Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions unsloth/kernels/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling :
if n_chunks == 1:
# For small vocabs <= 65336 like Llama, Mistral
BLOCK_SIZE, num_warps = calculate_settings(vocab_size)
if is_cdna(): num_warps = num_warps // 2
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = device)

with torch_gpu_device(device):
Expand Down
7 changes: 1 addition & 6 deletions unsloth/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,7 @@ def MistralForCausalLM_fast_forward(
else:
RETURN_LOGITS = os.environ.get("UNSLOTH_RETURN_LOGITS", "0") == "1"
# < 1024 Normal Unsloth uses less VRAM!
if DEVICE_TYPE == "hip":
# [TODO] AMD GPUs fail on chunked_cross_entropy loss!
# RuntimeError: Triton Error [HIP]: Code: 1, Messsage: invalid argument
RETURN_LOGITS = False
elif bsz*q_len <= 1024:
RETURN_LOGITS = True
if bsz * q_len <= 1024: RETURN_LOGITS = True

if not RETURN_LOGITS and labels is not None:
n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None)
Expand Down