Skip to content

Conversation

@sgugger
Copy link
Collaborator

@sgugger sgugger commented Nov 18, 2020

What does this PR do?

As discovered since merging #8530, sometimes (e.g. when using nvidia apex with the O2 optimization) the new model outputs lose their type and become regular dictionaries. This means we can't index into them with integers and some rework in the internals of Trainer has become necessary.

This PR:

  • fixes the training by indexing in the outputs by string if they are dict, int otherwise when grabbing the loss
  • fixes the evaluation by indexing in the outputs by string if they are dict, int otherwise when grabbing the loss

but it also takes advantage of the new dict outputs to better filter the outputs at inference. We had several issues recently when using models outputing past states (such as Reformer, XLNet, GPT-2) during evaluation in Trainer. This PR introduces a new API that looks at a possible key in the config of the model to get some attributes to ignore in the ouputs during evaluation (those outputs are then discarded from the predictions returned by the function Trainer.predict or passed along to metric computation in Trainer.evaluate). Since a user might have some use cases where they want to ignore more keys or output those keys, a new argument is added to both Trainer.predict and Trainer.evaluate to fully control the keys ignored in those dictionaries.

If the model outputs tuple, this is all ignored.

Fixes #8523 among others

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very welcome change imo, and the implementation is clean. Thank you for implementing the last test, I think it's great.

@sgugger sgugger merged commit 4208f49 into master Nov 19, 2020
@sgugger sgugger deleted the trainer_outputs branch November 19, 2020 15:43
"""

model_type = "marian"
keys_to_ignore_at_inference = ["past_key_values"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit late now, but I'm not a huge fan of the name to be honest -> this seems to be very specific to training, but one might think now that past_key_values can never be passed during inference in general. Why not call it keys_to_ignore_at_training?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No this is not for training, only for inference. During training we only get the loss in the outputs.
And this is not ignore to pass to the model, but ignore because they are not part of the logits/scores/predictions we want to gather. Maybe output_keys_to_ignore_at_inference is clearer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see! Yeah I think output_keys_to_ignore_at_inference would be a bit clearer to me :-)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #8857

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Reformer model crashes during casual LM evaluation

4 participants