Skip to content

Commit 2decab1

Browse files
jlamypoirierclaude
andcommitted
Fix kernel_1 backward on the CPU interpreter path (#499)
The single-pass launch queried the CUDA SM count unconditionally to bound the program count, which broke `triton_normalization_backward` under `TRITON_INTERPRET=1` on CPU (no CUDA). The bound is a GPU-occupancy heuristic, so skip it off-GPU and use one program per tile. Also cover the two-pass path in the kernel test (it was only exercising single-pass) and drop the now-constant num_stages from the launch config. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
1 parent 0c5d45c commit 2decab1

2 files changed

Lines changed: 16 additions & 9 deletions

File tree

fast_llm/functional/triton/normalization.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def _kernel_1_wide_single_pass(block_size: int, n_rows: int, has_bias: bool, max
376376

377377
def _triton_normalization_backward_kernel_1_config(
378378
n_cols: int, n_rows: int, has_bias: bool
379-
) -> tuple[int, int, bool, int, int]:
379+
) -> tuple[int, int, bool, int]:
380380
block_size = triton.next_power_of_2(n_cols)
381381
max_block_size = TritonConfig.MAX_BLOCK_SIZE_BYTES // 4
382382
if block_size <= _KERNEL_1_NARROW_MAX_COLS:
@@ -398,7 +398,7 @@ def _triton_normalization_backward_kernel_1_config(
398398
block_size_row = _KERNEL_1_WIDE_ROWS
399399
two_pass = True
400400
num_warps = 16
401-
return block_size_col, block_size_row, two_pass, num_warps, 1
401+
return block_size_col, block_size_row, two_pass, num_warps
402402

403403

404404
def triton_normalization_backward(grad_output: torch.Tensor, context: list[typing.Any]) -> torch.Tensor:
@@ -424,14 +424,19 @@ def triton_normalization_backward(grad_output: torch.Tensor, context: list[typin
424424
block_size_m = 64
425425
block_size_n = 8
426426

427-
block_size_col, block_size_row, two_pass, num_warps, num_stages = _triton_normalization_backward_kernel_1_config(
427+
block_size_col, block_size_row, two_pass, num_warps = _triton_normalization_backward_kernel_1_config(
428428
n_cols, n_rows, has_bias
429429
)
430430

431431
num_tiles = triton.cdiv(n_rows, block_size_row)
432432
# Single pass grid-strides the rows, so the program count (the partial-buffer height) is bounded
433-
# below the tile count; two pass keeps one program per tile.
434-
num_blocks_row = num_tiles if two_pass else min(num_tiles, _kernel_1_target_row_blocks(grad_output.device))
433+
# below the tile count; two pass keeps one program per tile. The bound is a GPU-occupancy
434+
# heuristic, so it is skipped off-GPU (e.g. the Triton interpreter on CPU), where querying the SM
435+
# count would fail.
436+
if two_pass or grad_output.device.type != "cuda":
437+
num_blocks_row = num_tiles
438+
else:
439+
num_blocks_row = min(num_tiles, _kernel_1_target_row_blocks(grad_output.device))
435440

436441
grad_input = torch.empty_like(grad_output)
437442

@@ -472,7 +477,7 @@ def triton_normalization_backward(grad_output: torch.Tensor, context: list[typin
472477
block_size_row,
473478
two_pass,
474479
num_warps=num_warps,
475-
num_stages=num_stages,
480+
num_stages=1,
476481
)
477482
if parameter_grad:
478483
triton_normalization_backward_kernel_2[(triton.cdiv(n_cols, block_size_n),)](

tests/functional/test_triton_kernels.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,15 @@ def test_triton_rotary(num_tokens, num_heads, head_size, testing_device):
106106

107107

108108
@requires_triton
109+
# (32, 128) exercises the single-pass backward; (32, 8192) crosses into the two-pass path.
110+
@pytest.mark.parametrize(("num_rows", "num_cols"), [(32, 128), (32, 8192)])
109111
@pytest.mark.parametrize("has_bias", [True, False])
110112
@pytest.mark.parametrize("zero_centered", [True, False])
111-
def test_triton_normalization(has_bias, zero_centered, testing_device):
112-
input_ = torch.randn(32, 128, device=testing_device, requires_grad=True)
113+
def test_triton_normalization(num_rows, num_cols, has_bias, zero_centered, testing_device):
114+
input_ = torch.randn(num_rows, num_cols, device=testing_device, requires_grad=True)
113115
output_grad = torch.randn_like(input_)
114116

115-
weight = torch.randn(128, device=testing_device, requires_grad=True)
117+
weight = torch.randn(num_cols, device=testing_device, requires_grad=True)
116118
weight.grad_buffer = torch.empty_like(weight)
117119
weight.param_grad_is_zero = True
118120

0 commit comments

Comments
 (0)