-
Notifications
You must be signed in to change notification settings - Fork 558
[chronos-2] add support for SDPA #331
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
7bf335a
1cf40ac
5bab24d
362513a
06a1aad
661679c
8fca666
5b4a90e
23ebe35
70e1b10
63a8a3c
f5540d2
011181a
ff2515c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -155,6 +155,7 @@ def __init__(self, config: Chronos2CoreConfig, use_rope: bool = True): | |
| self.n_heads: int = config.num_heads | ||
| self.dropout: float = config.dropout_rate | ||
| self.inner_dim: int = self.n_heads * self.kv_proj_dim | ||
| self.config = config | ||
|
|
||
| self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) | ||
| self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) | ||
|
|
@@ -165,6 +166,123 @@ def __init__(self, config: Chronos2CoreConfig, use_rope: bool = True): | |
| if use_rope: | ||
| self.rope_embed = RoPE(dim=self.kv_proj_dim, base=config.rope_theta) | ||
|
|
||
| def _eager_attention( | ||
| self, | ||
| query_states: torch.Tensor, | ||
| key_states: torch.Tensor, | ||
| value_states: torch.Tensor, | ||
| mask: torch.Tensor, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| """Eager attention implementation using manual matmul. | ||
|
|
||
| Args: | ||
| query_states: [batch, n_heads, seq_len, kv_proj_dim] | ||
| key_states: [batch, n_heads, seq_len, kv_proj_dim] | ||
| value_states: [batch, n_heads, seq_len, kv_proj_dim] | ||
| mask: [batch, n_heads, q_len, kv_len] | ||
|
|
||
| Returns: | ||
| attn_output: [batch, n_heads, seq_len, kv_proj_dim] | ||
| attn_weights: [batch, n_heads, q_len, kv_len] | ||
| """ | ||
| # Compute attention weights (no scaling - this is the original Chronos-2 implementation) | ||
| scores = torch.matmul(query_states, key_states.transpose(3, 2)) # "bnqd,bnkd->bnqk" | ||
| scores += mask | ||
| attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) | ||
| attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) | ||
| attn_output = torch.matmul(attn_weights, value_states) | ||
|
|
||
| return attn_output, attn_weights | ||
|
|
||
| def _sdpa_attention( | ||
| self, | ||
| query_states: torch.Tensor, | ||
| key_states: torch.Tensor, | ||
| value_states: torch.Tensor, | ||
| mask: torch.Tensor, | ||
| ) -> tuple[torch.Tensor, None]: | ||
| """SDPA attention implementation using torch.nn.functional.scaled_dot_product_attention. | ||
|
|
||
| Args: | ||
| query_states: [batch, n_heads, seq_len, kv_proj_dim] | ||
| key_states: [batch, n_heads, seq_len, kv_proj_dim] | ||
| value_states: [batch, n_heads, seq_len, kv_proj_dim] | ||
| mask: [batch, n_heads, q_len, kv_len] - additive mask (0 for valid, -inf for invalid) | ||
|
|
||
| Returns: | ||
| attn_output: [batch, n_heads, seq_len, kv_proj_dim] | ||
| attn_weights: None (SDPA doesn't return weights) | ||
| """ | ||
| attn_output = nn.functional.scaled_dot_product_attention( | ||
| query_states, | ||
| key_states, | ||
| value_states, | ||
| attn_mask=mask, | ||
| dropout_p=self.dropout if self.training else 0.0, | ||
| scale=1.0, # Match eager implementation (no scaling) | ||
| ) | ||
|
|
||
| return attn_output, None | ||
|
|
||
| def _flash_attention_2( | ||
kashif marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self, | ||
| query_states: torch.Tensor, | ||
| key_states: torch.Tensor, | ||
| value_states: torch.Tensor, | ||
| mask: torch.Tensor, | ||
| ) -> tuple[torch.Tensor, None]: | ||
| """FlashAttention-2 implementation. | ||
|
|
||
| Args: | ||
| query_states: [batch, n_heads, seq_len, kv_proj_dim] | ||
| key_states: [batch, n_heads, seq_len, kv_proj_dim] | ||
| value_states: [batch, n_heads, seq_len, kv_proj_dim] | ||
| mask: [batch, n_heads, q_len, kv_len] | ||
|
|
||
| Returns: | ||
| attn_output: [batch, n_heads, seq_len, kv_proj_dim] | ||
| attn_weights: None (FlashAttention doesn't return weights) | ||
| """ | ||
| try: | ||
| from flash_attn import flash_attn_func | ||
| except ImportError: | ||
| raise ImportError( | ||
| "FlashAttention-2 is not installed. Please install it with: " | ||
| "pip install flash-attn --no-build-isolation" | ||
| ) | ||
|
|
||
| # FlashAttention expects inputs in shape [batch, seq_len, n_heads, head_dim] | ||
| # We have [batch, n_heads, seq_len, head_dim], so we need to transpose | ||
| query_states = query_states.transpose(1, 2) | ||
| key_states = key_states.transpose(1, 2) | ||
| value_states = value_states.transpose(1, 2) | ||
|
|
||
| # FlashAttention only supports fp16 and bf16 | ||
| input_dtype = query_states.dtype | ||
| if input_dtype not in [torch.float16, torch.bfloat16]: | ||
| target_dtype = torch.float16 if torch.cuda.is_available() else torch.bfloat16 | ||
| query_states = query_states.to(target_dtype) | ||
| key_states = key_states.to(target_dtype) | ||
| value_states = value_states.to(target_dtype) | ||
|
|
||
| attn_output = flash_attn_func( | ||
| query_states, | ||
| key_states, | ||
| value_states, | ||
| dropout_p=self.dropout if self.training else 0.0, | ||
| softmax_scale=1.0, # Match eager implementation (no scaling) | ||
| causal=False, # Chronos uses bidirectional attention by default | ||
kashif marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ) | ||
|
|
||
| # Convert back to original dtype if needed | ||
| if input_dtype not in [torch.float16, torch.bfloat16]: | ||
| attn_output = attn_output.to(input_dtype) | ||
|
|
||
| # Transpose back to [batch, n_heads, seq_len, head_dim] | ||
| attn_output = attn_output.transpose(1, 2) | ||
|
|
||
| return attn_output, None | ||
|
|
||
| def forward( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
|
|
@@ -190,6 +308,11 @@ def forward( | |
| if self.use_rope: | ||
| assert position_ids is not None, "position_ids must be provided when self.use_rope=True" | ||
|
|
||
| # Force eager attention if output_attentions is True (only eager returns weights) | ||
| attn_implementation = self.config._attn_implementation | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this need to access the private
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems for now this is the convention see e.g. https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L277-L278 |
||
| if output_attentions and attn_implementation != "eager": | ||
kashif marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| attn_implementation = "eager" | ||
|
|
||
| seq_length = hidden_states.shape[1] | ||
|
|
||
| def shape(states: torch.Tensor) -> torch.Tensor: | ||
|
|
@@ -215,12 +338,13 @@ def unshape(states: torch.Tensor) -> torch.Tensor: | |
| cos, sin = self.rope_embed(value_states, position_ids) | ||
| query_states, key_states = RoPE.apply_rotary_pos_emb(query_states, key_states, cos, sin) | ||
|
|
||
| # Compute attention weights | ||
| scores = torch.matmul(query_states, key_states.transpose(3, 2)) # "bnqd,bnkd->bnqk" | ||
| scores += mask | ||
| attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) | ||
| attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) | ||
| attn_output = torch.matmul(attn_weights, value_states) | ||
| # Dispatch to appropriate attention implementation | ||
| if attn_implementation == "sdpa": | ||
| attn_output, attn_weights = self._sdpa_attention(query_states, key_states, value_states, mask) | ||
| elif attn_implementation == "flash_attention_2": | ||
| attn_output, attn_weights = self._flash_attention_2(query_states, key_states, value_states, mask) | ||
| else: # eager or default | ||
| attn_output, attn_weights = self._eager_attention(query_states, key_states, value_states, mask) | ||
|
|
||
| # Project attention output | ||
| attn_output = unshape(attn_output) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -211,7 +211,6 @@ def fit( | |
| lr_scheduler_type="linear", | ||
| warmup_ratio=0.0, | ||
| optim="adamw_torch_fused", | ||
| logging_dir=str(output_dir / "logs"), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this removed?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logging_dir has been removed from the training arguments and the different report_to backends handle the logging within the output dir
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will this work fine for the older
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes we can safely not set this argument and it will work for older transformer versions... (as well as newer) |
||
| logging_strategy="steps", | ||
| logging_steps=100, | ||
| disable_tqdm=False, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.