Commit 7a0c547
committed
Improve error handling in _can_use_flash_attention for better debugging
Enhanced the _can_use_flash_attention function to provide more detailed
error messages when flash attention compatibility checks fail.
Changes:
- Replace generic exception catching with specific error propagation
- When raise_error=True, directly re-raise original exceptions from
check_layout() and check_is_flash_attention() functions
- Preserve detailed error context from JAX internal validation functions
- Maintain existing behavior when raise_error=False (returns False)
This improves debugging experience by surfacing specific technical details
about tensor layout incompatibilities, cuDNN version requirements, and
other flash attention compatibility issues.
Relates to keras-hub PR #2257 and addresses flash attention debugging needs.1 parent 579cc11 commit 7a0c547
1 file changed
+2
-2
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1072 | 1072 | | |
1073 | 1073 | | |
1074 | 1074 | | |
1075 | | - | |
| 1075 | + | |
1076 | 1076 | | |
1077 | | - | |
| 1077 | + | |
1078 | 1078 | | |
1079 | 1079 | | |
1080 | 1080 | | |
| |||
0 commit comments