Skip to content

silent multi-GPU corruption, with the layer-40 repro and before/after evidence #6831

Description

@xvdev09

[Bug] weight_dequant_block launches Triton kernel without device context → silent nondeterministic corruption/NaN on multi-GPU device_map models

Environment

  • unsloth 2026.6.9, unsloth_zoo 2026.6.7
  • torch 2.10.0+cu128, triton 3.6.0, CUDA 12.8
  • 8x NVIDIA H200 SXM (SM90), model sharded via explicit device_map (model parallelism)
  • transformers 5.13.0.dev0 (git main)
  • Model: GLM-5.2 FP8 block-quantized (GlmMoeDsaForCausalLM, ~695GB, layers spread across 8 GPUs)

Summary

weight_dequant_block in unsloth/kernels/fp8.py launches weight_dequant_kernel without setting the CUDA device context to match the input tensor's device:

weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE = block_size)

When a model is sharded across GPUs with device_map (the only way to fine-tune ~700GB FP8 MoE checkpoints), expert weights live on cuda:1..cuda:7 while the calling thread's current device is frequently cuda:0. The Triton launch then targets the wrong device.

Severity: silent corruption, not a crash

Depending on context state at call time, the result is:

  • correct output (context happened to match), or
  • silently corrupted dequantized weights — nondeterministic garbage that flows into the MoE forward, or
  • NaN losses.

Observed symptoms before diagnosis (all with layers on GPUs 1-7, single-GPU-hosted layers unaffected):

  • loss=nan from training step 1, intermittently — same input flipped between NaN and finite across processes
  • eval-mode forward on identical input wobbled 0.146 → 0.218 across back-to-back trials (no dropout)
  • grouped_mm vs native_torch MoE backends "disagreed" by 2+ points of loss (5.8 vs 7.8) — both were consuming corrupted dequants

Repro

Load any FP8 block-quantized MoE with a multi-GPU device_map, then call an MoE layer hosted on a non-default GPU directly:

mlp = model.model.layers[40].mlp        # hosted on cuda:4
hidden = torch.randn(1, 384, config.hidden_size, dtype=torch.bfloat16, device="cuda:4") * 0.02
out = mlp(hidden)                        # ->
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)

The same call on a layer hosted on cuda:0 works. In full-model forwards the mismatch does not raise — it silently corrupts instead, which is the dangerous part.

Fix (verified)

Wrap the launch in the tensor's device context:

with torch.cuda.device(x.device):
    weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE = block_size)

After the patch:

  • layer on cuda:4 behaves identically to layer on cuda:0 (5 repeat forwards, max diff ~3e-05 = bf16 atomics noise, no NaN)
  • full-model eval loss became stable across runs (native_torch backend: bit-identical across trials)
  • both MoE backends now agree (3.163 vs 3.167)
  • loss values dropped from corrupted 5.8/7.8 to a consistent 3.17 on the same input — the corruption was inflating loss the whole time
  • full LoRA training runs with healthy finite losses

Note: any other Triton launch sites in the FP8/MoE path that can receive non-default-device tensors likely need the same guard.

Related

Found together with the int32 offset overflow in the same kernel (separate issue). Checkpoint integrity was verified clean (CPU-side scan of all shards: 59,044 fp8 tensors + 59,044 scales, zero anomalies), ruling out data corruption.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions