Skip to content

fix llama model text generation error#1402

Merged
regisss merged 1 commit intohuggingface:mainfrom
zongwave:main
Oct 10, 2024
Merged

fix llama model text generation error#1402
regisss merged 1 commit intohuggingface:mainfrom
zongwave:main

Conversation

@zongwave
Copy link
Copy Markdown
Contributor

@zongwave zongwave commented Oct 8, 2024

What does this PR do?

PR #1359 introduced error in some models text generation:
’meta-llama/Llama-2-7b-hf‘
‘mistralai/Mistral-7B-Instruct-v0.2'
’Qwen/Qwen2-7B‘
’bigcode/starcoder2-15b‘

Reproduce command:
python examples/text-generation/run_generation.py --use_hpu_graphs --model_name_or_path meta-llama/Llama-2-7b-hf

Input/outputs:
input 1: ('DeepSpeed is a machine learning framework',)
output 1: ('DeepSpeed is a machine learning frameworkЉЉЉЉЉЉЉЉЉЉЉЉЉЉЉ\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n.............................\n\n............\n......',)

Fixes # (issue)

Only use slicing operation on hidden_states for logits computing in case of kv_cache and reuse_cache is enabled.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@zongwave zongwave requested a review from a user October 8, 2024 04:30
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# TODO: remove the float() operation in v4.46
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
if reuse_cache:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also check if mixtral and mistral needs this code.
Prior to this fix for command: pytest -s -v tests/transformers/tests/models/ -k "contrastive_generate_dynamic_shapes" , we saw three failures:

FAILED tests/transformers/tests/models/llama/test_modeling_llama.py::LlamaModelTest::test_contrastive_generate_dynamic_shapes - AssertionError: Lists differ: [[47, 95, 92, 35, 83, 82]] != [[47, 95, 92, 35, 66, 8]] 
FAILED tests/transformers/tests/models/mistral/test_modeling_mistral.py::MistralModelTest::test_contrastive_generate_dynamic_shapes - AssertionError: Lists differ: [[55, 43, 83, 75, 15, 60]] != [[55, 43, 83, 19, 15, 94]] 
FAILED tests/transformers/tests/models/mixtral/test_modeling_mixtral.py::MixtralModelTest::test_contrastive_generate_dynamic_shapes - AssertionError: Lists differ: [[41, 73, 42, 57, 65, 26]] != [[41, 73, 42, 17, 89, 3]] 
==================================================================== 3 failed, 7 passed, 1425 deselected, 15 warnings in 19.74s ==

They pass if I add the same change to mixtral and mistral:

========================================================================= 10 passed, 1425 deselected, 15 warnings in 22.51s =========================================================================

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added same change to mixtral, mistral, qwen, starcoder and phi, resolved the text generation on these models:
‘mistralai/Mistral-7B-Instruct-v0.2'
’Qwen/Qwen2.5-Coder-1.5B‘
’bigcode/starcoder2-15b‘

Copy link
Copy Markdown
Contributor

@vidyasiv vidyasiv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your fix @zongwave. Your work also fixes some tests as mentioned in the comment. Please update mixtral and mistral too if applicable

@vidyasiv
Copy link
Copy Markdown
Contributor

vidyasiv commented Oct 9, 2024

@zongwave, thanks for updates to models that needed it. Could you run and paste output of the transformers fast tests?
@jiminha could you also take a look at this PR?

@jiminha
Copy link
Copy Markdown
Contributor

jiminha commented Oct 9, 2024

@ssarkar2 @libinta please review this as well.

@jiminha
Copy link
Copy Markdown
Contributor

jiminha commented Oct 9, 2024

@vidyasiv @zongwave Are there more model files need to be updated? Should we actually don't set this num_logits_to_keep in generation/utils.py file if reuse_cache is not being used, so we don't need to update all model files? (default value is 0)

@vidyasiv
Copy link
Copy Markdown
Contributor

vidyasiv commented Oct 9, 2024

@vidyasiv @zongwave Are there more model files need to be updated? Should we actually don't set this num_logits_to_keep in generation/utils.py file if reuse_cache is not being used, so we don't need to update all model files? (default value is 0)

Not sure if this is best way but a quick search shows the following models have this line logits=self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()

  1. gemma (maybe need to update)
  2. llama
  3. mistral
  4. mixtral
  5. phi
  6. qwen2
  7. starcoder2

But I am not sure if other models like t5 or gpt2 need this, they use a different implementation i.e don't index into anything.

Another list based on reuse_cache usage across models:

  1. falcon(?) : lm_logits = self.lm_head(hidden_states) at link
  2. gemma (maybe need to update)
  3. gptj (?) : lm_logits = self.lm_head(hidden_states).to(torch.float32) at link
  4. llama
  5. mistral
  6. mixtral
  7. phi
  8. qwen2
  9. qwen2_moe(?) : logits = self.lm_head(hidden_states) at link
  10. starcoder2

Does seem like gemma needs it but not sure about the others.

@jiminha
Copy link
Copy Markdown
Contributor

jiminha commented Oct 9, 2024

@vidyasiv @zongwave : Discussed further with @ssarkar2 , and this logic is actually exactly same as our --trim_logits. We should actually remove num_logits_to_keep from HPU and go back to old logic( logits = self.lm_head(hidden_states) ) since this will interfere with our --trim_logits and also --use_hpu_graphs. We might not gain anything here with num_logits_to_keep as long as --trim_logits are on.

We can later work on to combine the trim_logits and num_logits_to_keep to one so it's easier for later transformer upmerge.
I think for now we should just remove this slicing with num_logits_to_keep logic here.

@regisss could you also check on my command above?

@zongwave
Copy link
Copy Markdown
Contributor Author

@vidyasiv @zongwave : Discussed further with @ssarkar2 , and this logic is actually exactly same as our --trim_logits. We should actually remove num_logits_to_keep from HPU and go back to old logic( logits = self.lm_head(hidden_states) ) since this will interfere with our --trim_logits and also --use_hpu_graphs. We might not gain anything here with num_logits_to_keep as long as --trim_logits are on.

We can later work on to combine the trim_logits and num_logits_to_keep to one so it's easier for later transformer upmerge. I think for now we should just remove this slicing with num_logits_to_keep logic here.

@regisss could you also check on my command above?

I removed "num_logis_to_keep" indexing for "hidden_states" in PR #1359
But keep flag "num_logis_to_keep" in the code, seems this flag is not being used...

# No upscaling to float was ever done for Persimmon
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
logits = self.lm_head(hidden_states)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This model doesn't seem to have trim_logits, we can notify the author to add the trim logits support.

use trim_logits in HPU to save memory (comment out the num_logits_to_keep in utils.py)
@jiminha
Copy link
Copy Markdown
Contributor

jiminha commented Oct 10, 2024

For now, we decided to just modify utils.py file not to set this num_logits_to_keep to 1, so it doesn't make any effect on the run.(default is 0). We should revisit this to see if we can merge this new feature with trim_logits together. It seems HPU call flows are all different with combination of these three arguments (use_hpu_graphs, trim_logits, reuse_cache), so we should test all this flow and make sure these are all working.

@vidyasiv
Copy link
Copy Markdown
Contributor

For now, we decided to just modify utils.py file not to set this num_logits_to_keep to 1, so it doesn't make any effect on the run.(default is 0). We should revisit this to see if we can merge this new feature with trim_logits together. It seems HPU call flows are all different with combination of these three arguments (use_hpu_graphs, trim_logits, reuse_cache), so we should test all this flow and make sure these are all working.

I verified output for one model for these cases but we should be adding a test for relevant model architectures for functional accuracy:

python examples/text-generation/run_generation.py --use_hpu_graphs --model_name_or_path meta-llama/Llama-2-7b-hf
python examples/text-generation/run_generation.py --use_hpu_graphs --model_name_or_path meta-llama/Llama-2-7b-hf --trim_logits
python examples/text-generation/run_generation.py --use_hpu_graphs --model_name_or_path meta-llama/Llama-2-7b-hf --trim_logits --reuse_cache
python examples/text-generation/run_generation.py --use_hpu_graphs --model_name_or_path meta-llama/Llama-2-7b-hf  --reuse_cache

@libinta libinta added the run-test Run CI for PRs from external contributors label Oct 10, 2024
@vidyasiv
Copy link
Copy Markdown
Contributor

vidyasiv commented Oct 10, 2024

@regisss , @jiminha is it possible to add something like #1411. It is a mockup but am not sure of the variability of output across SW versions. If we are sure it should be same I can continue that PR.
We have seen many functional failures with text gen that the perf tests miss catching or are caught only by slow tests so it would be good to have a fast test that can do this.

@regisss
Copy link
Copy Markdown
Collaborator

regisss commented Oct 10, 2024

@regisss , @jiminha is it possible to add something like #1411. It is a mockup but am not sure of the variability of output across SW versions. If we are sure it should be same I can continue that PR. We have seen many functional failures with text gen that the perf tests miss catching or are caught only by slow tests so it would be good to have a fast test that can do this.

Definitely useful! I've got this in my mind for a while and I fully agree that testing the outputs becomes more and more critical. Happy to review this PR when it's ready 🙂

@jiminha
Copy link
Copy Markdown
Contributor

jiminha commented Oct 10, 2024

@regisss , @jiminha is it possible to add something like #1411. It is a mockup but am not sure of the variability of output across SW versions. If we are sure it should be same I can continue that PR. We have seen many functional failures with text gen that the perf tests miss catching or are caught only by slow tests so it would be good to have a fast test that can do this.

@vidyasiv I haven't checked out the details of the code, but this is absolutely needed. We are only testing the perf for text-gen , but no output token generation, and missed the opportunity to catch the bug early on. I would prefer though to extend the current text-gen tests to check at least 1st output token rather than adding a new one unless there is specific reason to have a separate test just for accuracy check.

@vidyasiv
Copy link
Copy Markdown
Contributor

@regisss , @jiminha is it possible to add something like #1411. It is a mockup but am not sure of the variability of output across SW versions. If we are sure it should be same I can continue that PR. We have seen many functional failures with text gen that the perf tests miss catching or are caught only by slow tests so it would be good to have a fast test that can do this.

@vidyasiv I haven't checked out the details of the code, but this is absolutely needed. We are only testing the perf for text-gen , but no output token generation, and missed the opportunity to catch the bug early on. I would prefer though to extend the current text-gen tests to check at least 1st output token rather than adding a new one unless there is specific reason to have a separate test just for accuracy check.

The main reason to make it separate is current text gen tests are slow test and sometimes cover more than 1 case per model. We need a fast test that functionally checks select key models so it can run with every PR. It's not going to be exhaustive but we have to start someplace so thought of basing it on this particular failure- next time we fix a functional issue we can update this test file with relevant options etc.

@regisss regisss merged commit ff8ff43 into huggingface:main Oct 10, 2024
regisss pushed a commit that referenced this pull request Oct 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

run-test Run CI for PRs from external contributors

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants