Skip to content

Commit 91d3af2

Browse files
authored
merge updates for Mi300x ops (#267)
* add initial r1 ops for testing * add deepseek_r1_sigmoid_top_k_f32 * add glu_expert_bf16xf8_block_scal * fix glu_expert_bf16xf8_block_scal for cat([gate, up], dim=0) * update cached bf16xf8_block_scal * add gemm_nt_bf16xfp8_block_scal * add modules * enhance glu perf for bs=32 * fuse routed_scaled in topk_gating * using rotary_lookup_bf16 instead of rotary_emb_bf16 * moving head/tail ops to C++ * add partial absorb fusion * using gate_gemm_out_bf16 * add gemm_gate_up_silu_bf16xf8_s_16x16 * use torch::matmul for some gemm * handle scaling at non-ending dim * add glu_expert_bf16xf8_block_scal_16x16 back * fine tune bf16 deviation * add test_allreduce_bf16 * update deepseek_sigmoid_top_8_static_v2 with scaling 2.5 * add glu_expert_bf16xf8_block_scal_16x16_fnuz * restore previous interface to avoid compatiblity break * add multi_head_latent_rope_bf16 * add system.from_url() * force deepseek_sigmoid_top_8_static_v2's dtype compatible with sglang * isolate NCCL-dependent APIs with macros
1 parent 490917d commit 91d3af2

30 files changed

Lines changed: 629 additions & 86 deletions

MANIFEST.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
include tutel/ops/*

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def install(use_cuda, use_nccl):
113113
install_requires=[
114114
"numpy",
115115
],
116+
include_package_data=True,
116117
zip_safe=False,
117118
extras_require={
118119
'test': [

tutel/custom/antares_ops.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ at::Tensor call(const void *key, const std::vector<at::Tensor> &ts, const std::v
181181
static std::unordered_map<std::string, decltype(torch::kInt8)> key_to_dtype = {
182182
{"int8", torch::kInt8}, {"int16", torch::kInt16}, {"int32", torch::kInt32}, {"int64", torch::kInt64},
183183
{"bfloat8", at::kFloat8_e5m2}, {"float8", at::kFloat8_e4m3fn}, {"bfloat16", torch::kBFloat16}, {"float16", torch::kFloat16}, {"float32", torch::kFloat32}, {"float64", torch::kFloat64},
184-
{"bfloat2x16", at::kComplexHalf}, {"float2x16", at::kComplexHalf}, {"float2x32", at::kComplexFloat},
184+
{"bfloat2x16", torch::kInt32}, {"float2x16", torch::kInt32}, {"float2x32", torch::kInt64},
185185
};
186186

187187
auto dtype_it = key_to_dtype.find(o_type[1]);

tutel/custom/custom_kernel.cpp

Lines changed: 540 additions & 84 deletions
Large diffs are not rendered by default.
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT license.
4+
5+
import torch
6+
import time
7+
import argparse
8+
9+
from tutel import system, net
10+
11+
parser = argparse.ArgumentParser()
12+
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
13+
parser.add_argument('--count', type=int, default=229376)
14+
parser.add_argument('--loop', type=int, default=50)
15+
parser.add_argument('--warmup', type=int, default=5, help='Number of warmup iterations')
16+
args = parser.parse_args()
17+
18+
parallel_env = system.init_data_model_parallel(backend='nccl' if args.device == 'cuda' else 'gloo')
19+
local_device = parallel_env.local_device
20+
21+
x = torch.randn([args.count], device=local_device, dtype=torch.float32)
22+
23+
if args.device == 'cuda':
24+
wait = lambda: torch.cuda.synchronize() or time.perf_counter()
25+
else:
26+
wait = lambda: time.perf_counter()
27+
28+
# Warmup phase (excluded from any measurement)
29+
with torch.no_grad():
30+
for _ in range(args.warmup + args.loop):
31+
torch.ops.tutel_ops.test_allreduce_bf16(args.count)
32+

tutel/ops/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
import os
5+
import tutel_custom_kernel
6+
7+
if 'OP_LOADER' not in os.environ:
8+
os.environ['OP_LOADER'] = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.')
75.5 KB
Binary file not shown.
75.6 KB
Binary file not shown.

tutel/ops/fused_silu_mul_bf16.mod

7.73 KB
Binary file not shown.
7.11 KB
Binary file not shown.

0 commit comments

Comments
 (0)