diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 6d215e1fe8e..483461990d1 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -235,7 +235,7 @@ jobs: uv pip install ".[dev]" uv pip install accelerate==1.4.0 uv pip install datasets==3.0.0 - uv pip install transformers==4.55.0 + uv pip install transformers==4.56.0 - name: Test with pytest run: | diff --git a/docs/source/detoxifying_a_lm.md b/docs/source/detoxifying_a_lm.md index 9e3530d7c2d..442a8230436 100644 --- a/docs/source/detoxifying_a_lm.md +++ b/docs/source/detoxifying_a_lm.md @@ -93,10 +93,10 @@ Our goal is to train models up to 6B parameters, which is about 24GB in float32! - Use `bfloat16` precision: Simply load your model in `bfloat16` when calling `from_pretrained` and you can reduce the size of the model by 2: ```python -model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=torch.bfloat16) +model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", dtype=torch.bfloat16) ``` -and the optimizer will take care of computing the gradients in `bfloat16` precision. Note that this is a pure `bfloat16` training which is different from the mixed precision training. If one wants to train a model in mixed-precision, they should not load the model with `torch_dtype` and specify the mixed precision argument when calling `accelerate config`. +and the optimizer will take care of computing the gradients in `bfloat16` precision. Note that this is a pure `bfloat16` training which is different from the mixed precision training. If one wants to train a model in mixed-precision, they should not load the model with `dtype` and specify the mixed precision argument when calling `accelerate config`. - Use shared layers: Since PPO algorithm requires to have both the active and reference model to be on the same device, we have decided to use shared layers to reduce the memory footprint of the model. This can be achieved by specifying `num_shared_layers` argument when calling the `create_reference_model()` function. For example, if you want to share the first 6 layers of the model, you can do it like this: diff --git a/docs/source/dpo_trainer.md b/docs/source/dpo_trainer.md index 35b87da0401..2c17154bf36 100644 --- a/docs/source/dpo_trainer.md +++ b/docs/source/dpo_trainer.md @@ -255,7 +255,7 @@ model = AutoModelForCausalLM.from_pretrained( load_in_4bit=True, quantization_config=bnb_config, attn_implementation="flash_attention_2", - torch_dtype=torch.bfloat16, + dtype=torch.bfloat16, device_map="auto", ) diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index ace444e07c3..a88682f4e7f 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -573,7 +573,7 @@ accelerate launch \ --output_dir grpo-Qwen2.5-VL-3B-Instruct \ --learning_rate 1e-5 \ --gradient_checkpointing \ - --torch_dtype bfloat16 \ + --dtype bfloat16 \ --max_prompt_length 2048 \ --max_completion_length 1024 \ --use_vllm \ diff --git a/docs/source/iterative_sft_trainer.md b/docs/source/iterative_sft_trainer.md index 54968ead78b..4092f0c0655 100644 --- a/docs/source/iterative_sft_trainer.md +++ b/docs/source/iterative_sft_trainer.md @@ -79,7 +79,7 @@ from trl import IterativeSFTConfig config = IterativeSFTConfig( # Model initialization parameters - model_init_kwargs={"torch_dtype": "bfloat16"}, + model_init_kwargs={"dtype": "bfloat16"}, # Data preprocessing parameters max_length=512, @@ -104,7 +104,7 @@ You can control how the model is initialized by passing keyword arguments to `mo ```python config = IterativeSFTConfig( model_init_kwargs={ - "torch_dtype": "bfloat16", + "dtype": "bfloat16", "device_map": "auto", "trust_remote_code": True, } diff --git a/docs/source/sft_trainer.md b/docs/source/sft_trainer.md index a68278f9092..1a13b078e17 100644 --- a/docs/source/sft_trainer.md +++ b/docs/source/sft_trainer.md @@ -130,16 +130,16 @@ While training and evaluating we record the following reward metrics: You can directly pass the kwargs of the [`~transformers.AutoModelForCausalLM.from_pretrained()`] method to the [`SFTConfig`]. For example, if you want to load a model in a different precision, analogous to ```python -model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", torch_dtype=torch.bfloat16) +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", dtype=torch.bfloat16) ``` -you can do so by passing the `model_init_kwargs={"torch_dtype": torch.bfloat16}` argument to the [`SFTConfig`]. +you can do so by passing the `model_init_kwargs={"dtype": torch.bfloat16}` argument to the [`SFTConfig`]. ```python from trl import SFTConfig training_args = SFTConfig( - model_init_kwargs={"torch_dtype": torch.bfloat16}, + model_init_kwargs={"dtype": torch.bfloat16}, ) ``` diff --git a/examples/research_projects/layer_skip/scripts/benchmark_layer_skip.py b/examples/research_projects/layer_skip/scripts/benchmark_layer_skip.py index 9359dda33eb..3225582cbaf 100644 --- a/examples/research_projects/layer_skip/scripts/benchmark_layer_skip.py +++ b/examples/research_projects/layer_skip/scripts/benchmark_layer_skip.py @@ -40,7 +40,7 @@ def generate_tokens_with_assistance(model, inputs, assistant_early_exit): if __name__ == "__main__": ckpt = config.hub_model_id - model = AutoModelForCausalLM.from_pretrained(ckpt, device_map="auto", torch_dtype=torch.bfloat16) + model = AutoModelForCausalLM.from_pretrained(ckpt, device_map="auto", dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained(ckpt) prompt = "### Instruction: What are my alarms for the rest of the day?\n ### Response: " diff --git a/examples/research_projects/layer_skip/scripts/layer_skip_sft.py b/examples/research_projects/layer_skip/scripts/layer_skip_sft.py index 71d2f6868a6..bc75dc1b0dd 100644 --- a/examples/research_projects/layer_skip/scripts/layer_skip_sft.py +++ b/examples/research_projects/layer_skip/scripts/layer_skip_sft.py @@ -43,7 +43,7 @@ def formatting_prompts_func(example): # load the model and tokenizer print("[INFO] loading the model and tokenizer...") - model = AutoModelForCausalLM.from_pretrained(config.model_name, device_map="auto", torch_dtype=torch.bfloat16) + model = AutoModelForCausalLM.from_pretrained(config.model_name, device_map="auto", dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, add_eos_token=True) # adding pad and eos tokens if not provided in the tokenizer diff --git a/examples/research_projects/stack_llama/scripts/merge_peft_adapter.py b/examples/research_projects/stack_llama/scripts/merge_peft_adapter.py index 9c7974ae0fa..0f07ec3a580 100644 --- a/examples/research_projects/stack_llama/scripts/merge_peft_adapter.py +++ b/examples/research_projects/stack_llama/scripts/merge_peft_adapter.py @@ -42,12 +42,10 @@ class ScriptArguments: if peft_config.task_type == "SEQ_CLS": # The sequence classification task is used for the reward model in PPO model = AutoModelForSequenceClassification.from_pretrained( - script_args.base_model_name, num_labels=1, torch_dtype=torch.bfloat16 + script_args.base_model_name, num_labels=1, dtype=torch.bfloat16 ) else: - model = AutoModelForCausalLM.from_pretrained( - script_args.base_model_name, return_dict=True, torch_dtype=torch.bfloat16 - ) + model = AutoModelForCausalLM.from_pretrained(script_args.base_model_name, return_dict=True, dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name) diff --git a/examples/research_projects/stack_llama/scripts/reward_modeling.py b/examples/research_projects/stack_llama/scripts/reward_modeling.py index c25b1fd61ef..780b192e07c 100644 --- a/examples/research_projects/stack_llama/scripts/reward_modeling.py +++ b/examples/research_projects/stack_llama/scripts/reward_modeling.py @@ -168,9 +168,7 @@ class ScriptArguments: lora_dropout=0.1, ) -model = AutoModelForSequenceClassification.from_pretrained( - script_args.model_name, num_labels=1, torch_dtype=torch.bfloat16 -) +model = AutoModelForSequenceClassification.from_pretrained(script_args.model_name, num_labels=1, dtype=torch.bfloat16) model = get_peft_model(model, peft_config) model.print_trainable_parameters() diff --git a/examples/research_projects/stack_llama_2/scripts/README.md b/examples/research_projects/stack_llama_2/scripts/README.md index 46a9fe6f2cc..5b226c5d2f6 100644 --- a/examples/research_projects/stack_llama_2/scripts/README.md +++ b/examples/research_projects/stack_llama_2/scripts/README.md @@ -67,7 +67,7 @@ from peft import AutoPeftModelForCausalLM model = AutoPeftModelForCausalLM.from_pretrained( "dpo/final_checkpoint", low_cpu_mem_usage=True, - torch_dtype=torch.float16, + dtype=torch.float16, load_in_4bit=True, ) diff --git a/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py b/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py index 43f9d35b3e4..deffc7b61b7 100644 --- a/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py +++ b/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py @@ -148,16 +148,16 @@ def return_prompt_and_responses(samples) -> dict[str, str]: set_seed(script_args.seed) # 1. load a pretrained model - torch_dtype = torch.float + dtype = torch.float if script_args.model_dtype == "float16": - torch_dtype = torch.float16 + dtype = torch.float16 elif script_args.model_dtype == "bfloat16": - torch_dtype = torch.bfloat16 + dtype = torch.bfloat16 model = AutoModelForCausalLM.from_pretrained( script_args.model_name_or_path, low_cpu_mem_usage=True, - torch_dtype=torch_dtype, + dtype=dtype, load_in_4bit=script_args.load_in_4bit, device_map={"": Accelerator().local_process_index}, ) diff --git a/examples/research_projects/stack_llama_2/scripts/sft_llama2.py b/examples/research_projects/stack_llama_2/scripts/sft_llama2.py index dff5b169e84..b50e0fab813 100644 --- a/examples/research_projects/stack_llama_2/scripts/sft_llama2.py +++ b/examples/research_projects/stack_llama_2/scripts/sft_llama2.py @@ -205,7 +205,7 @@ def create_datasets(tokenizer, args, seed=None): else: torch.cuda.empty_cache() -model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map="auto", torch_dtype=torch.bfloat16) +model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map="auto", dtype=torch.bfloat16) model = model.merge_and_unload() output_merged_dir = os.path.join(training_args.output_dir, "final_merged_checkpoint") diff --git a/examples/research_projects/toxicity/scripts/evaluate-toxicity.py b/examples/research_projects/toxicity/scripts/evaluate-toxicity.py index 6a1913f3d61..1b85ceeeae0 100644 --- a/examples/research_projects/toxicity/scripts/evaluate-toxicity.py +++ b/examples/research_projects/toxicity/scripts/evaluate-toxicity.py @@ -84,7 +84,7 @@ for model_id in tqdm(MODELS_TO_TEST): - model = AutoModelForCausalLM.from_pretrained(model_id, device_map={"": device}, torch_dtype=torch.bfloat16) + model = AutoModelForCausalLM.from_pretrained(model_id, device_map={"": device}, dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" diff --git a/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py b/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py index edab2a669b5..310fbbd1b77 100644 --- a/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py +++ b/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py @@ -155,7 +155,7 @@ def collator(data): # Now let's build the model, the reference model, and the tokenizer. We first load the model # in bfloat16 to save memory using `transformers`. -model = AutoModelForCausalLM.from_pretrained(config.model_name, torch_dtype=torch.bfloat16) +model = AutoModelForCausalLM.from_pretrained(config.model_name, dtype=torch.bfloat16) # And then we pass the loaded model to `AutoModelForCausalLMWithValueHead`. model = AutoModelForCausalLMWithValueHead.from_pretrained(model) @@ -186,7 +186,7 @@ def collator(data): toxicity_model_id = "facebook/roberta-hate-speech-dynabench-r4-target" toxicity_tokenizer = RobertaTokenizer.from_pretrained(toxicity_model_id) # We load the toxicity model in fp16 to save memory. -toxicity_model = RobertaForSequenceClassification.from_pretrained(toxicity_model_id, torch_dtype=torch.float16).to( +toxicity_model = RobertaForSequenceClassification.from_pretrained(toxicity_model_id, dtype=torch.float16).to( ppo_trainer.accelerator.device ) diff --git a/examples/scripts/bco.py b/examples/scripts/bco.py index 958859ba013..03b304e65a3 100644 --- a/examples/scripts/bco.py +++ b/examples/scripts/bco.py @@ -141,7 +141,7 @@ def mean_pooling(model_output, attention_mask): "nomic-ai/nomic-embed-text-v1.5", trust_remote_code=model_args.trust_remote_code, safe_serialization=True, - torch_dtype=torch.bfloat16, + dtype=torch.bfloat16, device_map="auto", ) embedding_model = accelerator.prepare_model(embedding_model) diff --git a/examples/scripts/dpo_online.py b/examples/scripts/dpo_online.py index a9f652e5efc..8cf9d0f3e6d 100644 --- a/examples/scripts/dpo_online.py +++ b/examples/scripts/dpo_online.py @@ -82,14 +82,12 @@ script_args, training_args, model_args = parser.parse_args_and_config() training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} - torch_dtype = ( - model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) - ) + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) quantization_config = get_quantization_config(model_args) model_kwargs = dict( revision=model_args.model_revision, attn_implementation=model_args.attn_implementation, - torch_dtype=torch_dtype, + dtype=dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, diff --git a/examples/scripts/dpo_vlm.py b/examples/scripts/dpo_vlm.py index b98eaeaf908..e73f94ab1d9 100644 --- a/examples/scripts/dpo_vlm.py +++ b/examples/scripts/dpo_vlm.py @@ -33,7 +33,7 @@ --gradient_accumulation_steps 32 \ --dataset_num_proc 32 \ --output_dir dpo_idefics_rlaif-v \ - --torch_dtype bfloat16 \ + --dtype bfloat16 \ --gradient_checkpointing \ --use_peft \ --lora_target_modules all-linear @@ -51,7 +51,7 @@ --gradient_accumulation_steps 32 \ --dataset_num_proc 32 \ --output_dir dpo_idefics_rlaif-v \ - --torch_dtype bfloat16 \ + --dtype bfloat16 \ --gradient_checkpointing \ --use_peft \ --lora_target_modules all-linear @@ -86,15 +86,13 @@ ################ # Model & Tokenizer ################ - torch_dtype = ( - model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) - ) + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) quantization_config = get_quantization_config(model_args) model_kwargs = dict( revision=model_args.model_revision, attn_implementation=model_args.attn_implementation, - torch_dtype=torch_dtype, + dtype=dtype, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) diff --git a/examples/scripts/gkd.py b/examples/scripts/gkd.py index 9be6fbdd350..2f060bb3cd5 100644 --- a/examples/scripts/gkd.py +++ b/examples/scripts/gkd.py @@ -85,7 +85,7 @@ revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, attn_implementation=model_args.attn_implementation, - torch_dtype=model_args.torch_dtype, + dtype=model_args.dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, @@ -96,7 +96,7 @@ revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, attn_implementation=model_args.attn_implementation, - torch_dtype=model_args.torch_dtype, + dtype=model_args.dtype, use_cache=True, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, diff --git a/examples/scripts/grpo_vlm.py b/examples/scripts/grpo_vlm.py index 2925e77be6a..b663f91534e 100644 --- a/examples/scripts/grpo_vlm.py +++ b/examples/scripts/grpo_vlm.py @@ -35,7 +35,7 @@ --output_dir grpo-Qwen2.5-VL-3B-Instruct \ --learning_rate 1e-5 \ --gradient_checkpointing \ - --torch_dtype bfloat16 \ + --dtype bfloat16 \ --max_prompt_length 2048 \ --max_completion_length 1024 \ --use_vllm \ @@ -53,7 +53,7 @@ --model_name_or_path HuggingFaceTB/SmolVLM2-2.2B-Instruct \ --output_dir grpo-SmolVLM2-2.2B-Instruct \ --learning_rate 1e-5 \ - --torch_dtype bfloat16 \ + --dtype bfloat16 \ --max_prompt_length 2048 \ --max_completion_length 1024 \ --use_peft \ @@ -95,14 +95,12 @@ ################ # Model & Processor ################ - torch_dtype = ( - model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) - ) + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) quantization_config = get_quantization_config(model_args) training_args.model_init_kwargs = dict( revision=model_args.model_revision, attn_implementation=model_args.attn_implementation, - torch_dtype=torch_dtype, + dtype=dtype, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) diff --git a/examples/scripts/gspo.py b/examples/scripts/gspo.py index 59702bce91d..3a8aa6f625d 100644 --- a/examples/scripts/gspo.py +++ b/examples/scripts/gspo.py @@ -34,7 +34,7 @@ --model_name_or_path Qwen/Qwen3-0.6B \ --output_dir gspo-Qwen3-0.6B \ --learning_rate 1e-5 \ - --torch_dtype bfloat16 \ + --dtype bfloat16 \ --max_prompt_length 2048 \ --max_completion_length 1024 \ --use_peft \ @@ -81,14 +81,12 @@ ################ # Model & Processor ################ - torch_dtype = ( - model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) - ) + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) quantization_config = get_quantization_config(model_args) training_args.model_init_kwargs = dict( revision=model_args.model_revision, attn_implementation=model_args.attn_implementation, - torch_dtype=torch_dtype, + dtype=dtype, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) diff --git a/examples/scripts/gspo_vlm.py b/examples/scripts/gspo_vlm.py index f099c1c5866..10e1831251d 100644 --- a/examples/scripts/gspo_vlm.py +++ b/examples/scripts/gspo_vlm.py @@ -34,7 +34,7 @@ --model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \ --output_dir gspo-Qwen2.5-VL-3B-Instruct \ --learning_rate 1e-5 \ - --torch_dtype bfloat16 \ + --dtype bfloat16 \ --max_prompt_length 2048 \ --max_completion_length 1024 \ --use_peft \ @@ -82,14 +82,12 @@ ################ # Model & Processor ################ - torch_dtype = ( - model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) - ) + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) quantization_config = get_quantization_config(model_args) training_args.model_init_kwargs = dict( revision=model_args.model_revision, attn_implementation=model_args.attn_implementation, - torch_dtype=torch_dtype, + dtype=dtype, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) diff --git a/examples/scripts/mpo_vlm.py b/examples/scripts/mpo_vlm.py index 510b8ae13c9..23e3a0192cf 100644 --- a/examples/scripts/mpo_vlm.py +++ b/examples/scripts/mpo_vlm.py @@ -32,7 +32,7 @@ --gradient_accumulation_steps 8 \ --dataset_num_proc 1 \ --output_dir dpo_idefics_rlaif-v \ - --torch_dtype bfloat16 \ + --dtype bfloat16 \ --gradient_checkpointing \ --use_peft \ --lora_target_modules down_proj, o_proj, k_proj, q_proj, gate_proj, up_proj, v_proj \ @@ -70,16 +70,14 @@ ################ # Model & Processor ################ - torch_dtype = ( - model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) - ) + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) quantization_config = get_quantization_config(model_args) model_kwargs = dict( trust_remote_code=model_args.trust_remote_code, revision=model_args.model_revision, attn_implementation=model_args.attn_implementation, - torch_dtype=torch_dtype, + dtype=dtype, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) diff --git a/examples/scripts/nash_md.py b/examples/scripts/nash_md.py index 28b44089d58..f02a0de34c1 100644 --- a/examples/scripts/nash_md.py +++ b/examples/scripts/nash_md.py @@ -86,14 +86,12 @@ script_args, training_args, model_args = parser.parse_args_and_config() training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} - torch_dtype = ( - model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) - ) + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) quantization_config = get_quantization_config(model_args) model_kwargs = dict( revision=model_args.model_revision, attn_implementation=model_args.attn_implementation, - torch_dtype=torch_dtype, + dtype=dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, diff --git a/examples/scripts/ppo/ppo.py b/examples/scripts/ppo/ppo.py index 75f52b20d1f..49c850e5fd6 100644 --- a/examples/scripts/ppo/ppo.py +++ b/examples/scripts/ppo/ppo.py @@ -89,14 +89,12 @@ ################ # Model & Tokenizer ################ - torch_dtype = ( - model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) - ) + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) quantization_config = get_quantization_config(model_args) model_kwargs = dict( revision=model_args.model_revision, attn_implementation=model_args.attn_implementation, - torch_dtype=torch_dtype, + dtype=dtype, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) diff --git a/examples/scripts/ppo/ppo_tldr.py b/examples/scripts/ppo/ppo_tldr.py index 633e26bb01c..bd9714f7b6a 100644 --- a/examples/scripts/ppo/ppo_tldr.py +++ b/examples/scripts/ppo/ppo_tldr.py @@ -96,14 +96,12 @@ ################ # Model & Tokenizer ################ - torch_dtype = ( - model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) - ) + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) quantization_config = get_quantization_config(model_args) model_kwargs = dict( revision=model_args.model_revision, attn_implementation=model_args.attn_implementation, - torch_dtype=torch_dtype, + dtype=dtype, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) diff --git a/examples/scripts/prm.py b/examples/scripts/prm.py index 7e5ecc4b2d6..72a6abd7c23 100644 --- a/examples/scripts/prm.py +++ b/examples/scripts/prm.py @@ -81,11 +81,7 @@ ################ # Model & Tokenizer ################ - torch_dtype = ( - model_config.torch_dtype - if model_config.torch_dtype in ["auto", None] - else getattr(torch, model_config.torch_dtype) - ) + dtype = model_config.dtype if model_config.dtype in ["auto", None] else getattr(torch, model_config.dtype) quantization_config = get_quantization_config(model_config) model_kwargs = dict( revision=model_config.model_revision, diff --git a/examples/scripts/reward_modeling.py b/examples/scripts/reward_modeling.py index 309a86c3a62..78abce6081c 100644 --- a/examples/scripts/reward_modeling.py +++ b/examples/scripts/reward_modeling.py @@ -84,16 +84,14 @@ ################ # Model & Tokenizer ################ - torch_dtype = ( - model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) - ) + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) quantization_config = get_quantization_config(model_args) model_kwargs = dict( revision=model_args.model_revision, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, use_cache=False if training_args.gradient_checkpointing else True, - torch_dtype=torch_dtype, + dtype=dtype, ) tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True diff --git a/examples/scripts/sft_gpt_oss.py b/examples/scripts/sft_gpt_oss.py index 509329dd5fd..55759692337 100644 --- a/examples/scripts/sft_gpt_oss.py +++ b/examples/scripts/sft_gpt_oss.py @@ -28,7 +28,7 @@ accelerate launch \ --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ examples/scripts/sft_gpt_oss.py \ - --torch_dtype bfloat16 \ + --dtype bfloat16 \ --model_name_or_path openai/gpt-oss-20b \ --packing true packing_strategy wrapped \ --run_name 20b-full-eager \ @@ -67,7 +67,7 @@ def main(script_args, training_args, model_args): revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, attn_implementation=model_args.attn_implementation, - torch_dtype=model_args.torch_dtype, + dtype=model_args.dtype, use_cache=False if training_args.gradient_checkpointing else True, quantization_config=quantization_config, ) diff --git a/examples/scripts/sft_video_llm.py b/examples/scripts/sft_video_llm.py index d80c66c02a8..06364a5022c 100644 --- a/examples/scripts/sft_video_llm.py +++ b/examples/scripts/sft_video_llm.py @@ -46,7 +46,7 @@ --warmup_ratio 0.1 \ --lr_scheduler_type cosine \ --push_to_hub False \ - --torch_dtype bfloat16 \ + --dtype bfloat16 \ --gradient_checkpointing True """ @@ -187,9 +187,7 @@ class CustomScriptArguments(ScriptArguments): dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config, split="train") # Setup model - torch_dtype = ( - model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) - ) + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) # Quantization configuration for 4-bit training bnb_config = BitsAndBytesConfig( @@ -203,7 +201,7 @@ class CustomScriptArguments(ScriptArguments): model_kwargs = dict( revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, - torch_dtype=torch_dtype, + dtype=dtype, device_map=get_kbit_device_map(), quantization_config=bnb_config, ) diff --git a/examples/scripts/sft_vlm.py b/examples/scripts/sft_vlm.py index 3bdb0b68ed1..af9c49be3f9 100644 --- a/examples/scripts/sft_vlm.py +++ b/examples/scripts/sft_vlm.py @@ -32,7 +32,7 @@ --model_name_or_path llava-hf/llava-1.5-7b-hf \ --gradient_accumulation_steps 8 \ --output_dir LLaVA-1.5-7B-SFT \ - --torch_dtype bfloat16 + --dtype bfloat16 For LLaVA-NeXT, use: (requires transformers>=4.45) --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf @@ -48,7 +48,7 @@ --per_device_train_batch_size 1 \ --gradient_accumulation_steps 1 \ --output_dir SmolVLM-SFT \ - --torch_dtype bfloat16 \ + --dtype bfloat16 \ --use_peft \ --lora_target_modules down_proj, o_proj, k_proj, q_proj, gate_proj, up_proj, v_proj """ @@ -83,14 +83,12 @@ ################ # Model, Tokenizer & Processor ################ - torch_dtype = ( - model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) - ) + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) quantization_config = get_quantization_config(model_args) model_kwargs = dict( revision=model_args.model_revision, attn_implementation=model_args.attn_implementation, - torch_dtype=torch_dtype, + dtype=dtype, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) diff --git a/examples/scripts/sft_vlm_gemma3.py b/examples/scripts/sft_vlm_gemma3.py index 4fa194b4597..82dfe59cd6f 100644 --- a/examples/scripts/sft_vlm_gemma3.py +++ b/examples/scripts/sft_vlm_gemma3.py @@ -31,7 +31,7 @@ --model_name_or_path google/gemma-3-4b-it \ --per_device_train_batch_size 1 \ --output_dir Gemma-3-4B-SFT-MMIU \ - --torch_dtype bfloat16 \ + --dtype bfloat16 \ --use_peft \ --lora_target_modules all-linear \ --attn_implementation eager @@ -46,7 +46,7 @@ --model_name_or_path google/gemma-3-4b-it \ --per_device_train_batch_size 1 \ --output_dir Gemma-3-4B-SFT-MMIU \ - --torch_dtype bfloat16 \ + --dtype bfloat16 \ --use_peft \ --lora_target_modules all-linear \ --attn_implementation eager @@ -148,14 +148,12 @@ def main(): ################ # Model, Tokenizer & Processor ################ - torch_dtype = ( - model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) - ) + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) quantization_config = get_quantization_config(model_args) model_kwargs = dict( revision=model_args.model_revision, attn_implementation=model_args.attn_implementation, - torch_dtype=torch_dtype, + dtype=dtype, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) diff --git a/examples/scripts/xpo.py b/examples/scripts/xpo.py index 1486f00baad..ef62b0a52cb 100644 --- a/examples/scripts/xpo.py +++ b/examples/scripts/xpo.py @@ -71,14 +71,12 @@ script_args, training_args, model_args = parser.parse_args_and_config() training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} - torch_dtype = ( - model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) - ) + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) quantization_config = get_quantization_config(model_args) model_kwargs = dict( revision=model_args.model_revision, attn_implementation=model_args.attn_implementation, - torch_dtype=torch_dtype, + dtype=dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, diff --git a/requirements.txt b/requirements.txt index 9e99bb9806a..7016360177d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ accelerate>=1.4.0 datasets>=3.0.0 -transformers>=4.55.0 \ No newline at end of file +transformers>=4.56.0 \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 01554407f3c..21bc48d43dd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -29,7 +29,7 @@ include_package_data = True install_requires = accelerate>=1.4.0 datasets>=3.0.0 - transformers>=4.55.0 + transformers>=4.56.0 [options.packages.find] exclude = diff --git a/tests/slow/test_grpo_slow.py b/tests/slow/test_grpo_slow.py index 57b68eebb1f..f81c6823716 100644 --- a/tests/slow/test_grpo_slow.py +++ b/tests/slow/test_grpo_slow.py @@ -263,7 +263,7 @@ def data_gen(num_samples): model = AutoModelForImageTextToText.from_pretrained( model_name, attn_implementation="flash_attention_2", - torch_dtype="bfloat16", + dtype="bfloat16", device_map=get_kbit_device_map(), quantization_config=quantization_config, ) @@ -421,7 +421,7 @@ def dummy_reward_func(completions, **kwargs): model = AutoModelForCausalLM.from_pretrained( "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", quantization_config=quantization_config, - torch_dtype=torch.bfloat16, + dtype=torch.bfloat16, ) trainer = GRPOTrainer( diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index dd34a68e03b..49282e0bac2 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -970,15 +970,15 @@ def test_dpo_lora_force_use_ref(self): # train the model trainer.train() - def test_dpo_trainer_torch_dtype(self): + def test_dpo_trainer_dtype(self): # See https://github.com/huggingface/trl/issues/1751 dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") training_args = DPOConfig( output_dir=self.tmp_dir, per_device_train_batch_size=2, max_steps=1, - model_init_kwargs={"torch_dtype": "float16"}, - ref_model_init_kwargs={"torch_dtype": "float16"}, + model_init_kwargs={"dtype": "float16"}, + ref_model_init_kwargs={"dtype": "float16"}, report_to="none", ) @@ -989,15 +989,15 @@ def test_dpo_trainer_torch_dtype(self): args=training_args, train_dataset=dummy_dataset["train"], ) - self.assertEqual(trainer.model.config.torch_dtype, torch.float16) - self.assertEqual(trainer.ref_model.config.torch_dtype, torch.float16) + self.assertEqual(trainer.model.config.dtype, torch.float16) + self.assertEqual(trainer.ref_model.config.dtype, torch.float16) - # Now test when `torch_dtype` is provided but is wrong to either the model or the ref_model + # Now test when `dtype` is provided but is wrong to either the model or the ref_model training_args = DPOConfig( output_dir=self.tmp_dir, per_device_train_batch_size=2, max_steps=1, - model_init_kwargs={"torch_dtype": -1}, + model_init_kwargs={"dtype": -1}, report_to="none", ) @@ -1010,7 +1010,7 @@ def test_dpo_trainer_torch_dtype(self): ) self.assertIn( - "Invalid `torch_dtype` passed to `DPOConfig`. Expected either 'auto' or a string representing a `torch.dtype` (e.g., 'float32'), but got -1.", + "Invalid `dtype` passed to `DPOConfig`. Expected either 'auto' or a string representing a `torch.dtype` (e.g., 'float32'), but got -1.", str(context.exception), ) @@ -1018,7 +1018,7 @@ def test_dpo_trainer_torch_dtype(self): output_dir=self.tmp_dir, per_device_train_batch_size=2, max_steps=1, - ref_model_init_kwargs={"torch_dtype": -1}, + ref_model_init_kwargs={"dtype": -1}, report_to="none", ) @@ -1032,7 +1032,7 @@ def test_dpo_trainer_torch_dtype(self): ) self.assertIn( - "Invalid `torch_dtype` passed to `DPOConfig`. Expected either 'auto' or a string representing a `torch.dtype` (e.g., 'float32'), but got -1.", + "Invalid `dtype` passed to `DPOConfig`. Expected either 'auto' or a string representing a `torch.dtype` (e.g., 'float32'), but got -1.", str(context.exception), ) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index b8a8f7f530e..0c0f267e600 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -275,7 +275,7 @@ def test_training_peft_with_gradient_checkpointing(self): model = AutoModelForCausalLM.from_pretrained( "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", - torch_dtype=torch.float32, # Use float32 for testing to avoid precision issues + dtype=torch.float32, # Use float32 for testing to avoid precision issues ) lora_config = LoraConfig( diff --git a/tests/test_modeling_value_head.py b/tests/test_modeling_value_head.py index a66b5a16e51..b0a75211175 100644 --- a/tests/test_modeling_value_head.py +++ b/tests/test_modeling_value_head.py @@ -267,7 +267,7 @@ def test_transformers_bf16_kwargs(self): run a dummy forward pass without any issue. """ for model_name in self.all_model_names: - trl_model = self.trl_model_class.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(self.device) + trl_model = self.trl_model_class.from_pretrained(model_name, dtype=torch.bfloat16).to(self.device) lm_head_namings = ["lm_head", "embed_out", "output_layer"] @@ -404,7 +404,7 @@ def test_transformers_bf16_kwargs(self): run a dummy forward pass without any issue. """ for model_name in self.all_model_names: - trl_model = self.trl_model_class.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(self.device) + trl_model = self.trl_model_class.from_pretrained(model_name, dtype=torch.bfloat16).to(self.device) lm_head_namings = self.trl_model_class.lm_head_namings diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index 32dfa083635..4f3a934dbd1 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -167,7 +167,7 @@ def test_training_peft_with_gradient_checkpointing(self): model = AutoModelForCausalLM.from_pretrained( "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", - torch_dtype=torch.float32, # Use float32 for testing to avoid precision issues + dtype=torch.float32, # Use float32 for testing to avoid precision issues ) lora_config = LoraConfig( diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index b76f8ef5233..18a55f172d2 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -358,14 +358,14 @@ def formatting_prompts_func(example): new_param = trainer.model.get_parameter(n) self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") - def test_train_model_torch_dtype(self): + def test_train_model_dtype(self): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") # Initialize the trainer training_args = SFTConfig( output_dir=self.tmp_dir, - model_init_kwargs={"torch_dtype": torch.float16}, + model_init_kwargs={"dtype": torch.float16}, learning_rate=0.1, report_to="none", ) @@ -1294,7 +1294,7 @@ def test_train_vlm_gemma_3n(self): max_length=None, per_device_train_batch_size=1, gradient_checkpointing=True, - model_init_kwargs={"torch_dtype": "bfloat16"}, + model_init_kwargs={"dtype": "bfloat16"}, report_to="none", ) trainer = SFTTrainer(model="google/gemma-3n-E2B-it", args=training_args, train_dataset=dataset) diff --git a/trl/scripts/dpo.py b/trl/scripts/dpo.py index c917a19e64c..d572f498c60 100644 --- a/trl/scripts/dpo.py +++ b/trl/scripts/dpo.py @@ -92,14 +92,12 @@ def main(script_args, training_args, model_args, dataset_args): ################ # Model & Tokenizer ################### - torch_dtype = ( - model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) - ) + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) quantization_config = get_quantization_config(model_args) model_kwargs = dict( revision=model_args.model_revision, attn_implementation=model_args.attn_implementation, - torch_dtype=torch_dtype, + dtype=dtype, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) diff --git a/trl/scripts/sft.py b/trl/scripts/sft.py index db7ce0679e9..73e0145a0fe 100644 --- a/trl/scripts/sft.py +++ b/trl/scripts/sft.py @@ -99,7 +99,7 @@ def main(script_args, training_args, model_args, dataset_args): revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, attn_implementation=model_args.attn_implementation, - torch_dtype=model_args.torch_dtype, + dtype=model_args.dtype, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py index 17ee72ed12a..cf0cfdb5b59 100644 --- a/trl/trainer/bco_trainer.py +++ b/trl/trainer/bco_trainer.py @@ -369,16 +369,16 @@ def __init__( raise ValueError("You passed model_kwargs to the BCOTrainer. But your model is already instantiated.") else: model_init_kwargs = args.model_init_kwargs - torch_dtype = model_init_kwargs.get("torch_dtype") - if torch_dtype is not None: + dtype = model_init_kwargs.get("dtype") + if dtype is not None: # Convert to `torch.dtype` if an str is passed - if isinstance(torch_dtype, str) and torch_dtype != "auto": - torch_dtype = getattr(torch, torch_dtype) - if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype): + if isinstance(dtype, str) and dtype != "auto": + dtype = getattr(torch, dtype) + if dtype != "auto" and not isinstance(dtype, torch.dtype): raise ValueError( - f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}." + f"Invalid `dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}." ) - model_init_kwargs["torch_dtype"] = torch_dtype + model_init_kwargs["dtype"] = dtype if args.ref_model_init_kwargs is None: ref_model_init_kwargs = {} @@ -388,16 +388,16 @@ def __init__( ) else: ref_model_init_kwargs = args.ref_model_init_kwargs - torch_dtype = ref_model_init_kwargs.get("torch_dtype") - if torch_dtype is not None: + dtype = ref_model_init_kwargs.get("dtype") + if dtype is not None: # Convert to `torch.dtype` if an str is passed - if isinstance(torch_dtype, str) and torch_dtype != "auto": - torch_dtype = getattr(torch, torch_dtype) - if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype): + if isinstance(dtype, str) and dtype != "auto": + dtype = getattr(torch, dtype) + if dtype != "auto" and not isinstance(dtype, torch.dtype): raise ValueError( - f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}." + f"Invalid `dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}." ) - ref_model_init_kwargs["torch_dtype"] = torch_dtype + ref_model_init_kwargs["dtype"] = dtype if isinstance(model, str): model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index 0e7878dbb64..2d4272ae0bb 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -136,16 +136,16 @@ def __init__( raise ValueError("You passed model_kwargs to the CPOTrainer. But your model is already instantiated.") else: model_init_kwargs = args.model_init_kwargs - torch_dtype = model_init_kwargs.get("torch_dtype") - if torch_dtype is not None: + dtype = model_init_kwargs.get("dtype") + if dtype is not None: # Convert to `torch.dtype` if an str is passed - if isinstance(torch_dtype, str) and torch_dtype != "auto": - torch_dtype = getattr(torch, torch_dtype) - if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype): + if isinstance(dtype, str) and dtype != "auto": + dtype = getattr(torch, dtype) + if dtype != "auto" and not isinstance(dtype, torch.dtype): raise ValueError( - f"Invalid `torch_dtype` passed to the CPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}." + f"Invalid `dtype` passed to the CPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}." ) - model_init_kwargs["torch_dtype"] = torch_dtype + model_init_kwargs["dtype"] = dtype if isinstance(model, str): model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 19afc4eed3e..2cacb26e3f6 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -528,16 +528,16 @@ def _create_model_from_path(self, model_path: str, args: DPOConfig, is_ref: bool model_init_kwargs = args.ref_model_init_kwargs or {} # Handle torch dtype - torch_dtype = model_init_kwargs.get("torch_dtype") - if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None: - pass # torch_dtype is already a torch.dtype or "auto" or None - elif isinstance(torch_dtype, str): # it's a str, but not "auto" - torch_dtype = getattr(torch, torch_dtype) - model_init_kwargs["torch_dtype"] = torch_dtype + dtype = model_init_kwargs.get("dtype") + if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None: + pass # dtype is already a torch.dtype or "auto" or None + elif isinstance(dtype, str): # it's a str, but not "auto" + dtype = getattr(torch, dtype) + model_init_kwargs["dtype"] = dtype else: raise ValueError( - "Invalid `torch_dtype` passed to `DPOConfig`. Expected either 'auto' or a string representing " - f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." + "Invalid `dtype` passed to `DPOConfig`. Expected either 'auto' or a string representing " + f"a `torch.dtype` (e.g., 'float32'), but got {dtype}." ) # Create model diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py index 4e79fba05c4..9d17f32e417 100644 --- a/trl/trainer/gkd_trainer.py +++ b/trl/trainer/gkd_trainer.py @@ -127,10 +127,10 @@ def __init__( ) else: teacher_model_init_kwargs = args.teacher_model_init_kwargs - teacher_model_init_kwargs["torch_dtype"] = ( - teacher_model_init_kwargs["torch_dtype"] - if teacher_model_init_kwargs["torch_dtype"] in ["auto", None] - else getattr(torch, teacher_model_init_kwargs["torch_dtype"]) + teacher_model_init_kwargs["dtype"] = ( + teacher_model_init_kwargs["dtype"] + if teacher_model_init_kwargs["dtype"] in ["auto", None] + else getattr(torch, teacher_model_init_kwargs["dtype"]) ) if isinstance(teacher_model, str): diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 3bb1a0ab4d7..c5db5e2b559 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -229,16 +229,16 @@ def __init__( model_init_kwargs = args.model_init_kwargs or {} if isinstance(model, str): model_id = model - torch_dtype = model_init_kwargs.get("torch_dtype") - if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None: - pass # torch_dtype is already a torch.dtype or "auto" or None - elif isinstance(torch_dtype, str): # it's a str, but not "auto" - torch_dtype = getattr(torch, torch_dtype) - model_init_kwargs["torch_dtype"] = torch_dtype + dtype = model_init_kwargs.get("dtype") + if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None: + pass # dtype is already a torch.dtype or "auto" or None + elif isinstance(dtype, str): # it's a str, but not "auto" + dtype = getattr(torch, dtype) + model_init_kwargs["dtype"] = dtype else: raise ValueError( - "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing " - f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." + "Invalid `dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing " + f"a `torch.dtype` (e.g., 'float32'), but got {dtype}." ) # Disable caching if gradient checkpointing is enabled (not supported) config = AutoConfig.from_pretrained(model_id) diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index 040d7fd7f3d..017dcf74cb2 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -358,16 +358,16 @@ def __init__( raise ValueError("You passed model_kwargs to the KTOTrainer. But your model is already instantiated.") else: model_init_kwargs = args.model_init_kwargs - torch_dtype = model_init_kwargs.get("torch_dtype") - if torch_dtype is not None: + dtype = model_init_kwargs.get("dtype") + if dtype is not None: # Convert to `torch.dtype` if an str is passed - if isinstance(torch_dtype, str) and torch_dtype != "auto": - torch_dtype = getattr(torch, torch_dtype) - if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype): + if isinstance(dtype, str) and dtype != "auto": + dtype = getattr(torch, dtype) + if dtype != "auto" and not isinstance(dtype, torch.dtype): raise ValueError( - f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}." + f"Invalid `dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}." ) - model_init_kwargs["torch_dtype"] = torch_dtype + model_init_kwargs["dtype"] = dtype if args.ref_model_init_kwargs is None: ref_model_init_kwargs = {} @@ -377,16 +377,16 @@ def __init__( ) else: ref_model_init_kwargs = args.ref_model_init_kwargs - torch_dtype = ref_model_init_kwargs.get("torch_dtype") - if torch_dtype is not None: + dtype = ref_model_init_kwargs.get("dtype") + if dtype is not None: # Convert to `torch.dtype` if an str is passed - if isinstance(torch_dtype, str) and torch_dtype != "auto": - torch_dtype = getattr(torch, torch_dtype) - if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype): + if isinstance(dtype, str) and dtype != "auto": + dtype = getattr(torch, dtype) + if dtype != "auto" and not isinstance(dtype, torch.dtype): raise ValueError( - f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}." + f"Invalid `dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}." ) - ref_model_init_kwargs["torch_dtype"] = torch_dtype + ref_model_init_kwargs["dtype"] = dtype if isinstance(model, str): model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) diff --git a/trl/trainer/model_config.py b/trl/trainer/model_config.py index dcaa8d0508f..7b818c6790e 100644 --- a/trl/trainer/model_config.py +++ b/trl/trainer/model_config.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from dataclasses import dataclass, field from typing import Optional @@ -30,7 +31,7 @@ class ModelConfig: Model checkpoint for weights initialization. model_revision (`str`, *optional*, defaults to `"main"`): Specific model version to use. It can be a branch name, a tag name, or a commit id. - torch_dtype (`Literal["auto", "bfloat16", "float16", "float32"]` or `None`, *optional*, defaults to `None`): + dtype (`Literal["auto", "bfloat16", "float16", "float32"]` or `None`, *optional*, defaults to `None`): Override the default `torch.dtype` and load the model under this dtype. Possible values are - `"bfloat16"`: `torch.bfloat16` @@ -89,7 +90,7 @@ class ModelConfig: default="main", metadata={"help": "Specific model version to use. It can be a branch name, a tag name, or a commit id."}, ) - torch_dtype: Optional[str] = field( + dtype: Optional[str] = field( default=None, metadata={ "help": "Override the default `torch.dtype` and load the model under this dtype.", @@ -176,10 +177,25 @@ class ModelConfig: default=False, metadata={"help": "Whether to use nested quantization."}, ) + # Deprecated params + torch_dtype: Optional[str] = field( + default=None, + metadata={ + "help": "Override the default `torch.dtype` and load the model under this dtype.", + "choices": ["auto", "bfloat16", "float16", "float32"], + }, + ) def __post_init__(self): if self.load_in_8bit and self.load_in_4bit: raise ValueError("You can't use 8 bit and 4 bit precision at the same time") + if self.torch_dtype and not self.dtype: + warnings.warn( + "`torch_dtype` is deprecated and will be removed in version 0.27.0, please use `dtype` instead.", + DeprecationWarning, + ) + self.dtype = self.torch_dtype + if hasattr(self.lora_target_modules, "__len__") and len(self.lora_target_modules) == 1: self.lora_target_modules = self.lora_target_modules[0] diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index cf657d29651..f70624130c9 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -195,17 +195,17 @@ def __init__( if isinstance(model, str): model_id = model - # Handle torch_dtype in model_init_kwargs - torch_dtype = model_init_kwargs.get("torch_dtype") - if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None: + # Handle dtype in model_init_kwargs + dtype = model_init_kwargs.get("dtype") + if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None: pass - elif isinstance(torch_dtype, str): - torch_dtype = getattr(torch, torch_dtype) - model_init_kwargs["torch_dtype"] = torch_dtype + elif isinstance(dtype, str): + dtype = getattr(torch, dtype) + model_init_kwargs["dtype"] = dtype else: raise ValueError( - "Invalid `torch_dtype` passed to `OnlineDPOConfig`. Expected either 'auto' or a string " - f"representing a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." + "Invalid `dtype` passed to `OnlineDPOConfig`. Expected either 'auto' or a string " + f"representing a `torch.dtype` (e.g., 'float32'), but got {dtype}." ) model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs) diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 92a1f77d5cf..1efbd86d110 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -140,16 +140,16 @@ def __init__( raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.") else: model_init_kwargs = args.model_init_kwargs - torch_dtype = model_init_kwargs.get("torch_dtype") - if torch_dtype is not None: + dtype = model_init_kwargs.get("dtype") + if dtype is not None: # Convert to `torch.dtype` if an str is passed - if isinstance(torch_dtype, str) and torch_dtype != "auto": - torch_dtype = getattr(torch, torch_dtype) - if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype): + if isinstance(dtype, str) and dtype != "auto": + dtype = getattr(torch, dtype) + if dtype != "auto" and not isinstance(dtype, torch.dtype): raise ValueError( - f"Invalid `torch_dtype` passed to the ORPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}." + f"Invalid `dtype` passed to the ORPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}." ) - model_init_kwargs["torch_dtype"] = torch_dtype + model_init_kwargs["dtype"] = dtype if isinstance(model, str): model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 2e126bff450..4c02b327cd7 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -299,16 +299,16 @@ def decode(example, tokenizer): model_init_kwargs = args.model_init_kwargs or {} if isinstance(model, str): model_id = model - torch_dtype = model_init_kwargs.get("torch_dtype") - if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None: - pass # torch_dtype is already a torch.dtype or "auto" or None - elif isinstance(torch_dtype, str): # it's a str, but not "auto" - torch_dtype = getattr(torch, torch_dtype) - model_init_kwargs["torch_dtype"] = torch_dtype + dtype = model_init_kwargs.get("dtype") + if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None: + pass # dtype is already a torch.dtype or "auto" or None + elif isinstance(dtype, str): # it's a str, but not "auto" + dtype = getattr(torch, dtype) + model_init_kwargs["dtype"] = dtype else: raise ValueError( - "Invalid `torch_dtype` passed to `RLOOConfig`. Expected either 'auto' or a string representing " - f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." + "Invalid `dtype` passed to `RLOOConfig`. Expected either 'auto' or a string representing " + f"a `torch.dtype` (e.g., 'float32'), but got {dtype}." ) # Disable caching if gradient checkpointing is enabled (not supported) config = AutoConfig.from_pretrained(model_id) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index c06911332fd..afcd7c4c8ff 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -586,16 +586,16 @@ def __init__( model_init_kwargs = args.model_init_kwargs or {} if isinstance(model, str): model_id = model - torch_dtype = model_init_kwargs.get("torch_dtype") - if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None: - pass # torch_dtype is already a torch.dtype or "auto" or None - elif isinstance(torch_dtype, str) and torch_dtype in ["bfloat16", "float16", "float32"]: - torch_dtype = getattr(torch, torch_dtype) - model_init_kwargs["torch_dtype"] = torch_dtype + dtype = model_init_kwargs.get("dtype") + if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None: + pass # dtype is already a torch.dtype or "auto" or None + elif isinstance(dtype, str) and dtype in ["bfloat16", "float16", "float32"]: + dtype = getattr(torch, dtype) + model_init_kwargs["dtype"] = dtype else: raise ValueError( - "Invalid `torch_dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing " - f"a valid `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." + "Invalid `dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing " + f"a valid `torch.dtype` (e.g., 'float32'), but got {dtype}." ) config = AutoConfig.from_pretrained(model_id) architecture = getattr(transformers, config.architectures[0]) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 54cea9ea98b..14963851159 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -600,10 +600,10 @@ def get_quantization_config(model_args: ModelConfig) -> Optional[BitsAndBytesCon if model_args.load_in_4bit: quantization_config = BitsAndBytesConfig( load_in_4bit=True, - bnb_4bit_compute_dtype=model_args.torch_dtype, # For consistency with model weights, we use the same value as `torch_dtype` + bnb_4bit_compute_dtype=model_args.dtype, # For consistency with model weights, we use the same value as `dtype` bnb_4bit_quant_type=model_args.bnb_4bit_quant_type, bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant, - bnb_4bit_quant_storage=model_args.torch_dtype, + bnb_4bit_quant_storage=model_args.dtype, ) elif model_args.load_in_8bit: quantization_config = BitsAndBytesConfig(