|
8 | 8 |
|
9 | 9 | _is_hip = is_hip() |
10 | 10 |
|
| 11 | + |
11 | 12 | fused_softcap_autotune = triton.autotune( |
12 | 13 | configs=[ |
13 | 14 | triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4), |
@@ -189,21 +190,16 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal |
189 | 190 | assert x.shape == residual.shape and x.dtype == residual.dtype |
190 | 191 | output, mid = torch.empty_like(x), torch.empty_like(x) |
191 | 192 | bs, hidden_dim = x.shape |
192 | | - |
193 | | - min_num_warps = 16 if _is_hip else 32 |
194 | | - |
195 | 193 | if autotune: |
196 | 194 | fused_dual_residual_rmsnorm_kernel_autotune[(bs,)]( |
197 | 195 | output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim |
198 | 196 | ) |
199 | 197 | else: |
| 198 | + max_warps = 16 if _is_hip else 32 |
200 | 199 | config = { |
201 | 200 | "BLOCK_SIZE": triton.next_power_of_2(hidden_dim), |
202 | 201 | "num_warps": max( |
203 | | - min( |
204 | | - triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps |
205 | | - ), |
206 | | - 4, |
| 202 | + min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4 |
207 | 203 | ), |
208 | 204 | } |
209 | 205 |
|
@@ -260,13 +256,11 @@ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False): |
260 | 256 | else: |
261 | 257 | output = torch.empty_like(x) |
262 | 258 | bs, hidden_dim = x.shape |
263 | | - |
264 | | - min_num_warps = 16 if _is_hip else 32 |
265 | | - |
| 259 | + max_warps = 16 if _is_hip else 32 |
266 | 260 | config = { |
267 | 261 | "BLOCK_SIZE": triton.next_power_of_2(hidden_dim), |
268 | 262 | "num_warps": max( |
269 | | - min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4 |
| 263 | + min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4 |
270 | 264 | ), |
271 | 265 | } |
272 | 266 |
|
@@ -331,6 +325,75 @@ def forward_native( |
331 | 325 | return self.rmsnorm2.forward_native(residual), residual |
332 | 326 |
|
333 | 327 |
|
| 328 | +@triton.jit |
| 329 | +def experts_combine_kernel( |
| 330 | + out_hidden_states, |
| 331 | + moe_hidden_states, |
| 332 | + mlp_hidden_states, |
| 333 | + combine_k: tl.constexpr, |
| 334 | + hidden_dim: tl.constexpr, |
| 335 | + BLOCK_SIZE: tl.constexpr, |
| 336 | +): |
| 337 | + pid = tl.program_id(0) |
| 338 | + start_index_mlp = pid * hidden_dim |
| 339 | + start_index_rmoe = pid * hidden_dim * combine_k |
| 340 | + offsets = tl.arange(0, BLOCK_SIZE) |
| 341 | + mask = offsets < hidden_dim |
| 342 | + combine_k_offsets = tl.arange(0, combine_k) |
| 343 | + |
| 344 | + moe_x = tl.load( |
| 345 | + moe_hidden_states |
| 346 | + + start_index_rmoe |
| 347 | + + combine_k_offsets[:, None] * hidden_dim |
| 348 | + + offsets[None, :], |
| 349 | + mask=mask[None, :], |
| 350 | + other=0.0, |
| 351 | + ) |
| 352 | + moe_x = tl.sum(moe_x, axis=0) |
| 353 | + mlp_x = tl.load(mlp_hidden_states + start_index_mlp + offsets, mask=mask, other=0.0) |
| 354 | + combined_x = (moe_x + mlp_x) / 1.4142135623730951 |
| 355 | + |
| 356 | + tl.store(out_hidden_states + start_index_mlp + offsets, combined_x, mask=mask) |
| 357 | + |
| 358 | + |
| 359 | +def experts_combine_triton(moe_hidden_states, mlp_hidden_states, output_buffer=None): |
| 360 | + assert moe_hidden_states.is_contiguous() |
| 361 | + assert mlp_hidden_states.is_contiguous() |
| 362 | + |
| 363 | + if len(moe_hidden_states.shape) == 2: |
| 364 | + combine_k = 1 # pre-combined |
| 365 | + else: |
| 366 | + combine_k = moe_hidden_states.shape[1] |
| 367 | + |
| 368 | + if output_buffer is None: |
| 369 | + out_hidden_states = torch.empty_like(mlp_hidden_states) |
| 370 | + else: |
| 371 | + flat_output_buffer = output_buffer.view(mlp_hidden_states.dtype).reshape(-1) |
| 372 | + assert flat_output_buffer.numel() >= mlp_hidden_states.numel() |
| 373 | + out_hidden_states = flat_output_buffer[: mlp_hidden_states.numel()].reshape( |
| 374 | + mlp_hidden_states.shape |
| 375 | + ) |
| 376 | + |
| 377 | + bs, hidden_dim = mlp_hidden_states.shape |
| 378 | + |
| 379 | + config = { |
| 380 | + "BLOCK_SIZE": triton.next_power_of_2(hidden_dim), |
| 381 | + "num_warps": max( |
| 382 | + min(triton.next_power_of_2(triton.cdiv(hidden_dim, 1024)), 8), 4 |
| 383 | + ), |
| 384 | + } |
| 385 | + |
| 386 | + experts_combine_kernel[(bs,)]( |
| 387 | + out_hidden_states, |
| 388 | + moe_hidden_states, |
| 389 | + mlp_hidden_states, |
| 390 | + combine_k, |
| 391 | + hidden_dim, |
| 392 | + **config, |
| 393 | + ) |
| 394 | + return out_hidden_states |
| 395 | + |
| 396 | + |
334 | 397 | # gelu on first half of vector |
335 | 398 | @triton.jit |
336 | 399 | def gelu_and_mul_kernel( |
@@ -400,10 +463,11 @@ def gelu_and_mul_triton( |
400 | 463 | out_scales = scales |
401 | 464 | static_scale = True |
402 | 465 |
|
| 466 | + max_warps = 16 if _is_hip else 32 |
403 | 467 | config = { |
404 | 468 | # 8 ele per thread (not tuned) |
405 | 469 | "num_warps": max( |
406 | | - min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), 32), 4 |
| 470 | + min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), max_warps), 4 |
407 | 471 | ), |
408 | 472 | } |
409 | 473 |
|
|
0 commit comments