From 0c5f25970833dfa95d3d5b4213b058f14a9c3f7d Mon Sep 17 00:00:00 2001 From: Xun Wang Date: Thu, 23 Oct 2025 06:13:12 -0500 Subject: [PATCH] fix cross entropy loss issue for small vocab size on amd gpu --- unsloth/kernels/cross_entropy_loss.py | 1 + unsloth/models/mistral.py | 7 +------ 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index e0af458a9..d3b618582 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -300,6 +300,7 @@ def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling : if n_chunks == 1: # For small vocabs <= 65336 like Llama, Mistral BLOCK_SIZE, num_warps = calculate_settings(vocab_size) + if is_cdna(): num_warps = num_warps // 2 logsumexp = torch.empty(n_rows, dtype = torch.float32, device = device) with torch_gpu_device(device): diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index b547739df..faab2d30b 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -298,12 +298,7 @@ def MistralForCausalLM_fast_forward( else: RETURN_LOGITS = os.environ.get("UNSLOTH_RETURN_LOGITS", "0") == "1" # < 1024 Normal Unsloth uses less VRAM! - if DEVICE_TYPE == "hip": - # [TODO] AMD GPUs fail on chunked_cross_entropy loss! - # RuntimeError: Triton Error [HIP]: Code: 1, Messsage: invalid argument - RETURN_LOGITS = False - elif bsz*q_len <= 1024: - RETURN_LOGITS = True + if bsz * q_len <= 1024: RETURN_LOGITS = True if not RETURN_LOGITS and labels is not None: n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None)