-
Notifications
You must be signed in to change notification settings - Fork 31.7k
Description
In generate, we document
length_penalty: Exponential penalty to the length. Default to 1.Given the name and the docstring, you might expect that if you increase the length_penalty your model will, on average, produce shorter generations.
You would be wrong! (at least for bart-large-xsum)
When we decide the score of a hypothesis here, we calculate
score = sum_logprobs / len(hyp) ** self.length_penaltyThe issue is that the numerator, sum_logprobs, is negative (the result of F.log_softmax), and the denominator, len(hyp) ** self.length_penalty, is positive. If we increase length_penalty we increase the denominator (and the derivative of the denominator w.r.t length) and therefore make the score less negative, so greater.
Fairseq has the same logic.
I can think of two groups of solutions:
- keep the name and change the code so that length is actually penalized:
denominator = len(hyp) ** self.length_penalty
if numerator < 0: denominator *= -1- Change the name/docstring to something like
len_adjustmentand explain that increasing it is likely to make generations shorter.
@yjernite @patrickvonplaten @LysandreJik @thomwolf, have you guys seen this/do you think it's worth fixing or redocumenting?
Empirical Evidence
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-xsum')
tok = BartTokenizer.from_pretrained("facebook/bart-large")
PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
batch = tok.batch_encode_plus([PGE_ARTICLE], max_length=1024, pad_to_max_length=True, return_tensors="pt",)
ids_lp1 = model.generate(**batch, length_penalty=1.)
ids_lp2 = model.generate(**batch, length_penalty=2.)
text_a, text_b = [tok.batch_decode(x, skip_special_tokens=True,)[0] for x in [ids_lp1, ids_lp2]]text a:
"California's largest power company, PG&E, has shut off power to tens of thousands of customers across the state."
text_b:
"California's largest power company, PG&E, has shut off power to tens of thousands of homes and businesses in the north-east of the state."
I found similar results for bart-large-cnn.