Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 28 additions & 10 deletions aiter/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,15 @@ def build_module(
is_standalone,
torch_exclude,
hipify=True,
prebuild=0,
):
lock_path = f"{bd_dir}/lock_{md_name}"
startTS = time.perf_counter()
target_name = f"{md_name}.so" if not is_standalone else md_name

def MainFunc():
recopy_ck()
if prebuild != 1:
recopy_ck()
if AITER_REBUILD == 1:
rm_module(md_name)
clear_build(md_name)
Expand All @@ -289,7 +291,12 @@ def MainFunc():
if os.path.exists(f"{get_user_jit_dir()}/{target_name}"):
os.remove(f"{get_user_jit_dir()}/{target_name}")

sources = rename_cpp_to_cu(srcs, src_dir)
if prebuild != 2:
sources = rename_cpp_to_cu(srcs, src_dir)
else:
sources = rename_cpp_to_cu(
[get_user_jit_dir() + "/../../csrc/rocm_ops.cpp"], opbd_dir + "/srcs"
)

flags_cc = ["-O3", "-std=c++20"]
flags_hip = [
Expand Down Expand Up @@ -352,11 +359,12 @@ def exec_blob(blob_gen_cmd, op_dir, src_dir, sources):
sources += rename_cpp_to_cu([blob_dir], src_dir, recurisve=True)
return sources

if isinstance(blob_gen_cmd, list):
for s_blob_gen_cmd in blob_gen_cmd:
sources = exec_blob(s_blob_gen_cmd, op_dir, src_dir, sources)
else:
sources = exec_blob(blob_gen_cmd, op_dir, src_dir, sources)
if prebuild != 2:
if isinstance(blob_gen_cmd, list):
for s_blob_gen_cmd in blob_gen_cmd:
sources = exec_blob(s_blob_gen_cmd, op_dir, src_dir, sources)
else:
sources = exec_blob(blob_gen_cmd, op_dir, src_dir, sources)

# TODO: Move all torch api into torch folder
old_bd_include_dir = f"{op_dir}/build/include"
Expand Down Expand Up @@ -393,6 +401,7 @@ def exec_blob(blob_gen_cmd, op_dir, src_dir, sources):
is_standalone=is_standalone,
torch_exclude=torch_exclude,
hipify=hipify,
prebuild=prebuild,
)
if is_python_module and not is_standalone:
shutil.copy(f"{opbd_dir}/{target_name}", f"{get_user_jit_dir()}")
Expand Down Expand Up @@ -459,10 +468,10 @@ def convert(d_ops: dict):
with open(this_dir + "/optCompilerConfig.json", "r") as file:
data = json.load(file)
if isinstance(data, dict):
# parse all ops
# parse all ops, return list
if ops_name == "all":
all_ops_list = []
d_all_ops = {
"srcs": [],
"flags_extra_cc": [],
"flags_extra_hip": [],
"extra_include": [],
Expand All @@ -477,13 +486,22 @@ def convert(d_ops: dict):
if ops_name in exclude:
continue
single_ops = convert(d_ops)
d_single_ops = {
"md_name": ops_name,
"srcs": single_ops["srcs"],
"flags_extra_cc": single_ops["flags_extra_cc"],
"flags_extra_hip": single_ops["flags_extra_hip"],
"extra_include": single_ops["extra_include"],
"blob_gen_cmd": single_ops["blob_gen_cmd"],
}
for k in d_all_ops.keys():
if isinstance(single_ops[k], list):
d_all_ops[k] += single_ops[k]
elif isinstance(single_ops[k], str) and single_ops[k] != "":
d_all_ops[k].append(single_ops[k])
all_ops_list.append(d_single_ops)

return d_all_ops
return all_ops_list, d_all_ops
# no find opt_name in json.
elif data.get(ops_name) == None:
logger.warning(
Expand Down
57 changes: 43 additions & 14 deletions aiter/jit/utils/cpp_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,7 @@ def _jit_compile(
keep_intermediates=True,
torch_exclude=False,
hipify=True,
prebuild=0,
) -> None:
if is_python_module and is_standalone:
raise ValueError(
Expand Down Expand Up @@ -1234,6 +1235,7 @@ def _jit_compile(
is_python_module=is_python_module,
is_standalone=is_standalone,
torch_exclude=torch_exclude,
prebuild=prebuild,
)
elif verbose:
print(
Expand Down Expand Up @@ -1316,6 +1318,7 @@ def _write_ninja_file_and_build_library(
is_python_module: bool,
is_standalone: bool = False,
torch_exclude: bool = False,
prebuild: int = 0,
) -> None:
verify_ninja_availability()

Expand All @@ -1324,7 +1327,7 @@ def _write_ninja_file_and_build_library(
if with_cuda is None:
with_cuda = any(map(_is_cuda_file, sources))
extra_ldflags = _prepare_ldflags(
extra_ldflags or [], with_cuda, verbose, is_standalone, torch_exclude
extra_ldflags or [], with_cuda, verbose, is_standalone, torch_exclude, prebuild
)
build_file_path = os.path.join(build_directory, "build.ninja")
if verbose:
Expand All @@ -1343,6 +1346,7 @@ def _write_ninja_file_and_build_library(
is_python_module=is_python_module,
is_standalone=is_standalone,
torch_exclude=torch_exclude,
prebuild=prebuild,
)

if verbose:
Expand All @@ -1368,7 +1372,9 @@ def verify_ninja_availability():
raise RuntimeError("Ninja is required to load C++ extensions")


def _prepare_ldflags(extra_ldflags, with_cuda, verbose, is_standalone, torch_exclude):
def _prepare_ldflags(
extra_ldflags, with_cuda, verbose, is_standalone, torch_exclude, prebuild
):
extra_ldflags.append("-mcmodel=large")
extra_ldflags.append("-ffunction-sections")
extra_ldflags.append("-fdata-sections ")
Expand All @@ -1380,15 +1386,18 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose, is_standalone, torch_exc
_TORCH_PATH = os.path.join(os.path.dirname(torch.__file__))
TORCH_LIB_PATH = os.path.join(_TORCH_PATH, "lib")
extra_ldflags.append(f"-L{TORCH_LIB_PATH}")
extra_ldflags.append("-lc10")
if with_cuda:
extra_ldflags.append("-lc10_hip" if IS_HIP_EXTENSION else "-lc10_cuda")
extra_ldflags.append("-ltorch_cpu")
if with_cuda:
extra_ldflags.append("-ltorch_hip" if IS_HIP_EXTENSION else "-ltorch_cuda")
extra_ldflags.append("-ltorch")
if not is_standalone:
extra_ldflags.append("-ltorch_python")
if prebuild != 1:
extra_ldflags.append("-lc10")
if with_cuda:
extra_ldflags.append("-lc10_hip" if IS_HIP_EXTENSION else "-lc10_cuda")
extra_ldflags.append("-ltorch_cpu")
if with_cuda:
extra_ldflags.append(
"-ltorch_hip" if IS_HIP_EXTENSION else "-ltorch_cuda"
)
extra_ldflags.append("-ltorch")
if not is_standalone:
extra_ldflags.append("-ltorch_python")

if is_standalone:
extra_ldflags.append(f"-Wl,-rpath,{TORCH_LIB_PATH}")
Expand All @@ -1398,7 +1407,8 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose, is_standalone, torch_exc
print("Detected CUDA files, patching ldflags", file=sys.stderr)

extra_ldflags.append(f'-L{_join_rocm_home("lib")}')
extra_ldflags.append("-lamdhip64")
if prebuild != 1:
extra_ldflags.append("-lamdhip64")
return extra_ldflags


Expand Down Expand Up @@ -1529,6 +1539,7 @@ def _write_ninja_file_to_build_library(
is_python_module,
is_standalone,
torch_exclude,
prebuild=0,
) -> None:
extra_cflags = [flag.strip() for flag in extra_cflags]
extra_cuda_cflags = [flag.strip() for flag in extra_cuda_cflags]
Expand Down Expand Up @@ -1561,7 +1572,10 @@ def _write_ninja_file_to_build_library(
user_includes = [os.path.abspath(file) for file in extra_include_paths]

if not torch_exclude:
common_cflags.append(f"-DTORCH_EXTENSION_NAME={name}")
if prebuild == 0:
common_cflags.append(f"-DTORCH_EXTENSION_NAME={name}")
else:
common_cflags.append(f"-DTORCH_EXTENSION_NAME=aiter_")
common_cflags.append("-DTORCH_API_INCLUDE_EXTENSION_H")
common_cflags += [f"{x}" for x in _get_pybind11_abi_build_flags()]
common_cflags += [f"{x}" for x in _get_glibcxx_abi_build_flags()]
Expand All @@ -1576,6 +1590,8 @@ def _write_ninja_file_to_build_library(
cuda_flags = ["-DWITH_HIP"] + cflags + COMMON_HIP_FLAGS + COMMON_HIPCC_FLAGS
cuda_flags += extra_cuda_cflags
cuda_flags += _get_rocm_arch_flags(cuda_flags)
if prebuild == 1:
cuda_flags += ["-fvisibility=default -DEXPORT_SYMBOLS"]

def object_file_path(source_file: str) -> str:
# '/path/to/file.cpp' -> 'file'
Expand All @@ -1593,6 +1609,8 @@ def object_file_path(source_file: str) -> str:

ext = EXEC_EXT if is_standalone else LIB_EXT
library_target = f"{name}{ext}"
if prebuild == 2:
library_target = "aiter_.so"

_write_ninja_file(
path=path,
Expand All @@ -1606,6 +1624,7 @@ def object_file_path(source_file: str) -> str:
ldflags=ldflags,
library_target=library_target,
with_cuda=with_cuda,
prebuild=prebuild,
)


Expand All @@ -1621,6 +1640,7 @@ def _write_ninja_file(
ldflags,
library_target,
with_cuda,
prebuild=0,
) -> None:
r"""Write a ninja file that does the desired compiling and linking.

Expand Down Expand Up @@ -1671,7 +1691,6 @@ def sanitize_flags(flags):
flags.append(f'cuda_cflags = {" ".join(cuda_cflags)}')
flags.append(f'cuda_post_cflags = {" ".join(cuda_post_cflags)}')
flags.append(f'cuda_dlink_post_cflags = {" ".join(cuda_dlink_post_cflags)}')
flags.append(f'ldflags = {" ".join(ldflags)}')

# Turn into absolute paths so we can emit them into the ninja build
# file wherever it is.
Expand Down Expand Up @@ -1701,7 +1720,17 @@ def sanitize_flags(flags):
source_file = source_file.replace(" ", "$ ")
object_file = object_file.replace(" ", "$ ")
build.append(f"build {object_file}: {rule} {source_file}")
if prebuild == 2:
o_path = path.split("build/aiter_")[0]
ldflags.append(f"-Wl,-rpath={o_path}")

for root, dirs, files in os.walk(o_path):
for file in files:
mid_file_dir = o_path + file
if file.endswith(".so") and mid_file_dir not in objects:
objects.append(mid_file_dir)

flags.append(f'ldflags = {" ".join(ldflags)}')
if cuda_dlink_post_cflags:
devlink_out = os.path.join(os.path.dirname(objects[0]), "dlink.o")
devlink_rule = ["rule cuda_devlink"]
Expand Down
15 changes: 14 additions & 1 deletion csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
from dataclasses import dataclass
from aiter.jit.utils.chip_info import get_gfx
import os
import sys

this_dir = os.path.dirname(os.path.abspath(__file__))
AITER_CORE_DIR = os.path.abspath(f"{this_dir}/../../../")
if os.path.exists(os.path.join(AITER_CORE_DIR, "aiter_meta")):
AITER_CORE_DIR = os.path.join(AITER_CORE_DIR, "aiter/jit/utils") # pip install mode
else:
AITER_CORE_DIR = os.path.abspath(
f"{this_dir}/../../aiter/jit/utils"
) # develop mode
sys.path.insert(0, AITER_CORE_DIR)

from chip_info import get_gfx # noqa: E402


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion csrc/rocm_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
NORM_PYBIND;
POS_ENCODING_PYBIND;
ATTENTION_PYBIND;
// MOE_CK_2STAGES_PYBIND;
MOE_CK_2STAGES_PYBIND;
QUANT_PYBIND;
ATTENTION_ASM_PYBIND;
ATTENTION_RAGGED_PYBIND;
Expand Down
68 changes: 58 additions & 10 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
BuildExtension,
IS_HIP_EXTENSION,
)
from multiprocessing import Pool

ck_dir = os.environ.get("CK_DIR", f"{this_dir}/3rdparty/composable_kernel")
PACKAGE_NAME = "aiter"
Expand Down Expand Up @@ -45,7 +46,6 @@
exclude_ops = [
"libmha_fwd",
"libmha_bwd",
"module_moe_ck2stages",
"module_fmha_v3_fwd",
"module_mha_fwd",
"module_mha_varlen_fwd",
Expand All @@ -56,25 +56,73 @@
"module_mha_varlen_bwd",
]

all_opts_args_build = core.get_args_of_build("all", exclude=exclude_ops)
# remove pybind, because there are already duplicates in rocm_opt
new_list = [el for el in all_opts_args_build["srcs"] if "pybind.cu" not in el]
all_opts_args_build["srcs"] = new_list
all_opts_args_build, prebuild_link_param = core.get_args_of_build(
"all", exclude=exclude_ops
)
os.system(f"rm -rf {core.get_user_jit_dir()}/build")
os.system(f"rm -rf {core.get_user_jit_dir()}/*.so")
prebuild_dir = f"{core.get_user_jit_dir()}/build/aiter_/build"
core.recopy_ck()
os.makedirs(prebuild_dir + "/srcs")

def build_one_module(one_opt_args):
core.build_module(
md_name=one_opt_args["md_name"],
srcs=one_opt_args["srcs"],
flags_extra_cc=one_opt_args["flags_extra_cc"] + ["-DPREBUILD_KERNELS"],
flags_extra_hip=one_opt_args["flags_extra_hip"]
+ ["-DPREBUILD_KERNELS"],
blob_gen_cmd=one_opt_args["blob_gen_cmd"],
extra_include=one_opt_args["extra_include"],
extra_ldflags=None,
verbose=False,
is_python_module=True,
is_standalone=False,
torch_exclude=False,
prebuild=1,
)

# step 1, build *.cu -> module*.so
with Pool(processes=int(0.8 * os.cpu_count())) as pool:
pool.map(build_one_module, all_opts_args_build)

ck_batched_gemm_folders = [
f"{this_dir}/csrc/{name}/include"
for name in os.listdir(f"{this_dir}/csrc")
if os.path.isdir(os.path.join(f"{this_dir}/csrc", name))
and name.startswith("ck_batched_gemm")
]
ck_gemm_folders = [
f"{this_dir}/csrc/{name}/include"
for name in os.listdir(f"{this_dir}/csrc")
if os.path.isdir(os.path.join(f"{this_dir}/csrc", name))
and name.startswith("ck_gemm_a")
]
ck_gemm_inc = ck_batched_gemm_folders + ck_gemm_folders
for src in ck_gemm_inc:
dst = f"{prebuild_dir}/include"
shutil.copytree(src, dst, dirs_exist_ok=True)

shutil.copytree(
f"{this_dir}/csrc/include", f"{prebuild_dir}/include", dirs_exist_ok=True
)

# step 2, link module*.so -> aiter_.so
core.build_module(
md_name="aiter_",
srcs=all_opts_args_build["srcs"] + [f"{this_dir}/csrc"],
flags_extra_cc=all_opts_args_build["flags_extra_cc"]
srcs=[f"{prebuild_dir}/srcs/rocm_ops.cu"],
flags_extra_cc=prebuild_link_param["flags_extra_cc"]
+ ["-DPREBUILD_KERNELS"],
flags_extra_hip=all_opts_args_build["flags_extra_hip"]
flags_extra_hip=prebuild_link_param["flags_extra_hip"]
+ ["-DPREBUILD_KERNELS"],
blob_gen_cmd=all_opts_args_build["blob_gen_cmd"],
extra_include=all_opts_args_build["extra_include"],
blob_gen_cmd=prebuild_link_param["blob_gen_cmd"],
extra_include=prebuild_link_param["extra_include"],
extra_ldflags=None,
verbose=False,
is_python_module=True,
is_standalone=False,
torch_exclude=False,
prebuild=2,
)
else:
raise NotImplementedError("Only ROCM is supported")
Expand Down