-
-
Notifications
You must be signed in to change notification settings - Fork 3.9k
Description
Context
I was finetuning deepseek r1 with medical data by following Datacamp's article but on the trainer.train I am getting a matrix multiplication runtime error RuntimeError: mat1 and mat2 shapes cannot be multiplied (2158x4096 and 1x8388608)
Here's the link for the data camp article -> https://www.datacamp.com/tutorial/fine-tuning-deepseek-r1-reasoning-model
Error
RuntimeError Traceback (most recent call last)
in <cell line: 2>()
1 # Start the fine-tuning process
----> 2 trainer_stats = trainer.train()
/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
2239 hf_hub_utils.enable_progress_bars()
2240 else:
-> 2241 return inner_training_loop(
2242 args=args,
2243 resume_from_checkpoint=resume_from_checkpoint,
/usr/local/lib/python3.10/dist-packages/unsloth/models/llama.py in _fast_inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
/usr/local/lib/python3.10/dist-packages/unsloth/models/_utils.py in _unsloth_training_step(self, model, inputs, num_items_in_batch)
/kaggle/working/unsloth_compiled_cache/UnslothSFTTrainer.py in compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
738
739 def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None):
--> 740 outputs = super().compute_loss(
741 model,
742 inputs,
/usr/local/lib/python3.10/dist-packages/unsloth/models/_utils.py in _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs)
1058 )
1059 pass
-> 1060 return self._old_compute_loss(model, inputs, *args, **kwargs)
1061 pass
1062
/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
3757 loss_kwargs["num_items_in_batch"] = num_items_in_batch
3758 inputs = {**inputs, **loss_kwargs}
-> 3759 outputs = model(**inputs)
3760 # Save past state if it exists
3761 # TODO: this needs to be fixed and made cleaner later.
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
1740
1741 # torchrec tests the code consistency with the following code
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1748 or _global_backward_pre_hooks or _global_backward_hooks
1749 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750 return forward_call(*args, **kwargs)
1751
1752 result = None
/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/data_parallel.py in forward(self, *inputs, **kwargs)
191 return self.module(*inputs[0], **module_kwargs[0])
192 replicas = self.replicate(self.module, self.device_ids[: len(inputs)])
--> 193 outputs = self.parallel_apply(replicas, inputs, module_kwargs)
194 return self.gather(outputs, self.output_device)
195
/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/data_parallel.py in parallel_apply(self, replicas, inputs, kwargs)
210 self, replicas: Sequence[T], inputs: Sequence[Any], kwargs: Any
211 ) -> List[Any]:
--> 212 return parallel_apply(
213 replicas, inputs, kwargs, self.device_ids[: len(replicas)]
214 )
/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/parallel_apply.py in parallel_apply(modules, inputs, kwargs_tup, devices)
124 output = results[i]
125 if isinstance(output, ExceptionWrapper):
--> 126 output.reraise()
127 outputs.append(output)
128 return outputs
/usr/local/lib/python3.10/dist-packages/torch/_utils.py in reraise(self)
731 # instantiate since we don't know how to
732 raise RuntimeError(msg) from None
--> 733 raise exception
734
735
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/parallel_apply.py", line 96, in _worker
output = module(*input, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 32, in inner
return disable_fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/unsloth/models/llama.py", line 1226, in PeftModelForCausalLM_fast_forward
return self.base_model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/peft/tuners/tuners_utils.py", line 197, in forward
return self.model.forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/unsloth/models/llama.py", line 1062, in _CausalLM_fast_forward
outputs = self.model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/unsloth/models/llama.py", line 854, in LlamaModel_fast_forward
hidden_states = Unsloth_Offloaded_Gradient_Checkpointer.apply(
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 575, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 503, in decorate_fwd
return fwd(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/unsloth_zoo/gradient_checkpointing.py", line 147, in forward
output = forward_function(hidden_states, *args)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/unsloth/models/llama.py", line 533, in LlamaDecoderLayer_fast_forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/unsloth/models/llama.py", line 388, in LlamaAttention_fast_forward
Q, K, V = self.apply_qkv(self, hidden_states)
File "/usr/local/lib/python3.10/dist-packages/unsloth/kernels/fast_lora.py", line 366, in apply_lora_qkv
Q, K, V = LoRA_QKV.apply(X,
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 575, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 503, in decorate_fwd
return fwd(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/unsloth/kernels/fast_lora.py", line 259, in forward
Q = matmul_lora(X, QW, QW_quant, QA, QB, QS)
File "/usr/local/lib/python3.10/dist-packages/unsloth/kernels/utils.py", line 462, in matmul_lora
out = torch_matmul(X, W, out = out)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (2158x4096 and 1x8388608)