diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index 16a23f885..6274f2e5d 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -300,7 +300,7 @@ def MistralForCausalLM_fast_forward( # < 1024 Normal Unsloth uses less VRAM! if bsz * q_len <= 1024: RETURN_LOGITS = True - if not RETURN_LOGITS and HAS_CUT_CROSS_ENTROPY and labels is not None: + if not RETURN_LOGITS and HAS_CUT_CROSS_ENTROPY and os.environ.get("UNSLOTH_ENABLE_CCE", "1") != "0" and labels is not None: n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None) logit_softcapping = getattr(self.config, "final_logit_softcapping", 0) loss = fused_linear_cross_entropy(