Skip to content

Commit 7a0c547

Browse files
committed
Improve error handling in _can_use_flash_attention for better debugging
Enhanced the _can_use_flash_attention function to provide more detailed error messages when flash attention compatibility checks fail. Changes: - Replace generic exception catching with specific error propagation - When raise_error=True, directly re-raise original exceptions from check_layout() and check_is_flash_attention() functions - Preserve detailed error context from JAX internal validation functions - Maintain existing behavior when raise_error=False (returns False) This improves debugging experience by surfacing specific technical details about tensor layout incompatibilities, cuDNN version requirements, and other flash attention compatibility issues. Relates to keras-hub PR #2257 and addresses flash attention debugging needs.
1 parent 579cc11 commit 7a0c547

File tree

1 file changed

+2
-2
lines changed
  • keras/src/backend/jax

1 file changed

+2
-2
lines changed

keras/src/backend/jax/nn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,9 +1072,9 @@ def _can_use_flash_attention(query, key, value, bias, raise_error=False):
10721072
is_training=False,
10731073
)
10741074
return True
1075-
except:
1075+
except Exception as e:
10761076
if raise_error:
1077-
raise
1077+
raise e
10781078
return False
10791079

10801080

0 commit comments

Comments
 (0)