Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/detoxifying_a_lm.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ Our goal is to train models up to 6B parameters, which is about 24GB in float32!
- Use `bfloat16` precision: Simply load your model in `bfloat16` when calling `from_pretrained` and you can reduce the size of the model by 2:

```python
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=torch.bfloat16)
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", dtype=torch.bfloat16)
```

and the optimizer will take care of computing the gradients in `bfloat16` precision. Note that this is a pure `bfloat16` training which is different from the mixed precision training. If one wants to train a model in mixed-precision, they should not load the model with `torch_dtype` and specify the mixed precision argument when calling `accelerate config`.
and the optimizer will take care of computing the gradients in `bfloat16` precision. Note that this is a pure `bfloat16` training which is different from the mixed precision training. If one wants to train a model in mixed-precision, they should not load the model with `dtype` and specify the mixed precision argument when calling `accelerate config`.

- Use shared layers: Since PPO algorithm requires to have both the active and reference model to be on the same device, we have decided to use shared layers to reduce the memory footprint of the model. This can be achieved by specifying `num_shared_layers` argument when calling the `create_reference_model()` function. For example, if you want to share the first 6 layers of the model, you can do it like this:

Expand Down
2 changes: 1 addition & 1 deletion docs/source/dpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ model = AutoModelForCausalLM.from_pretrained(
load_in_4bit=True,
quantization_config=bnb_config,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
dtype=torch.bfloat16,
device_map="auto",
)

Expand Down
2 changes: 1 addition & 1 deletion docs/source/grpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ accelerate launch \
--output_dir grpo-Qwen2.5-VL-3B-Instruct \
--learning_rate 1e-5 \
--gradient_checkpointing \
--torch_dtype bfloat16 \
--dtype bfloat16 \
--max_prompt_length 2048 \
--max_completion_length 1024 \
--use_vllm \
Expand Down
4 changes: 2 additions & 2 deletions docs/source/iterative_sft_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ from trl import IterativeSFTConfig

