Skip to content

Commit ca9c327

Browse files
kashifshchurabdulfatir
authored
[chronos-2] add support for SDPA (#331)
This pull request introduces configurable attention backends to the Chronos-2 model, allowing users to select between eager, SDPA, and FlashAttention-2 implementations. --------- Co-authored-by: Oleksandr Shchur <oleks.shchur@gmail.com> Co-authored-by: Abdul Fatir <Abdulfatirs@gmail.com>
1 parent 0c51188 commit ca9c327

6 files changed

Lines changed: 212 additions & 15 deletions

File tree

scripts/training/train.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,6 @@ def main(
663663
lr_scheduler_type=lr_scheduler_type,
664664
warmup_ratio=warmup_ratio,
665665
optim=optim,
666-
logging_dir=str(output_dir / "logs"),
667666
logging_strategy="steps",
668667
logging_steps=log_steps,
669668
save_strategy="steps",

src/chronos/chronos2/config.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# Authors: Abdul Fatir Ansari <ansarnd@amazon.com>
55

66
from dataclasses import dataclass
7-
from typing import List
7+
from typing import List, Literal
88

99
from transformers.configuration_utils import PretrainedConfig
1010

@@ -39,6 +39,8 @@ class Chronos2CoreConfig(PretrainedConfig):
3939
Token ID for padding/missing value token, by default 0
4040
rope_theta
4141
The base theta for rotary position embedding (RoPE), by default 10000.0
42+
attn_implementation
43+
The attention implementation to use. Options: "eager" or "sdpa", by default None (uses "sdpa")
4244
"""
4345

4446
model_type = "t5"
@@ -63,6 +65,7 @@ def __init__(
6365
vocab_size: int = 2,
6466
pad_token_id: int = 0,
6567
rope_theta: float = 10000.0,
68+
attn_implementation: Literal["eager", "sdpa"] | None = None,
6669
**kwargs,
6770
):
6871
self.vocab_size = vocab_size
@@ -83,11 +86,17 @@ def __init__(
8386

8487
assert not self.is_gated_act, "gated activation is not supported"
8588

89+
# Attention implementation - default to "sdpa" if not specified
90+
attn_implementation = attn_implementation or "sdpa"
91+
assert attn_implementation in ["eager", "sdpa"], f"attn_implementation {attn_implementation} not supported"
92+
8693
# unused
8794
kwargs.pop("is_encoder_decoder", None)
8895
kwargs.pop("eos_token_id", None)
8996

90-
super().__init__(pad_token_id=pad_token_id, is_encoder_decoder=False, **kwargs)
97+
super().__init__(
98+
pad_token_id=pad_token_id, is_encoder_decoder=False, attn_implementation=attn_implementation, **kwargs
99+
)
91100

92101

93102
@dataclass

src/chronos/chronos2/layers.py

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def __init__(self, config: Chronos2CoreConfig, use_rope: bool = True):
155155
self.n_heads: int = config.num_heads
156156
self.dropout: float = config.dropout_rate
157157
self.inner_dim: int = self.n_heads * self.kv_proj_dim
158+
self.config = config
158159

159160
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
160161
self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
@@ -165,6 +166,64 @@ def __init__(self, config: Chronos2CoreConfig, use_rope: bool = True):
165166
if use_rope:
166167
self.rope_embed = RoPE(dim=self.kv_proj_dim, base=config.rope_theta)
167168

169+
def _eager_attention(
170+
self,
171+
query_states: torch.Tensor,
172+
key_states: torch.Tensor,
173+
value_states: torch.Tensor,
174+
mask: torch.Tensor,
175+
) -> tuple[torch.Tensor, torch.Tensor]:
176+
"""Eager attention implementation using manual matmul.
177+
178+
Args:
179+
query_states: [batch, n_heads, seq_len, kv_proj_dim]
180+
key_states: [batch, n_heads, seq_len, kv_proj_dim]
181+
value_states: [batch, n_heads, seq_len, kv_proj_dim]
182+
mask: [batch, n_heads, q_len, kv_len]
183+
184+
Returns:
185+
attn_output: [batch, n_heads, seq_len, kv_proj_dim]
186+
attn_weights: [batch, n_heads, q_len, kv_len]
187+
"""
188+
# Compute attention weights (no scaling - this is the original Chronos-2 implementation)
189+
scores = torch.matmul(query_states, key_states.transpose(3, 2)) # "bnqd,bnkd->bnqk"
190+
scores += mask
191+
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
192+
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
193+
attn_output = torch.matmul(attn_weights, value_states)
194+
195+
return attn_output, attn_weights
196+
197+
def _sdpa_attention(
198+
self,
199+
query_states: torch.Tensor,
200+
key_states: torch.Tensor,
201+
value_states: torch.Tensor,
202+
mask: torch.Tensor,
203+
) -> tuple[torch.Tensor, None]:
204+
"""SDPA attention implementation using torch.nn.functional.scaled_dot_product_attention.
205+
206+
Args:
207+
query_states: [batch, n_heads, seq_len, kv_proj_dim]
208+
key_states: [batch, n_heads, seq_len, kv_proj_dim]
209+
value_states: [batch, n_heads, seq_len, kv_proj_dim]
210+
mask: [batch, n_heads, q_len, kv_len] - additive mask (0 for valid, -inf for invalid)
211+
212+
Returns:
213+
attn_output: [batch, n_heads, seq_len, kv_proj_dim]
214+
attn_weights: None (SDPA doesn't return weights)
215+
"""
216+
attn_output = nn.functional.scaled_dot_product_attention(
217+
query_states,
218+
key_states,
219+
value_states,
220+
attn_mask=mask,
221+
dropout_p=self.dropout if self.training else 0.0,
222+
scale=1.0, # Match eager implementation (no scaling)
223+
)
224+
225+
return attn_output, None
226+
168227
def forward(
169228
self,
170229
hidden_states: torch.Tensor,
@@ -190,6 +249,11 @@ def forward(
190249
if self.use_rope:
191250
assert position_ids is not None, "position_ids must be provided when self.use_rope=True"
192251

252+
# Force eager attention if output_attentions is True (only eager returns weights)
253+
attn_implementation = self.config._attn_implementation
254+
if output_attentions:
255+
attn_implementation = "eager"
256+
193257
seq_length = hidden_states.shape[1]
194258

195259
def shape(states: torch.Tensor) -> torch.Tensor:
@@ -215,12 +279,10 @@ def unshape(states: torch.Tensor) -> torch.Tensor:
215279
cos, sin = self.rope_embed(value_states, position_ids)
216280
query_states, key_states = RoPE.apply_rotary_pos_emb(query_states, key_states, cos, sin)
217281

218-
# Compute attention weights
219-
scores = torch.matmul(query_states, key_states.transpose(3, 2)) # "bnqd,bnkd->bnqk"
220-
scores += mask
221-
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
222-
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
223-
attn_output = torch.matmul(attn_weights, value_states)
282+
if attn_implementation == "sdpa":
283+
attn_output, attn_weights = self._sdpa_attention(query_states, key_states, value_states, mask)
284+
else: # eager
285+
attn_output, attn_weights = self._eager_attention(query_states, key_states, value_states, mask)
224286

225287
# Project attention output
226288
attn_output = unshape(attn_output)

src/chronos/chronos2/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ class Chronos2Model(PreTrainedModel):
199199
config_class = Chronos2CoreConfig # type: ignore[assignment]
200200
_supports_long_horizon: bool = True
201201
_supports_future_covariates: bool = True
202+
_supports_sdpa: bool = True
202203

203204
def __init__(self, config: Chronos2CoreConfig):
204205
assert hasattr(config, "chronos_config"), "Not a valid Chronos config"

src/chronos/chronos2/pipeline.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,6 @@ def fit(
211211
lr_scheduler_type="linear",
212212
warmup_ratio=0.0,
213213
optim="adamw_torch_fused",
214-
logging_dir=str(output_dir / "logs"),
215214
logging_strategy="steps",
216215
logging_steps=100,
217216
disable_tqdm=False,

test/test_chronos2.py

Lines changed: 132 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
from chronos import BaseChronosPipeline, Chronos2Pipeline
1616
from chronos.chronos2.dataset import convert_df_input_to_list_of_dicts_input
17+
from chronos.chronos2.config import Chronos2CoreConfig
18+
from chronos.chronos2.layers import MHA
19+
1720
from test.util import validate_tensor
1821

1922
DUMMY_MODEL_PATH = Path(__file__).parent / "dummy-chronos2-model"
@@ -317,13 +320,11 @@ def test_when_input_is_invalid_then_predict_raises_value_error(pipeline, inputs,
317320
_ = pipeline.predict(inputs, prediction_length=10)
318321

319322

320-
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
323+
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
321324
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
322-
def test_pipeline_predict_can_handle_different_model_and_input_dtypes(
323-
torch_dtype: torch.dtype, input_dtype: torch.dtype
324-
):
325+
def test_pipeline_predict_can_handle_different_model_and_input_dtypes(dtype: torch.dtype, input_dtype: torch.dtype):
325326
pipeline = BaseChronosPipeline.from_pretrained(
326-
Path(__file__).parent / "dummy-chronos2-model", device_map="cpu", torch_dtype=torch_dtype
327+
Path(__file__).parent / "dummy-chronos2-model", device_map="cpu", dtype=dtype
327328
)
328329
context = 10 * torch.rand(size=(4, 3, 16)) + 10
329330
context = context.to(dtype=input_dtype)
@@ -936,3 +937,129 @@ def test_two_step_finetuning_with_df_input_works(pipeline, context_setup, future
936937

937938
# Check predictions from the fine-tuned model are different from the original predictions
938939
assert not np.allclose(orig_result_before["predictions"].to_numpy(), result["predictions"].to_numpy())
940+
941+
942+
@pytest.mark.parametrize("attn_implementation", ["eager", "sdpa"])
943+
def test_pipeline_works_with_different_attention_implementations(attn_implementation):
944+
"""Test that the pipeline works with different attention implementations."""
945+
# Load the dummy model
946+
model_path = Path(__file__).parent / "dummy-chronos2-model"
947+
948+
# Load with specified attention implementation
949+
pipeline = BaseChronosPipeline.from_pretrained(
950+
model_path, device_map="cpu", attn_implementation=attn_implementation
951+
)
952+
953+
# Verify the config has the correct attention implementation
954+
assert pipeline.model.config._attn_implementation == attn_implementation
955+
956+
# Test prediction with simple input
957+
inputs = torch.rand(2, 1, 16)
958+
prediction_length = 7
959+
960+
outputs = pipeline.predict(inputs, prediction_length=prediction_length)
961+
962+
# Check outputs are valid
963+
assert isinstance(outputs, list) and len(outputs) == 2
964+
for out in outputs:
965+
validate_tensor(out, (1, DEFAULT_MODEL_NUM_QUANTILES, 7), dtype=torch.float32)
966+
967+
968+
@pytest.mark.parametrize("attn_implementation", ["eager", "sdpa"])
969+
@pytest.mark.parametrize("output_attentions", [False, True])
970+
def test_attention_implementations_with_output_attentions(attn_implementation, output_attentions):
971+
"""Test that attention implementations handle output_attentions correctly."""
972+
# Create config with specified attention implementation
973+
config = Chronos2CoreConfig(
974+
d_model=128,
975+
d_kv=32,
976+
num_heads=4,
977+
dropout_rate=0.1,
978+
attn_implementation=attn_implementation,
979+
)
980+
981+
# Create MHA layer
982+
mha = MHA(config, use_rope=True)
983+
mha.eval()
984+
985+
# Create dummy inputs
986+
batch_size = 2
987+
seq_len = 10
988+
hidden_states = torch.randn(batch_size, seq_len, config.d_model)
989+
position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1)
990+
mask = torch.zeros(batch_size, config.num_heads, seq_len, seq_len)
991+
992+
# Test forward pass
993+
output = mha(
994+
hidden_states=hidden_states,
995+
mask=mask,
996+
position_ids=position_ids,
997+
output_attentions=output_attentions,
998+
)
999+
1000+
# Check output shape
1001+
assert output.hidden_states.shape == (batch_size, seq_len, config.d_model)
1002+
1003+
# Check attention weights - should only be returned when output_attentions=True
1004+
if output_attentions:
1005+
assert output.attn_weights is not None
1006+
assert output.attn_weights.shape == (batch_size, config.num_heads, seq_len, seq_len)
1007+
else:
1008+
# SDPA doesn't return weights
1009+
if attn_implementation == "sdpa":
1010+
assert output.attn_weights is None
1011+
1012+
1013+
def test_eager_and_sdpa_produce_identical_outputs(pipeline):
1014+
"""Test that eager and SDPA implementations produce identical outputs on full pipeline."""
1015+
# Reload pipeline with SDPA
1016+
model_path = Path(__file__).parent / "dummy-chronos2-model"
1017+
pipeline_sdpa = BaseChronosPipeline.from_pretrained(
1018+
model_path, device_map="cpu", attn_implementation="sdpa", dtype=torch.float32
1019+
)
1020+
1021+
# Note: the original pipeline fixture uses default attn_implementation which should be sdpa
1022+
# Force eager for comparison
1023+
pipeline_eager = BaseChronosPipeline.from_pretrained(
1024+
model_path, device_map="cpu", attn_implementation="eager", dtype=torch.float32
1025+
)
1026+
1027+
# Test 1: Simple univariate input
1028+
inputs_simple = torch.rand(2, 1, 16)
1029+
prediction_length = 7
1030+
1031+
with torch.no_grad():
1032+
outputs_eager = pipeline_eager.predict(inputs_simple, prediction_length=prediction_length)
1033+
outputs_sdpa = pipeline_sdpa.predict(inputs_simple, prediction_length=prediction_length)
1034+
1035+
# Verify outputs match exactly
1036+
assert len(outputs_eager) == len(outputs_sdpa)
1037+
for out_eager, out_sdpa in zip(outputs_eager, outputs_sdpa):
1038+
# Should match exactly or very close (numerical precision)
1039+
assert torch.allclose(out_eager, out_sdpa, atol=1e-5, rtol=1e-4)
1040+
1041+
# Test 2: Multivariate inputs with covariates to test group attention
1042+
inputs_grouped = [
1043+
{
1044+
"target": np.random.randn(2, 36),
1045+
"past_covariates": {
1046+
"temperature": np.random.randn(36),
1047+
"weather_type": np.random.choice(["sunny", "cloudy", "rainy"], size=36),
1048+
},
1049+
"future_covariates": {
1050+
"temperature": np.random.randn(prediction_length),
1051+
"weather_type": np.random.choice(["sunny", "cloudy", "rainy"], size=prediction_length),
1052+
},
1053+
}
1054+
for _ in range(5)
1055+
]
1056+
1057+
with torch.no_grad():
1058+
outputs_eager_grouped = pipeline_eager.predict(inputs_grouped, prediction_length=prediction_length)
1059+
outputs_sdpa_grouped = pipeline_sdpa.predict(inputs_grouped, prediction_length=prediction_length)
1060+
1061+
# Verify outputs match for grouped inputs
1062+
assert len(outputs_eager_grouped) == len(outputs_sdpa_grouped)
1063+
for out_eager, out_sdpa in zip(outputs_eager_grouped, outputs_sdpa_grouped):
1064+
# Should match exactly or very close (numerical precision)
1065+
assert torch.allclose(out_eager, out_sdpa, atol=1e-5, rtol=1e-4)

0 commit comments

Comments
 (0)