Skip to content

Conversation

@mapmeld
Copy link
Contributor

@mapmeld mapmeld commented May 12, 2022

What does this PR do?

When calling model.generate(content, log_decoder=True), the PR would log which decoder and warper(s) are actually used in generation.

I have a demo where I show text generated with different options (top_k, typical_p, repetition_penalty, num_beams, etc.). The final chosen decoding strategy is not obvious. It is tricky to test by comparing outputs because a generative model often returns different text on multiple runs.
By design the function tolerates mistakes -- if there is a missing arg (typical_p=0.5 but no do_sample=True) or mismatched value (typical_p=3) or typo'd arg (numBeams=2) then the function silently chooses another decoding strategy. The code does not flag these because the remaining **kwargs are passed to the model.
I believe the logger is the best place to check whether decoding actually happened as expected.

Example usage: https://colab.research.google.com/drive/1DpMnZkSCtZIiaONoxfzYxYI4vgiTNYLN?usp=sharing

  • The first commit is unnecessary thanks to docs for typical decoding #17186 Rebased on this PR and adding one additional section to the documentation about typical decoding
  • I'm open to renaming or removing log_decoder to always do logger.info in these places
  • If we always do logger.info, I could move logger calls into BeamSearchScorer. Trying to avoid adding too many args
  • Could use logger.warn if these issues warrant it

Discussion: https://discuss.huggingface.co/t/logging-which-decoder-selected-in-generation/18133

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented May 12, 2022

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten
Copy link
Contributor

Hey @mapmeld,

Thanks for the PR.

To me, that is a bit too much of an edge-case and I'm not very happy with cluttering the generation code with if - else statements.

@gante @patil-suraj what do you think?

@gante
Copy link
Contributor

gante commented Jun 3, 2022

Hey @mapmeld! Thank you for the PR 👍

I'm also not a fan of all the if statements on a function whose complexity is already over the top. Perhaps we could remove all the if branches, keep the logging statements, but lower their logging level to debug. That way, a user could get all those values by setting the appropriate logging level, and it would be invisible in the vast majority of cases.

WDYT?

@mapmeld
Copy link
Contributor Author

mapmeld commented Jun 3, 2022

@patrickvonplaten @gante That makes sense to me, log.debug level, no extra argument. I've made a commit for that

@patil-suraj
Copy link
Contributor

Agree with @gante 's comment, using logger.debug and getting rid of those if-else statements sounds good to me.
I'm okay with having these loggings to make it more obvious which method is being used, will be useful in debugging IMO.

Copy link
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 👍

(to get rid of CI errors, rebase with main -- some issues were fixed since you opened the PR)

@patrickvonplaten
Copy link
Contributor

Sorry, I think I wasn't super clear in my last message.

Personally, I would prefer to not merge this PR because:

  • the generation code is already very complex and hard to read (talking about the code-reading part here not what's displayed to the user), don't think adding 5,6 new logger statement lines help here
  • How do users know that generate should be run in debug mode to display the logging statements - don't think many users will realize this
  • If the decoding strategy is not obvious, we should improve the docs IMO
  • If the user doesn't know what top_p does, I don't think she/he would know that a TopPLogitsWarper is -> don't see the added value of displaying the names in a logger here
  • Also not in line with how we use the logger in other places across the library

@mapmeld
Copy link
Contributor Author

mapmeld commented Jun 10, 2022

OK, will close then.

If I can suggest changes beyond logging to this section, here are some ideas:

  • throwing exceptions in the current code if a decoding argument (typical_p) is ignored because of an unusable value or missing companion argument (do_sample=True)
  • adding an argument to generate() naming the intended decoder, so it is clear in end-user code, and transformers can throw an exception for calls which don't go down the expected path for whatever reason
  • specific decoding functions to replace the general generate(), where these functions can throw exceptions / using Python type hints / be more useful in code auto-complete tools
  • implementing typical decoding in TensorFlow so there's more similarity between Torch and TensorFlow code

@mapmeld mapmeld closed this Jun 10, 2022
@patrickvonplaten
Copy link
Contributor

Thanks a lot @mapmeld - those are really nice suggestions! Also after some discussion we think it could make a lot of sense to do maybe the following:

  • If kwargs are passed to generate that don't exist than we throw a warning so a user is well aware if something is misspelled.
  • Really like the idea of warning the user if an argument is used that cannot be activated - wondering if there is a good approach that would not force us to make a lot of if .... statements in generate. Any ideas how this could be checked in a very concise way?

@patrickvonplaten
Copy link
Contributor

Also keen to hear suggestions from @gante :-)

@gante
Copy link
Contributor

gante commented Jun 14, 2022

implementing typical decoding in TensorFlow so there's more similarity between Torch and TensorFlow code

(@mapmeld) Yeah, we are working on it :D TF generate should have a big release soon.

Really like the idea of warning the user if an argument is used that cannot be activated - wondering if there is a good approach that would not force us to make a lot of if .... statements in generate. Any ideas how this could be checked in a very concise way?

(@patrickvonplaten) Without if's and else's, the cleanest solution would possibly be to hold some dictionary with all passed arguments, in addition to a set of accepted arguments for each generation type, and raise an exception with all unexpected arguments (e.g. The passed arguments triggered greedy_search. However, for greedy_search, following arguments are not accepted: top_p. Please check the documentation here [link]). We can actually implement it with a small effort -- the dictionary with all arguments is locals() at the start of the specific generation functions (e.g. greedy_search()) and the set of accepted arguments is the function signature except **model_kwargs. We can get the accepted model_kwargs from the model forward signature (it's not quite, but should be close enough) -- everything else that remains in **model_kwards is an unused parameter and should raise an exception.

WDYT?

@patrickvonplaten
Copy link
Contributor

In a first step I was rather thinking about just warning the user if parameters are passed in kwargs that are not used (probs misspelled)

@patrickvonplaten
Copy link
Contributor

Adding sub-generation specific logging logic sounds very complex, would be open if we find a clean, concise solution but at the moment I'd like to prevent adding hardcoded lists of which generation parameter is relevant for which sub generation method (also hard to maintain)

@patrickvonplaten
Copy link
Contributor

@gante the solution sounds interesting - would need to see a PR for it to fully understand it. The problem I see is that we won't detect unnecessary generation parameters since they are inside logits_processor and logits_warper

@patrickvonplaten
Copy link
Contributor

Overall, also just want to say here that IMO two mistakes were made a while back:

  • We've set defaults for some values which we should have never done IMO (max_length and top_k) have defaults which is quite counter productive for good logging
  • We have allowed people to set generation parameters inside the config to which the method defaults to - in the aftermath this was too much "black-magic" and not at all visible/understandable for (new) users.

Will be very hard to remedy these things without breaking backward comp, but open to suggestions / comments!

@mapmeld
Copy link
Contributor Author

mapmeld commented Jun 17, 2022

Would it be possible for us to talk about it in the HF Slack? I would be interested in finding a part of this where I can contribute

@patrickvonplaten
Copy link
Contributor

Invited you :-) Let's chat on Slack

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants