Commit b673c16
Fix mask slicing for models with HybridCache (#35681)
* correctly slice
* check mask
* Update modular_gemma2.py
* fix
* add tests
* fix typo
* finally fix mask slicing
* Finally correctly slice in all cases!!
* add test for all attention functions
* small fix in tests
* trick around dynamo tracing issue
* last update
* more robust
* kwargs propagation
* make it explicit for checkpointing
* apply modular1 parent aa3e590 commit b673c16
File tree
6 files changed
+232
-22
lines changed- src/transformers/models
- cohere2
- gemma2
- tests/models
- cohere2
- gemma2
6 files changed
+232
-22
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
255 | 255 | | |
256 | 256 | | |
257 | 257 | | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
258 | 263 | | |
259 | 264 | | |
260 | 265 | | |
| |||
318 | 323 | | |
319 | 324 | | |
320 | 325 | | |
| 326 | + | |
321 | 327 | | |
322 | 328 | | |
323 | 329 | | |
| |||
338 | 344 | | |
339 | 345 | | |
340 | 346 | | |
| 347 | + | |
341 | 348 | | |
342 | 349 | | |
343 | 350 | | |
344 | | - | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
345 | 355 | | |
346 | | - | |
347 | | - | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
348 | 359 | | |
349 | 360 | | |
350 | 361 | | |
351 | 362 | | |
352 | 363 | | |
353 | 364 | | |
354 | | - | |
355 | | - | |
| 365 | + | |
| 366 | + | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
356 | 371 | | |
357 | 372 | | |
358 | 373 | | |
| |||
551 | 566 | | |
552 | 567 | | |
553 | 568 | | |
| 569 | + | |
554 | 570 | | |
555 | 571 | | |
556 | 572 | | |
| |||
590 | 606 | | |
591 | 607 | | |
592 | 608 | | |
| 609 | + | |
| 610 | + | |
| 611 | + | |
| 612 | + | |
| 613 | + | |
| 614 | + | |
| 615 | + | |
| 616 | + | |
| 617 | + | |
| 618 | + | |
593 | 619 | | |
594 | 620 | | |
595 | 621 | | |
| 622 | + | |
596 | 623 | | |
597 | 624 | | |
598 | 625 | | |
| |||
616 | 643 | | |
617 | 644 | | |
618 | 645 | | |
| 646 | + | |
619 | 647 | | |
620 | 648 | | |
621 | 649 | | |
| |||
626 | 654 | | |
627 | 655 | | |
628 | 656 | | |
| 657 | + | |
629 | 658 | | |
630 | 659 | | |
631 | 660 | | |
| |||
908 | 937 | | |
909 | 938 | | |
910 | 939 | | |
| 940 | + | |
| 941 | + | |
| 942 | + | |
| 943 | + | |
911 | 944 | | |
912 | 945 | | |
913 | 946 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
296 | 296 | | |
297 | 297 | | |
298 | 298 | | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
299 | 304 | | |
300 | 305 | | |
301 | 306 | | |
| |||
340 | 345 | | |
341 | 346 | | |
342 | 347 | | |
| 348 | + | |
343 | 349 | | |
344 | 350 | | |
345 | 351 | | |
| |||
360 | 366 | | |
361 | 367 | | |
362 | 368 | | |
| 369 | + | |
363 | 370 | | |
364 | 371 | | |
365 | 372 | | |
366 | | - | |
| 373 | + | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
367 | 377 | | |
368 | | - | |
369 | | - | |
| 378 | + | |
| 379 | + | |
| 380 | + | |
370 | 381 | | |
371 | 382 | | |
372 | 383 | | |
373 | 384 | | |
374 | 385 | | |
375 | 386 | | |
376 | | - | |
377 | | - | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
| 392 | + | |
378 | 393 | | |
379 | 394 | | |
380 | 395 | | |
| |||
434 | 449 | | |
435 | 450 | | |
436 | 451 | | |
| 452 | + | |
437 | 453 | | |
438 | 454 | | |
439 | 455 | | |
| |||
473 | 489 | | |
474 | 490 | | |
475 | 491 | | |
| 492 | + | |
| 493 | + | |
| 494 | + | |
| 495 | + | |
| 496 | + | |
| 497 | + | |
| 498 | + | |
| 499 | + | |
| 500 | + | |
| 501 | + | |
476 | 502 | | |
477 | 503 | | |
478 | 504 | | |
| 505 | + | |
479 | 506 | | |
480 | 507 | | |
481 | 508 | | |
| |||
499 | 526 | | |
500 | 527 | | |
501 | 528 | | |
| 529 | + | |
502 | 530 | | |
503 | 531 | | |
504 | 532 | | |
| |||
509 | 537 | | |
510 | 538 | | |
511 | 539 | | |
| 540 | + | |
512 | 541 | | |
513 | 542 | | |
514 | 543 | | |
| |||
578 | 607 | | |
579 | 608 | | |
580 | 609 | | |
| 610 | + | |
| 611 | + | |
| 612 | + | |
| 613 | + | |
581 | 614 | | |
582 | 615 | | |
583 | 616 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
220 | 220 | | |
221 | 221 | | |
222 | 222 | | |
223 | | - | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
224 | 229 | | |
225 | 230 | | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
226 | 236 | | |
227 | 237 | | |
228 | 238 | | |
| |||
276 | 286 | | |
277 | 287 | | |
278 | 288 | | |
| 289 | + | |
| 290 | + | |
279 | 291 | | |
280 | 292 | | |
281 | | - | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
282 | 297 | | |
283 | | - | |
284 | | - | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
285 | 301 | | |
286 | 302 | | |
287 | 303 | | |
288 | 304 | | |
289 | 305 | | |
290 | 306 | | |
291 | | - | |
292 | | - | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
293 | 313 | | |
294 | 314 | | |
295 | 315 | | |
| |||
305 | 325 | | |
306 | 326 | | |
307 | 327 | | |
| 328 | + | |
308 | 329 | | |
309 | 330 | | |
310 | 331 | | |
| |||
549 | 570 | | |
550 | 571 | | |
551 | 572 | | |
| 573 | + | |
552 | 574 | | |
553 | 575 | | |
554 | 576 | | |
| |||
589 | 611 | | |
590 | 612 | | |
591 | 613 | | |
| 614 | + | |
| 615 | + | |
| 616 | + | |
| 617 | + | |
| 618 | + | |
| 619 | + | |
| 620 | + | |
| 621 | + | |
| 622 | + | |
| 623 | + | |
592 | 624 | | |
593 | 625 | | |
594 | 626 | | |
| |||
624 | 656 | | |
625 | 657 | | |
626 | 658 | | |
| 659 | + | |
627 | 660 | | |
628 | 661 | | |
629 | 662 | | |
| |||
635 | 668 | | |
636 | 669 | | |
637 | 670 | | |
| 671 | + | |
638 | 672 | | |
639 | 673 | | |
640 | 674 | | |
| |||
850 | 884 | | |
851 | 885 | | |
852 | 886 | | |
| 887 | + | |
853 | 888 | | |
854 | 889 | | |
855 | 890 | | |
| |||
918 | 953 | | |
919 | 954 | | |
920 | 955 | | |
| 956 | + | |
| 957 | + | |
| 958 | + | |
| 959 | + | |
921 | 960 | | |
922 | 961 | | |
923 | 962 | | |
| |||
0 commit comments