Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions trl/scripts/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@

from accelerate import logging
from datasets import load_dataset
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES

from trl import (
Expand All @@ -93,7 +93,7 @@

def main(script_args, training_args, model_args, dataset_args):
################
# Model init kwargs & Tokenizer
# Model init kwargs
################
model_kwargs = dict(
revision=model_args.model_revision,
Expand All @@ -118,11 +118,6 @@ def main(script_args, training_args, model_args, dataset_args):
else:
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)

# Create tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
)

# Load the dataset
if dataset_args.datasets and script_args.dataset_name:
logger.warning(
Expand All @@ -145,7 +140,6 @@ def main(script_args, training_args, model_args, dataset_args):
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
processing_class=tokenizer,
peft_config=get_peft_config(model_args),
)

Expand Down
Loading