config = IterativeSFTConfig(
# Model initialization parameters
model_init_kwargs={"torch_dtype": "bfloat16"},
model_init_kwargs={"dtype": "bfloat16"},

# Data preprocessing parameters
max_length=512,
Expand All @@ -104,7 +104,7 @@ You can control how the model is initialized by passing keyword arguments to `mo
```python
config = IterativeSFTConfig(
model_init_kwargs={
"torch_dtype": "bfloat16",
"dtype": "bfloat16",
"device_map": "auto",
"trust_remote_code": True,
}
Expand Down
6 changes: 3 additions & 3 deletions docs/source/sft_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,16 +130,16 @@ While training and evaluating we record the following reward metrics:
You can directly pass the kwargs of the [`~transformers.AutoModelForCausalLM.from_pretrained()`] method to the [`SFTConfig`]. For example, if you want to load a model in a different precision, analogous to

```python
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", torch_dtype=torch.bfloat16)
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", dtype=torch.bfloat16)
```

you can do so by passing the `model_init_kwargs={"torch_dtype": torch.bfloat16}` argument to the [`SFTConfig`].
you can do so by passing the `model_init_kwargs={"dtype": torch.bfloat16}` argument to the [`SFTConfig`].

```python
from trl import SFTConfig

training_args = SFTConfig(
model_init_kwargs={"torch_dtype": torch.bfloat16},
model_init_kwargs={"dtype": torch.bfloat16},
)
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def generate_tokens_with_assistance(model, inputs, assistant_early_exit):
if __name__ == "__main__":
ckpt = config.hub_model_id

model = AutoModelForCausalLM.from_pretrained(ckpt, device_map="auto", torch_dtype=torch.bfloat16)
model = AutoModelForCausalLM.from_pretrained(ckpt, device_map="auto", dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(ckpt)

prompt = "### Instruction: What are my alarms for the rest of the day?\n ### Response: "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def formatting_prompts_func(example):

# load the model and tokenizer
print("[INFO] loading the model and tokenizer...")
model = AutoModelForCausalLM.from_pretrained(config.model_name, device_map="auto", torch_dtype=torch.bfloat16)
model = AutoModelForCausalLM.from_pretrained(config.model_name, device_map="auto", dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, add_eos_token=True)

# adding pad and eos tokens if not provided in the tokenizer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,10 @@ class ScriptArguments:
if peft_config.task_type == "SEQ_CLS":
# The sequence classification task is used for the reward model in PPO
model = AutoModelForSequenceClassification.from_pretrained(
script_args.base_model_name, num_labels=1, torch_dtype=torch.bfloat16
script_args.base_model_name, num_labels=1, dtype=torch.bfloat16
)
else:
model = AutoModelForCausalLM.from_pretrained(
script_args.base_model_name, return_dict=True, torch_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(script_args.base_model_name, return_dict=True, dtype=torch.bfloat16)

tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,7 @@ class ScriptArguments:
lora_dropout=0.1,
)

model = AutoModelForSequenceClassification.from_pretrained(
script_args.model_name, num_labels=1, torch_dtype=torch.bfloat16
)
model = AutoModelForSequenceClassification.from_pretrained(script_args.model_name, num_labels=1, dtype=torch.bfloat16)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

Expand Down
2 changes: 1 addition & 1 deletion examples/research_projects/stack_llama_2/scripts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ from peft import AutoPeftModelForCausalLM
model = AutoPeftModelForCausalLM.from_pretrained(
"dpo/final_checkpoint",
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
dtype=torch.float16,
load_in_4bit=True,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,16 +148,16 @@ def return_prompt_and_responses(samples) -> dict[str, str]:
set_seed(script_args.seed)

# 1. load a pretrained model
torch_dtype = torch.float
dtype = torch.float
if script_args.model_dtype == "float16":
torch_dtype = torch.float16
dtype = torch.float16
elif script_args.model_dtype == "bfloat16":
torch_dtype = torch.bfloat16
dtype = torch.bfloat16

model = AutoModelForCausalLM.from_pretrained(
script_args.model_name_or_path,
low_cpu_mem_usage=True,
torch_dtype=torch_dtype,
dtype=dtype,
load_in_4bit=script_args.load_in_4bit,
device_map={"": Accelerator().local_process_index},
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def create_datasets(tokenizer, args, seed=None):
else:
torch.cuda.empty_cache()

model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map="auto", torch_dtype=torch.bfloat16)
model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map="auto", dtype=torch.bfloat16)
model = model.merge_and_unload()

output_merged_dir = os.path.join(training_args.output_dir, "final_merged_checkpoint")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@


for model_id in tqdm(MODELS_TO_TEST):
model = AutoModelForCausalLM.from_pretrained(model_id, device_map={"": device}, torch_dtype=torch.bfloat16)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map={"": device}, dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def collator(data):

# Now let's build the model, the reference model, and the tokenizer. We first load the model
# in bfloat16 to save memory using `transformers`.
model = AutoModelForCausalLM.from_pretrained(config.model_name, torch_dtype=torch.bfloat16)
model = AutoModelForCausalLM.from_pretrained(config.model_name, dtype=torch.bfloat16)
# And then we pass the loaded model to `AutoModelForCausalLMWithValueHead`.
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)

Expand Down Expand Up @@ -186,7 +186,7 @@ def collator(data):
toxicity_model_id = "facebook/roberta-hate-speech-dynabench-r4-target"
toxicity_tokenizer = RobertaTokenizer.from_pretrained(toxicity_model_id)
# We load the toxicity model in fp16 to save memory.
toxicity_model = RobertaForSequenceClassification.from_pretrained(toxicity_model_id, torch_dtype=torch.float16).to(
toxicity_model = RobertaForSequenceClassification.from_pretrained(toxicity_model_id, dtype=torch.float16).to(
ppo_trainer.accelerator.device
)

Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/bco.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def mean_pooling(model_output, attention_mask):
"nomic-ai/nomic-embed-text-v1.5",
trust_remote_code=model_args.trust_remote_code,
safe_serialization=True,
torch_dtype=torch.bfloat16,
dtype=torch.bfloat16,
device_map="auto",
)
embedding_model = accelerator.prepare_model(embedding_model)
Expand Down
6 changes: 2 additions & 4 deletions examples/scripts/dpo_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,12 @@
script_args, training_args, model_args = parser.parse_args_and_config()
training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}

torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
dtype=dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
Expand Down
10 changes: 4 additions & 6 deletions examples/scripts/dpo_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
--gradient_accumulation_steps 32 \
--dataset_num_proc 32 \
--output_dir dpo_idefics_rlaif-v \
--torch_dtype bfloat16 \
--dtype bfloat16 \
--gradient_checkpointing \
--use_peft \
--lora_target_modules all-linear
Expand All @@ -51,7 +51,7 @@
--gradient_accumulation_steps 32 \
--dataset_num_proc 32 \
--output_dir dpo_idefics_rlaif-v \
--torch_dtype bfloat16 \
--dtype bfloat16 \
--gradient_checkpointing \
--use_peft \
--lora_target_modules all-linear
Expand Down Expand Up @@ -86,15 +86,13 @@
################
# Model & Tokenizer
################
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
quantization_config = get_quantization_config(model_args)

model_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
dtype=dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/gkd.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=model_args.torch_dtype,
dtype=model_args.dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
Expand All @@ -96,7 +96,7 @@
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=model_args.torch_dtype,
dtype=model_args.dtype,
use_cache=True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
Expand Down
10 changes: 4 additions & 6 deletions examples/scripts/grpo_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
--output_dir grpo-Qwen2.5-VL-3B-Instruct \
--learning_rate 1e-5 \
--gradient_checkpointing \
--torch_dtype bfloat16 \
--dtype bfloat16 \
--max_prompt_length 2048 \
--max_completion_length 1024 \
--use_vllm \
Expand All @@ -53,7 +53,7 @@
--model_name_or_path HuggingFaceTB/SmolVLM2-2.2B-Instruct \
--output_dir grpo-SmolVLM2-2.2B-Instruct \
--learning_rate 1e-5 \
--torch_dtype bfloat16 \
--dtype bfloat16 \
--max_prompt_length 2048 \
--max_completion_length 1024 \
--use_peft \
Expand Down Expand Up @@ -95,14 +95,12 @@
################
# Model & Processor
################
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
quantization_config = get_quantization_config(model_args)
training_args.model_init_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
dtype=dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
Expand Down
8 changes: 3 additions & 5 deletions examples/scripts/gspo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
--model_name_or_path Qwen/Qwen3-0.6B \
--output_dir gspo-Qwen3-0.6B \
--learning_rate 1e-5 \
--torch_dtype bfloat16 \
--dtype bfloat16 \
--max_prompt_length 2048 \
--max_completion_length 1024 \
--use_peft \
Expand Down Expand Up @@ -81,14 +81,12 @@
################
# Model & Processor
################
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
quantization_config = get_quantization_config(model_args)
training_args.model_init_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
dtype=dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
Expand Down
8 changes: 3 additions & 5 deletions examples/scripts/gspo_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
--model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
--output_dir gspo-Qwen2.5-VL-3B-Instruct \
--learning_rate 1e-5 \
--torch_dtype bfloat16 \
--dtype bfloat16 \
--max_prompt_length 2048 \
--max_completion_length 1024 \
--use_peft \
Expand Down Expand Up @@ -82,14 +82,12 @@
################
# Model & Processor
################
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
quantization_config = get_quantization_config(model_args)
training_args.model_init_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
dtype=dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
Expand Down
8 changes: 3 additions & 5 deletions examples/scripts/mpo_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
--gradient_accumulation_steps 8 \
--dataset_num_proc 1 \
--output_dir dpo_idefics_rlaif-v \
--torch_dtype bfloat16 \
--dtype bfloat16 \
--gradient_checkpointing \
--use_peft \
--lora_target_modules down_proj, o_proj, k_proj, q_proj, gate_proj, up_proj, v_proj \
Expand Down Expand Up @@ -70,16 +70,14 @@
################
# Model & Processor
################
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
quantization_config = get_quantization_config(model_args)

model_kwargs = dict(
trust_remote_code=model_args.trust_remote_code,
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
dtype=dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
Expand Down
6 changes: 2 additions & 4 deletions examples/scripts/nash_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,12 @@
script_args, training_args, model_args = parser.parse_args_and_config()
training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}

torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
dtype=dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
Expand Down
Loading
Loading