Skip to content

[generate] Increasing length_penalty makes generations longer #4915

@sshleifer

Description

@sshleifer

In generate, we document

length_penalty: Exponential penalty to the length. Default to 1.

Given the name and the docstring, you might expect that if you increase the length_penalty your model will, on average, produce shorter generations.
You would be wrong! (at least for bart-large-xsum)

When we decide the score of a hypothesis here, we calculate

score = sum_logprobs / len(hyp) ** self.length_penalty

The issue is that the numerator, sum_logprobs, is negative (the result of F.log_softmax), and the denominator, len(hyp) ** self.length_penalty, is positive. If we increase length_penalty we increase the denominator (and the derivative of the denominator w.r.t length) and therefore make the score less negative, so greater.

Fairseq has the same logic.

I can think of two groups of solutions:

  1. keep the name and change the code so that length is actually penalized:
denominator = len(hyp) ** self.length_penalty
if numerator < 0: denominator *= -1
  1. Change the name/docstring to something like len_adjustment and explain that increasing it is likely to make generations shorter.

@yjernite @patrickvonplaten @LysandreJik @thomwolf, have you guys seen this/do you think it's worth fixing or redocumenting?

Empirical Evidence

model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-xsum')
tok = BartTokenizer.from_pretrained("facebook/bart-large")
PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
batch = tok.batch_encode_plus([PGE_ARTICLE], max_length=1024, pad_to_max_length=True, return_tensors="pt",)

ids_lp1 = model.generate(**batch, length_penalty=1.)
ids_lp2 = model.generate(**batch, length_penalty=2.)
text_a, text_b = [tok.batch_decode(x, skip_special_tokens=True,)[0] for x in [ids_lp1, ids_lp2]]

text a:

"California's largest power company, PG&E, has shut off power to tens of thousands of customers across the state."

text_b:

"California's largest power company, PG&E, has shut off power to tens of thousands of homes and businesses in the north-east of the state."

I found similar results for bart-large-cnn.

Metadata

Metadata

Assignees

No one assigned

    Labels

    DiscussionDiscussion on a topic (keep it focused or open a new issue though)seq2seqwontfix

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions