-
Notifications
You must be signed in to change notification settings - Fork 593
unittest: Add head dim 256 test cases and mark as xfail #1999
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughTests updated to parameterize batch-decode tests with a new Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Comment |
| head_dim, | ||
| ): | ||
| pytest.xfail("trtllm-gen decode gets incorrect output with head_dim = 256") | ||
| test_trtllm_batch_decode( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's unusual practice (but it's totally okay to do so and it's not introduced in this PR) to call one test_* function instead another one, all top-level functions with prefix test_ will be treated as standalone unittests.
Can we create a function _test_trtllm_batch_decode as the common body of these unittests, instead of calling another top-level test_trtllm_batch_decode function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @yzh119, I think this PR is a good opportunity to do make this change.
I have:
- renamed
test_trtllm_batch_decodeto_test_trtllm_batch_decodeas a base function. - There are test functions that call
_test_trtllm_batch_decodewith a group of parameter combinations:test_trtllm_batch_decode--> 1632 existing parameter combinationstest_trtllm_batch_decode_bs1--> 1 xfail case with batch size 1test_trtllm_batch_decode_head_dim_256--> 40 xfail cases with head_dim=256.test_trtllm_batch_decode_long_sequence_length--> 48 cases of long seqlen.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The long seqlen was added because I saw #1968 and tested what happens if try testing long seqlens. We start to see failures starting from 4k
1993033 to
097308f
Compare
yzh119
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
/bot run |
|
[SUCCESS] Pipeline #37485772: 13/17 passed |
📌 Description
Adding unit test for
head_dim=256cases for trtllm-gen decode and marking them as xfail.Renames
test_trtllm_batch_decodeto_test_trtllm_batch_decodeas a base function. Test functions now call _test_trtllm_batch_decode with a group of parameter combinations:test_trtllm_batch_decode--> 1632 existing parameter combinationstest_trtllm_batch_decode_bs1--> 1 xfail case with batch size 1test_trtllm_batch_decode_head_dim_256--> 40 xfail cases with head_dim=256.test_trtllm_batch_decode_long_sequence_length--> 48 cases of long seqlen.🔍 Related Issues
#1993
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit