Skip to content

Matrix mismatch runtime error #1875

@dhruv1710

Description

@dhruv1710

Context

I was finetuning deepseek r1 with medical data by following Datacamp's article but on the trainer.train I am getting a matrix multiplication runtime error RuntimeError: mat1 and mat2 shapes cannot be multiplied (2158x4096 and 1x8388608)

Here's the link for the data camp article -> https://www.datacamp.com/tutorial/fine-tuning-deepseek-r1-reasoning-model

Error

RuntimeError Traceback (most recent call last)
in <cell line: 2>()
1 # Start the fine-tuning process
----> 2 trainer_stats = trainer.train()

/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
2239 hf_hub_utils.enable_progress_bars()
2240 else:
-> 2241 return inner_training_loop(
2242 args=args,
2243 resume_from_checkpoint=resume_from_checkpoint,

/usr/local/lib/python3.10/dist-packages/unsloth/models/llama.py in _fast_inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)

/usr/local/lib/python3.10/dist-packages/unsloth/models/_utils.py in _unsloth_training_step(self, model, inputs, num_items_in_batch)

/kaggle/working/unsloth_compiled_cache/UnslothSFTTrainer.py in compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
738
739 def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None):
--> 740 outputs = super().compute_loss(
741 model,
742 inputs,

/usr/local/lib/python3.10/dist-packages/unsloth/models/_utils.py in _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs)
1058 )
1059 pass
-> 1060 return self._old_compute_loss(model, inputs, *args, **kwargs)
1061 pass
1062

/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
3757 loss_kwargs["num_items_in_batch"] = num_items_in_batch
3758 inputs = {**inputs, **loss_kwargs}
-> 3759 outputs = model(**inputs)
3760 # Save past state if it exists
3761 # TODO: this needs to be fixed and made cleaner later.

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
1740
1741 # torchrec tests the code consistency with the following code

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1748 or _global_backward_pre_hooks or _global_backward_hooks
1749 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750 return forward_call(*args, **kwargs)
1751
1752 result = None

/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/data_parallel.py in forward(self, *inputs, **kwargs)
191 return self.module(*inputs[0], **module_kwargs[0])
192 replicas = self.replicate(self.module, self.device_ids[: len(inputs)])
--> 193 outputs = self.parallel_apply(replicas, inputs, module_kwargs)
194 return self.gather(outputs, self.output_device)
195

/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/data_parallel.py in parallel_apply(self, replicas, inputs, kwargs)
210 self, replicas: Sequence[T], inputs: Sequence[Any], kwargs: Any
211 ) -> List[Any]:
--> 212 return parallel_apply(
213 replicas, inputs, kwargs, self.device_ids[: len(replicas)]
214 )

/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/parallel_apply.py in parallel_apply(modules, inputs, kwargs_tup, devices)
124 output = results[i]
125 if isinstance(output, ExceptionWrapper):
--> 126 output.reraise()
127 outputs.append(output)
128 return outputs

/usr/local/lib/python3.10/dist-packages/torch/_utils.py in reraise(self)
731 # instantiate since we don't know how to
732 raise RuntimeError(msg) from None
--> 733 raise exception
734
735

RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/parallel_apply.py", line 96, in _worker
output = module(*input, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 32, in inner
return disable_fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/unsloth/models/llama.py", line 1226, in PeftModelForCausalLM_fast_forward
return self.base_model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/peft/tuners/tuners_utils.py", line 197, in forward
return self.model.forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/unsloth/models/llama.py", line 1062, in _CausalLM_fast_forward
outputs = self.model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/unsloth/models/llama.py", line 854, in LlamaModel_fast_forward
hidden_states = Unsloth_Offloaded_Gradient_Checkpointer.apply(
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 575, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 503, in decorate_fwd
return fwd(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/unsloth_zoo/gradient_checkpointing.py", line 147, in forward
output = forward_function(hidden_states, *args)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/unsloth/models/llama.py", line 533, in LlamaDecoderLayer_fast_forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/unsloth/models/llama.py", line 388, in LlamaAttention_fast_forward
Q, K, V = self.apply_qkv(self, hidden_states)
File "/usr/local/lib/python3.10/dist-packages/unsloth/kernels/fast_lora.py", line 366, in apply_lora_qkv
Q, K, V = LoRA_QKV.apply(X,
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 575, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 503, in decorate_fwd
return fwd(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/unsloth/kernels/fast_lora.py", line 259, in forward
Q = matmul_lora(X, QW, QW_quant, QA, QB, QS)
File "/usr/local/lib/python3.10/dist-packages/unsloth/kernels/utils.py", line 462, in matmul_lora
out = torch_matmul(X, W, out = out)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (2158x4096 and 1x8388608)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions