@@ -38,7 +38,7 @@ def run(args):
3838 from unsloth import FastLanguageModel
3939 from datasets import load_dataset
4040 from transformers .utils import strtobool
41- from trl import SFTTrainer
41+ from trl import SFTTrainer , SFTConfig
4242 from transformers import TrainingArguments
4343 from unsloth import is_bfloat16_supported
4444 import logging
@@ -100,7 +100,7 @@ def formatting_prompts_func(examples):
100100 print ("Data is formatted and ready!" )
101101
102102 # Configure training arguments
103- training_args = TrainingArguments (
103+ training_args = SFTConfig (
104104 per_device_train_batch_size = args .per_device_train_batch_size ,
105105 gradient_accumulation_steps = args .gradient_accumulation_steps ,
106106 warmup_steps = args .warmup_steps ,
@@ -115,17 +115,16 @@ def formatting_prompts_func(examples):
115115 seed = args .seed ,
116116 output_dir = args .output_dir ,
117117 report_to = args .report_to ,
118+ max_length = args .max_seq_length ,
119+ dataset_num_proc = 2 ,
120+ packing = False ,
118121 )
119122
120123 # Initialize trainer
121124 trainer = SFTTrainer (
122125 model = model ,
123- tokenizer = tokenizer ,
126+ processing_class = tokenizer ,
124127 train_dataset = dataset ,
125- dataset_text_field = "text" ,
126- max_seq_length = args .max_seq_length ,
127- dataset_num_proc = 2 ,
128- packing = False ,
129128 args = training_args ,
130129 )
131130
0 commit comments