Commit 9b861cd
authored
bugfix: fix merge_attention_state in BatchAttention w/ gqa-group-size in Qwen family (#1614)
<!-- .github/pull_request_template.md -->
## 📌 Description
<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->
This PR fixes precision issues of BatchAttention (Persistent FA2 of
#1137), when
`CTA_TILE_Q` is not a multiple of `gqa_group_size` (e.g., Qwen family
models). Prior implementation assumes that all `qo_heads` of a `kv_head`
on a specific token will all be split-kv or non-split-kv. However, when
`gqa-group-size == 7`, some `qo_heads` can be non-split while the
remaining can be split.
## 🔍 Related Issues
<!-- Link any related issues here -->
## 🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.
### ✅ Pre-commit Checks
- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).
## 🧪 Tests
- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).
## Reviewer Notes
<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->
cc @Edenzzzz1 parent 8e926de commit 9b861cd
File tree
3 files changed
+30
-16
lines changed- include/flashinfer/attention
- tests
3 files changed
+30
-16
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
269 | 269 | | |
270 | 270 | | |
271 | 271 | | |
272 | | - | |
| 272 | + | |
| 273 | + | |
273 | 274 | | |
274 | 275 | | |
275 | 276 | | |
| |||
517 | 518 | | |
518 | 519 | | |
519 | 520 | | |
| 521 | + | |
520 | 522 | | |
521 | 523 | | |
522 | 524 | | |
523 | 525 | | |
524 | 526 | | |
525 | 527 | | |
526 | 528 | | |
527 | | - | |
528 | | - | |
529 | | - | |
530 | 529 | | |
531 | 530 | | |
532 | 531 | | |
533 | 532 | | |
534 | 533 | | |
535 | | - | |
536 | | - | |
| 534 | + | |
| 535 | + | |
537 | 536 | | |
538 | 537 | | |
539 | 538 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1235 | 1235 | | |
1236 | 1236 | | |
1237 | 1237 | | |
1238 | | - | |
1239 | | - | |
| 1238 | + | |
| 1239 | + | |
| 1240 | + | |
| 1241 | + | |
| 1242 | + | |
1240 | 1243 | | |
1241 | 1244 | | |
1242 | 1245 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
57 | 57 | | |
58 | 58 | | |
59 | 59 | | |
| 60 | + | |
| 61 | + | |
60 | 62 | | |
61 | | - | |
62 | | - | |
63 | | - | |
64 | | - | |
65 | | - | |
66 | 63 | | |
67 | 64 | | |
68 | 65 | | |
69 | 66 | | |
70 | 67 | | |
| 68 | + | |
71 | 69 | | |
72 | 70 | | |
73 | 71 | | |
| |||
142 | 140 | | |
143 | 141 | | |
144 | 142 | | |
145 | | - | |
| 143 | + | |
146 | 144 | | |
147 | 145 | | |
148 | 146 | | |
| |||
190 | 188 | | |
191 | 189 | | |
192 | 190 | | |
193 | | - | |
194 | | - | |
| 191 | + | |
| 192 | + | |
195 | 193 | | |
196 | 194 | | |
197 | 195 | | |
| |||
225 | 223 | | |
226 | 224 | | |
227 | 225 | | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
0 commit comments