Skip to content

Commit 928f589

Browse files
authored
Fix: add_generation_prompt=True for conversational only (#4362)
1 parent b0889d2 commit 928f589

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

trl/trainer/grpo_trainer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,13 +1300,16 @@ def _generate_single_turn(self, prompts: list):
13001300
"padding": True,
13011301
"padding_side": "left",
13021302
"max_length": self.max_prompt_length,
1303-
"add_generation_prompt": True,
13041303
"truncation": True,
13051304
"add_special_tokens": False,
13061305
}
13071306
if is_conversational({"prompt": prompts[0]}):
13081307
generate_inputs = self.processing_class.apply_chat_template(
1309-
conversation=prompts, **processor_kwargs, tokenize=True, return_dict=True
1308+
conversation=prompts,
1309+
**processor_kwargs,
1310+
add_generation_prompt=True,
1311+
tokenize=True,
1312+
return_dict=True,
13101313
)
13111314
else:
13121315
generate_inputs = self.processing_class(text=prompts, **processor_kwargs)

trl/trainer/rloo_trainer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,13 +1126,16 @@ def _generate_single_turn(self, prompts: list):
11261126
"padding": True,
11271127
"padding_side": "left",
11281128
"max_length": self.max_prompt_length,
1129-
"add_generation_prompt": True,
11301129
"truncation": True,
11311130
"add_special_tokens": False,
11321131
}
11331132
if is_conversational({"prompt": prompts[0]}):
11341133
generate_inputs = self.processing_class.apply_chat_template(
1135-
conversation=prompts, **processor_kwargs, tokenize=True, return_dict=True
1134+
conversation=prompts,
1135+
**processor_kwargs,
1136+
add_generation_kwargs=True,
1137+
tokenize=True,
1138+
return_dict=True,
11361139
)
11371140
else:
11381141
generate_inputs = self.processing_class(text=prompts, **processor_kwargs)

0 commit comments

Comments
 (0)