Skip to content

Commit 23c0062

Browse files
Hotfix: Fall back to config.text_config._name_or_path if missing config._name_or_path (#4324)
1 parent 47b1aa7 commit 23c0062

File tree

9 files changed

+46
-24
lines changed

9 files changed

+46
-24
lines changed

trl/trainer/base_trainer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from transformers import Trainer, is_wandb_available
1919

20-
from .utils import generate_model_card, get_comet_experiment_url
20+
from .utils import generate_model_card, get_comet_experiment_url, get_config_model_id
2121

2222

2323
if is_wandb_available():
@@ -50,8 +50,9 @@ def create_model_card(
5050
if not self.is_world_process_zero():
5151
return
5252

53-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
54-
base_model = self.model.config._name_or_path
53+
model_name_or_path = get_config_model_id(self.model.config)
54+
if model_name_or_path and not os.path.isdir(model_name_or_path):
55+
base_model = model_name_or_path
5556
else:
5657
base_model = None
5758

trl/trainer/callbacks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from ..mergekit_utils import MergeConfig, merge_models, upload_model_to_hf
4141
from ..models.utils import unwrap_model_for_generation
4242
from .judges import BasePairwiseJudge
43-
from .utils import log_table_to_comet_experiment
43+
from .utils import get_config_model_id, log_table_to_comet_experiment
4444

4545

4646
if is_rich_available():
@@ -821,7 +821,7 @@ def _merge_and_maybe_push(self, output_dir, global_step, model):
821821
checkpoint_path = os.path.join(output_dir, f"checkpoint-{global_step}")
822822
self.merge_config.policy_model_path = checkpoint_path
823823
if self.merge_config.target_model_path is None:
824-
self.merge_config.target_model_path = model.config._name_or_path
824+
self.merge_config.target_model_path = get_config_model_id(model.config)
825825
merge_path = os.path.join(checkpoint_path, "merged")
826826

827827
merge_models(self.merge_config.create(), merge_path)

trl/trainer/dpo_trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
empty_cache,
6666
flush_left,
6767
flush_right,
68+
get_config_model_id,
6869
log_table_to_comet_experiment,
6970
pad,
7071
pad_to_length,
@@ -286,7 +287,7 @@ def __init__(
286287
):
287288
# Args
288289
if args is None:
289-
model_name = model if isinstance(model, str) else model.config._name_or_path
290+
model_name = model if isinstance(model, str) else get_config_model_id(model.config)
290291
model_name = model_name.split("/")[-1]
291292
args = DPOConfig(f"{model_name}-DPO")
292293

@@ -299,7 +300,7 @@ def __init__(
299300
"You passed `model_init_kwargs` to the `DPOConfig`, but your model is already instantiated. "
300301
"The `model_init_kwargs` will be ignored."
301302
)
302-
model_id = model.config._name_or_path
303+
model_id = get_config_model_id(model.config)
303304
if isinstance(ref_model, str):
304305
ref_model = create_model_from_path(ref_model, **args.ref_model_init_kwargs or {})
305306
else:

trl/trainer/grpo_trainer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
disable_dropout_in_model,
6767
ensure_master_addr_port,
6868
entropy_from_logits,
69+
get_config_model_id,
6970
identity,
7071
nanmax,
7172
nanmin,
@@ -245,7 +246,7 @@ def __init__(
245246
):
246247
# Args
247248
if args is None:
248-
model_name = model if isinstance(model, str) else model.config._name_or_path
249+
model_name = model if isinstance(model, str) else get_config_model_id(model.config)
249250
model_name = model_name.split("/")[-1]
250251
args = GRPOConfig(f"{model_name}-GRPO")
251252

@@ -270,7 +271,7 @@ def __init__(
270271
architecture = getattr(transformers, config.architectures[0])
271272
model = architecture.from_pretrained(model_id, **model_init_kwargs)
272273
else:
273-
model_id = model.config._name_or_path
274+
model_id = get_config_model_id(model.config)
274275
if args.model_init_kwargs is not None:
275276
logger.warning(
276277
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
@@ -290,7 +291,7 @@ def __init__(
290291

291292
# Processing class
292293
if processing_class is None:
293-
processing_class = AutoProcessor.from_pretrained(model.config._name_or_path, truncation_side="left")
294+
processing_class = AutoProcessor.from_pretrained(get_config_model_id(model.config), truncation_side="left")
294295

295296
# Handle pad token for processors or tokenizers
296297
if isinstance(processing_class, ProcessorMixin):
@@ -317,7 +318,7 @@ def __init__(
317318
reward_func, num_labels=1, **model_init_kwargs
318319
)
319320
if isinstance(reward_funcs[i], nn.Module): # Use Module over PretrainedModel for compat w/ compiled models
320-
self.reward_func_names.append(reward_funcs[i].config._name_or_path.split("/")[-1])
321+
self.reward_func_names.append(get_config_model_id(reward_funcs[i].config).split("/")[-1])
321322
else:
322323
self.reward_func_names.append(reward_funcs[i].__name__)
323324
self.reward_funcs = reward_funcs
@@ -347,7 +348,7 @@ def __init__(
347348
for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
348349
if isinstance(reward_func, PreTrainedModel):
349350
if reward_processing_class is None:
350-
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
351+
reward_processing_class = AutoTokenizer.from_pretrained(get_config_model_id(reward_func.config))
351352
if reward_processing_class.pad_token_id is None:
352353
reward_processing_class.pad_token = reward_processing_class.eos_token
353354
# The reward model computes the reward for the latest non-padded token in the input sequence.

trl/trainer/online_dpo_trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
disable_dropout_in_model,
7575
empty_cache,
7676
ensure_master_addr_port,
77+
get_config_model_id,
7778
pad,
7879
truncate_right,
7980
)
@@ -243,7 +244,7 @@ def __init__(
243244
reward_func, num_labels=1, **model_init_kwargs
244245
)
245246
if isinstance(reward_funcs[i], nn.Module):
246-
self.reward_func_names.append(reward_funcs[i].config._name_or_path.split("/")[-1])
247+
self.reward_func_names.append(get_config_model_id(reward_funcs[i].config).split("/")[-1])
247248
else:
248249
self.reward_func_names.append(reward_funcs[i].__name__)
249250
self.reward_funcs = reward_funcs

trl/trainer/reward_trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from ..models import clone_chat_template, get_act_offloading_ctx_manager, prepare_peft_model
4545
from .base_trainer import BaseTrainer
4646
from .reward_config import RewardConfig
47-
from .utils import disable_dropout_in_model, pad, remove_none_values
47+
from .utils import disable_dropout_in_model, get_config_model_id, pad, remove_none_values
4848

4949

5050
if is_peft_available():
@@ -273,7 +273,7 @@ def __init__(
273273
):
274274
# Args
275275
if args is None:
276-
model_name = model if isinstance(model, str) else model.config._name_or_path
276+
model_name = model if isinstance(model, str) else get_config_model_id(model.config)
277277
model_name = model_name.split("/")[-1]
278278
args = RewardConfig(f"{model_name}-Reward")
279279

@@ -294,7 +294,7 @@ def __init__(
294294
with suppress_from_pretrained_warning(transformers.modeling_utils.logger):
295295
model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=1, **model_init_kwargs)
296296
else:
297-
model_id = model.config._name_or_path
297+
model_id = get_config_model_id(model.config)
298298
if args.model_init_kwargs is not None:
299299
logger.warning(
300300
"You passed `model_init_kwargs` to the `RewardConfig`, but your model is already instantiated. "

trl/trainer/rloo_trainer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
disable_dropout_in_model,
6666
ensure_master_addr_port,
6767
entropy_from_logits,
68+
get_config_model_id,
6869
identity,
6970
nanmax,
7071
nanmin,
@@ -240,7 +241,7 @@ def __init__(
240241

241242
# Args
242243
if args is None:
243-
model_name = model if isinstance(model, str) else model.config._name_or_path
244+
model_name = model if isinstance(model, str) else get_config_model_id(model.config)
244245
model_name = model_name.split("/")[-1]
245246
args = RLOOConfig(f"{model_name}-RLOO")
246247

@@ -265,7 +266,7 @@ def __init__(
265266
architecture = getattr(transformers, config.architectures[0])
266267
model = architecture.from_pretrained(model_id, **model_init_kwargs)
267268
else:
268-
model_id = model.config._name_or_path
269+
model_id = get_config_model_id(model.config)
269270
if args.model_init_kwargs is not None:
270271
logger.warning(
271272
"You passed `model_init_kwargs` to the `RLOOConfig`, but your model is already instantiated. "
@@ -285,7 +286,7 @@ def __init__(
285286

286287
# Processing class
287288
if processing_class is None:
288-
processing_class = AutoProcessor.from_pretrained(model.config._name_or_path, truncation_side="left")
289+
processing_class = AutoProcessor.from_pretrained(get_config_model_id(model.config), truncation_side="left")
289290

290291
# Handle pad token for processors or tokenizers
291292
if isinstance(processing_class, ProcessorMixin):
@@ -312,7 +313,7 @@ def __init__(
312313
reward_func, num_labels=1, **model_init_kwargs
313314
)
314315
if isinstance(reward_funcs[i], nn.Module): # Use Module over PretrainedModel for compat w/ compiled models
315-
self.reward_func_names.append(reward_funcs[i].config._name_or_path.split("/")[-1])
316+
self.reward_func_names.append(get_config_model_id(reward_funcs[i].config).split("/")[-1])
316317
else:
317318
self.reward_func_names.append(reward_funcs[i].__name__)
318319
self.reward_funcs = reward_funcs
@@ -342,7 +343,7 @@ def __init__(
342343
for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
343344
if isinstance(reward_func, PreTrainedModel):
344345
if reward_processing_class is None:
345-
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
346+
reward_processing_class = AutoTokenizer.from_pretrained(get_config_model_id(reward_func.config))
346347
if reward_processing_class.pad_token_id is None:
347348
reward_processing_class.pad_token = reward_processing_class.eos_token
348349
# The reward model computes the reward for the latest non-padded token in the input sequence.

trl/trainer/sft_trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
create_model_from_path,
5555
entropy_from_logits,
5656
flush_left,
57+
get_config_model_id,
5758
pad,
5859
remove_none_values,
5960
selective_log_softmax,
@@ -590,7 +591,7 @@ def __init__(
590591
):
591592
# Args
592593
if args is None:
593-
model_name = model if isinstance(model, str) else model.config._name_or_path
594+
model_name = model if isinstance(model, str) else get_config_model_id(model.config)
594595
model_name = model_name.split("/")[-1]
595596
args = SFTConfig(f"{model_name}-SFT")
596597
elif isinstance(args, TrainingArguments) and not isinstance(args, SFTConfig):
@@ -608,11 +609,10 @@ def __init__(
608609
"You passed `model_init_kwargs` to the `SFTConfig`, but your model is already instantiated. "
609610
"The `model_init_kwargs` will be ignored."
610611
)
611-
model_id = model.config._name_or_path
612612

613613
# Processing class
614614
if processing_class is None:
615-
processing_class = AutoProcessor.from_pretrained(model_id)
615+
processing_class = AutoProcessor.from_pretrained(get_config_model_id(model.config))
616616

617617
# Handle pad token for processors or tokenizers
618618
if isinstance(processing_class, ProcessorMixin):

trl/trainer/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
BitsAndBytesConfig,
4242
EvalPrediction,
4343
GenerationConfig,
44+
PretrainedConfig,
4445
PreTrainedModel,
4546
PreTrainedTokenizerBase,
4647
TrainerState,
@@ -1962,3 +1963,19 @@ def create_model_from_path(model_id: str, **kwargs) -> PreTrainedModel:
19621963
architecture = getattr(transformers, config.architectures[0])
19631964
model = architecture.from_pretrained(model_id, **kwargs)
19641965
return model
1966+
1967+
1968+
def get_config_model_id(config: PretrainedConfig) -> str:
1969+
"""
1970+
Retrieve the model identifier from a given model configuration.
1971+
1972+
Args:
1973+
config ([`~transformers.PreTrainedConfig`]):
1974+
Configuration from which to extract the model identifier.
1975+
1976+
Returns:
1977+
`str`:
1978+
The model identifier associated with the model configuration.
1979+
"""
1980+
# Fall back to `config.text_config._name_or_path` if `config._name_or_path` is missing: Qwen2-VL and Qwen2.5-VL. See GH-4323
1981+
return getattr(config, "_name_or_path", "") or getattr(getattr(config, "text_config", None), "_name_or_path", "")

0 commit comments

Comments
 (0)