Skip to content

Commit 26f3c3d

Browse files
authored
Merge pull request #268 from microsoft/mi300x
synchronize Mi300x ops for MLA rope
2 parents 91d3af2 + 67eb279 commit 26f3c3d

5 files changed

Lines changed: 41 additions & 45 deletions

File tree

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,15 @@ We compare three solutions that support <ins>Full-Precision Inference (PPL = 0)
1616

1717
## What's New:
1818

19+
- Tutel v0.4.1: Support fused MLA for R1 for AMD MI300x8:
20+
```sh
21+
>> Example:
22+
23+
docker run -it --rm --ipc=host --privileged -p 8000:8000 \
24+
-v /:/host -w /host$(pwd) tutelgroup/deepseek-671b:mi300x8-chat-20250319 \
25+
--model_path ./deepseek-ai/DeepSeek-R1 --prompt "Calculate the result of: 1 / (sqrt(5) - sqrt(3))"
26+
```
27+
1928
- Tutel v0.4.0: Accelerating Deepseek R1 Full-precision-Chat for AMD MI300x8 (more platform support will be added in later versions):
2029
```sh
2130
>> Example:

tutel/custom/custom_kernel.cpp

Lines changed: 32 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,9 +1126,14 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> warp_multi_head_latent_r
11261126
CHECK_EQ(w_q_b_proj.dim(), 4);
11271127
CHECK_EQ(k_b_proj.dtype(), torch::kBFloat16);
11281128
CHECK_EQ(k_b_proj.dim(), 3);
1129+
CHECK_CONTIGUOUS(k_b_proj.transpose(1, 2));
11291130

1130-
auto qh = torch::matmul(w_q_b_proj, q.view({batch * seqlen, q.size(2)}).t()).view({2, n_local_heads, 128, batch * seqlen});
1131-
auto q_output = torch::matmul(qh[0].transpose(1, 2), k_b_proj);
1131+
auto qh = ((batch * seqlen <= 4) ? \
1132+
antares::ops::call("gmv_bf16", {w_q_b_proj.view({-1, w_q_b_proj.size(-1)}).view(torch::kInt32), q.view({batch * seqlen, q.size(2)}).view(torch::kInt32)}, {}) : \
1133+
torch::matmul(w_q_b_proj, q.view({batch * seqlen, q.size(2)}).t())).view({2, n_local_heads, 128, batch * seqlen});
1134+
auto q_output = (batch * seqlen == 1) ? \
1135+
antares::ops::call("bmv_bf16", {qh[0].squeeze(-1).view(torch::kInt32), k_b_proj.transpose(1, 2).view(torch::kInt32)}, {}).unsqueeze(1) : (batch * seqlen <= 4 ? \
1136+
antares::ops::call("bmm_bf16", {qh[0], k_b_proj.transpose(1, 2)}, {}) : torch::matmul(qh[0].transpose(1, 2), k_b_proj));
11321137

11331138
q_output = antares::ops::call("rope_q_bf16", {q_output.view({n_local_heads, batch * seqlen, 16, 32}), cos_sin, qh[1].view({n_local_heads, 2, 32, 2, batch * seqlen}), positions}, {}).view({batch, seqlen, n_local_heads, 576});
11341139
return {q_output, k_output, kv};
@@ -1171,19 +1176,27 @@ torch::Tensor warp_deepseek_r1_attn_f16xf8_block_scal(
11711176
if (it2 == wkv_b_.end()) {
11721177
auto _ = antares::ops::call("to_bfloat16_3d", {kv_b_proj.unsqueeze(0), kv_b_proj_scal.unsqueeze(0)}, {}).
11731178
view({n_local_heads, 2, -1, kv_b_proj.size(-1)}).permute({1, 0, 2, 3}).contiguous(); // 2, H, 128, 512
1174-
wkv_b_[kv_b_proj_scal.data_ptr()] = {_.select(0, 0), _.select(0, 1).transpose(1, 2)}; // H, D(128), C(512)
1179+
wkv_b_[kv_b_proj_scal.data_ptr()] = {_.select(0, 0).transpose(1, 2).contiguous().transpose(1, 2), _.select(0, 1)}; // H, D(128), C(512)
11751180
it2 = wkv_b_.find(kv_b_proj_scal.data_ptr());
11761181
}
11771182
auto wkc = std::get<0>(it2->second), wvc = std::get<1>(it2->second);
11781183
auto qkv = warp_gemm_nt_bf16xfp8_block_scal(data, qkv_a_proj, qkv_a_proj_scal); // [B, S, 1536 + 512 + 64]
11791184

1180-
auto positions = torch::full({batch}, pos, torch::TensorOptions().dtype(torch::kInt64).device(data.device()));
1185+
static torch::Tensor posperm = torch::arange(0, cos_sin.size(0), torch::TensorOptions().dtype(torch::kInt64).device(data.device()));
1186+
auto positions = batch == 1 ? posperm.narrow(0, pos, 1) : torch::full({batch}, pos, torch::TensorOptions().dtype(torch::kInt64).device(data.device()));
11811187
auto inputs = warp_multi_head_latent_rope_bf16(qkv, cos_sin, positions, q_a_norm, kv_a_norm, it->second, wkc, n_local_heads);
11821188
key_cache.narrow(0, pos, 1).copy_(std::get<1>(inputs).permute({1, 0, 2}));
1189+
11831190
auto Q = std::get<0>(inputs), C = key_cache.narrow(0, 0, pos + seqlen); // S2, B, (512 + 64)
1184-
auto scores = at::einsum("bshc,tbc->bsht", {Q, C}) * 0.1352337788608801f;
1185-
Q = at::einsum("bsht,tbc->bshc", {at::softmax(scores, -1), C}).narrow(-1, 0, 512);
1186-
Q = at::einsum("bshc,hcd->bshd", {Q, wvc}).contiguous();
1191+
if (batch == 1 && seqlen == 1) {
1192+
auto scores = antares::ops::call("scores_bf16", {Q.view({n_heads, Q.size(-1)}).view(torch::kInt32), C.squeeze(1).view(torch::kInt32)}, {0.1352337788608801f});
1193+
Q = torch::matmul(at::softmax(scores, -1), C.squeeze(1));
1194+
Q = antares::ops::call("logits_bf16", {Q.view(torch::kInt32), wvc.view(torch::kInt32)}, {});
1195+
} else {
1196+
auto scores = at::einsum("bshc,tbc->bsht", {Q, C}) * 0.1352337788608801f;
1197+
Q = at::einsum("bsht,tbc->bshc", {at::softmax(scores, -1), C}).narrow(-1, 0, 512);
1198+
Q = at::einsum("bshc,hdc->bshd", {Q, wvc}).contiguous();
1199+
}
11871200
Q = warp_gemm_nt_bf16xfp8_block_scal(Q.view({batch, seqlen, -1}), o_proj, o_proj_scal);
11881201
return Q;
11891202
}
@@ -1200,48 +1213,22 @@ torch::Tensor warp_deepseek_r1_attn_f16xf8_block_scal(
12001213
auto q_pe_out = torch::empty_like(q_pe);
12011214
antares::ops::call("rotary_lookup_bf16", {cos_sin.select(0, pos).select(0, 0), cos_sin.select(0, pos).select(0, 1), q_pe.view({-1, 32, 2}), q_pe_out.view({-1, 2, 32})}, {}, false, 0, 3);
12021215

1203-
if (val_cache.numel() > 1) {
1204-
kv = warp_gemm_nt_bf16xfp8_block_scal(warp_rmsnorm_bf16(kv, kv_a_norm, 1e-6f), kv_b_proj, kv_b_proj_scal);
1205-
1206-
antares::ops::call("cache_fill_bf16", {q_pe_out, k_pe_out, query_states, key_cache.select(0, pos)}, {128}, false, 0, 3);
1216+
kv = warp_gemm_nt_bf16xfp8_block_scal(warp_rmsnorm_bf16(kv, kv_a_norm, 1e-6f), kv_b_proj, kv_b_proj_scal);
1217+
antares::ops::call("cache_fill_bf16", {q_pe_out, k_pe_out, query_states, key_cache.select(0, pos)}, {128}, false, 0, 3);
12071218
// [B,S,H,64] [B,S,64] [B,S,H,128:] [B,H,128:]
12081219

1209-
antares::ops::call("cache_move_bf16", {kv.view({batch, seqlen, n_heads, 2, 128}), key_cache.narrow(0, pos, seqlen), val_cache.narrow(0, pos, seqlen)}, {}, false, 0, 2);
1220+
antares::ops::call("cache_move_bf16", {kv.view({batch, seqlen, n_heads, 2, 128}), key_cache.narrow(0, pos, seqlen), val_cache.narrow(0, pos, seqlen)}, {}, false, 0, 2);
12101221
// [B,S,H,2,M] [S,B,H,:128] [S,B,H,:128]
12111222

1212-
auto key_states = key_cache.narrow(0, 0, pos + seqlen).view({1, pos + seqlen, batch * n_heads, 192});
1213-
auto value_states = val_cache.narrow(0, 0, pos + seqlen).view({1, pos + seqlen, batch * n_heads, 128});
1214-
query_states = query_states.permute({1, 0, 2, 3}).view({1, seqlen, -1, 192});
1215-
CHECK_EQ(query_states.size(1), 1);
1223+
auto key_states = key_cache.narrow(0, 0, pos + seqlen).view({1, pos + seqlen, batch * n_heads, 192});
1224+
auto value_states = val_cache.narrow(0, 0, pos + seqlen).view({1, pos + seqlen, batch * n_heads, 128});
1225+
query_states = query_states.permute({1, 0, 2, 3}).view({1, seqlen, -1, 192});
1226+
CHECK_EQ(query_states.size(1), 1);
12161227

1217-
auto lm = torch::empty({2, batch * n_heads, 64}, torch::TensorOptions().dtype(torch::kBFloat16).device(query_states.device()));
1218-
auto attn_output = antares::ops::call("self_attn_infer_bf16", {query_states.squeeze(0).squeeze(0), key_states.squeeze(0), value_states.squeeze(0), lm}, {0.1352337788608801f});
1219-
xb = torch::matmul(antares::ops::call("self_attn_reduce_bf16", {lm}, {}).unsqueeze(1), attn_output).to(query_states.dtype());
1220-
1221-
// xb = std::get<0>(at::native::_scaled_dot_product_attention_math(query_states.permute({0, 2, 1, 3}).to(torch::kBFloat16), key_states.permute({0, 2, 1, 3}).to(torch::kBFloat16), value_states.permute({0, 2, 1, 3}).to(torch::kBFloat16), {}, 0, false, {}, 0.1352337788608801)).permute({0, 2, 1, 3}).to(query_states.dtype());
1222-
} else {
1223-
kv = torch::cat({warp_rmsnorm_bf16(kv, kv_a_norm, 1e-6f), k_pe_out}, -1); // [B, S, 512]
1224-
key_cache.narrow(0, pos, seqlen).narrow(1, 0, batch) = kv.permute({1, 0, 2}); // [S, B, 512 + 64]
1225-
1226-
static std::unordered_map<void*, torch::Tensor> wkv_b_;
1227-
auto it = wkv_b_.find(kv_b_proj_scal.data_ptr());
1228-
if (it == wkv_b_.end()) {
1229-
wkv_b_[kv_b_proj_scal.data_ptr()] = antares::ops::call("to_bfloat16_3d", {kv_b_proj.unsqueeze(0), kv_b_proj_scal.unsqueeze(0)}, {}).
1230-
view({n_heads, 2, -1, kv_b_proj.size(-1)}).permute({1, 0, 2, 3}).contiguous(); // 2, H, 128, 512
1231-
it = wkv_b_.find(kv_b_proj_scal.data_ptr());
1232-
}
1233-
auto _0 = it->second.select(0, 0), _1 = it->second.select(0, 1); // H, D(128), C(512)
1234-
// k_pe_out, q_pe_out -- 1, 1, 64 | 1, 1, 16, 64
1235-
auto q_nope = query_states.narrow(-1, 0, 128).contiguous(); // B, S, H, D
1236-
q_nope = at::einsum("bshd,hdc->bshc", {q_nope, _0}).contiguous(); // B, S, H, C(512)
1237-
q_nope = torch::cat({q_nope, q_pe_out}, -1);
1238-
1239-
auto R = key_cache.narrow(0, 0, pos + seqlen); // S2, B, (512 + 64)
1240-
auto scores_ = at::einsum("bshc,tbc->bsht", {q_nope, R}) * 0.1352337788608801f;
1241-
_0 = at::einsum("bsht,tbc->bshc", {at::softmax(scores_, -1), R});
1242-
_0 = _0.narrow(-1, 0, 512).contiguous();
1243-
xb = at::einsum("bshc,hdc->bshd", {_0, _1}).contiguous();
1244-
}
1228+
auto lm = torch::empty({2, batch * n_heads, 64}, torch::TensorOptions().dtype(torch::kBFloat16).device(query_states.device()));
1229+
auto attn_output = antares::ops::call("self_attn_infer_bf16", {query_states.squeeze(0).squeeze(0), key_states.squeeze(0), value_states.squeeze(0), lm}, {0.1352337788608801f});
1230+
xb = torch::matmul(antares::ops::call("self_attn_reduce_bf16", {lm}, {}).unsqueeze(1), attn_output).to(query_states.dtype());
1231+
// xb = std::get<0>(at::native::_scaled_dot_product_attention_math(query_states.permute({0, 2, 1, 3}).to(torch::kBFloat16), key_states.permute({0, 2, 1, 3}).to(torch::kBFloat16), value_states.permute({0, 2, 1, 3}).to(torch::kBFloat16), {}, 0, false, {}, 0.1352337788608801)).permute({0, 2, 1, 3}).to(query_states.dtype());
12451232
xb = warp_gemm_nt_bf16xfp8_block_scal(xb.view({batch, seqlen, -1}), o_proj, o_proj_scal);
12461233
return xb;
12471234
}
@@ -1405,7 +1392,7 @@ void warp_deepseek_r1_prepare_weights(
14051392
::cos_sin = cos_sin;
14061393

14071394
int n_layers = o_projs.size();
1408-
bool use_lora = getenv("LORA") ? (std::atoi(getenv("LORA")) == 1) : (batch >= 4);
1395+
bool use_lora = getenv("LORA") ? (std::atoi(getenv("LORA")) == 1) : true;
14091396
if (use_lora) {
14101397
// kv_lora_rank + qk_rope_head_dim
14111398
::key_cache = torch::zeros({n_layers, max_seq_len, batch, 512 + 64}, torch::TensorOptions().dtype(token_emb.dtype()).device(token_emb.device()));

tutel/ops/bmm_bf16.mod

5.89 KB
Binary file not shown.

tutel/ops/bmv_bf16.mod

6.06 KB
Binary file not shown.

tutel/ops/gmv_bf16.mod

6.12 KB
Binary file not shown.

0 commit comments

Comments
 (0)