|
14 | 14 |
|
15 | 15 | from chronos import BaseChronosPipeline, Chronos2Pipeline |
16 | 16 | from chronos.chronos2.dataset import convert_df_input_to_list_of_dicts_input |
| 17 | +from chronos.chronos2.config import Chronos2CoreConfig |
| 18 | +from chronos.chronos2.layers import MHA |
| 19 | + |
17 | 20 | from test.util import validate_tensor |
18 | 21 |
|
19 | 22 | DUMMY_MODEL_PATH = Path(__file__).parent / "dummy-chronos2-model" |
@@ -317,13 +320,11 @@ def test_when_input_is_invalid_then_predict_raises_value_error(pipeline, inputs, |
317 | 320 | _ = pipeline.predict(inputs, prediction_length=10) |
318 | 321 |
|
319 | 322 |
|
320 | | -@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16]) |
| 323 | +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) |
321 | 324 | @pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64]) |
322 | | -def test_pipeline_predict_can_handle_different_model_and_input_dtypes( |
323 | | - torch_dtype: torch.dtype, input_dtype: torch.dtype |
324 | | -): |
| 325 | +def test_pipeline_predict_can_handle_different_model_and_input_dtypes(dtype: torch.dtype, input_dtype: torch.dtype): |
325 | 326 | pipeline = BaseChronosPipeline.from_pretrained( |
326 | | - Path(__file__).parent / "dummy-chronos2-model", device_map="cpu", torch_dtype=torch_dtype |
| 327 | + Path(__file__).parent / "dummy-chronos2-model", device_map="cpu", dtype=dtype |
327 | 328 | ) |
328 | 329 | context = 10 * torch.rand(size=(4, 3, 16)) + 10 |
329 | 330 | context = context.to(dtype=input_dtype) |
@@ -936,3 +937,129 @@ def test_two_step_finetuning_with_df_input_works(pipeline, context_setup, future |
936 | 937 |
|
937 | 938 | # Check predictions from the fine-tuned model are different from the original predictions |
938 | 939 | assert not np.allclose(orig_result_before["predictions"].to_numpy(), result["predictions"].to_numpy()) |
| 940 | + |
| 941 | + |
| 942 | +@pytest.mark.parametrize("attn_implementation", ["eager", "sdpa"]) |
| 943 | +def test_pipeline_works_with_different_attention_implementations(attn_implementation): |
| 944 | + """Test that the pipeline works with different attention implementations.""" |
| 945 | + # Load the dummy model |
| 946 | + model_path = Path(__file__).parent / "dummy-chronos2-model" |
| 947 | + |
| 948 | + # Load with specified attention implementation |
| 949 | + pipeline = BaseChronosPipeline.from_pretrained( |
| 950 | + model_path, device_map="cpu", attn_implementation=attn_implementation |
| 951 | + ) |
| 952 | + |
| 953 | + # Verify the config has the correct attention implementation |
| 954 | + assert pipeline.model.config._attn_implementation == attn_implementation |
| 955 | + |
| 956 | + # Test prediction with simple input |
| 957 | + inputs = torch.rand(2, 1, 16) |
| 958 | + prediction_length = 7 |
| 959 | + |
| 960 | + outputs = pipeline.predict(inputs, prediction_length=prediction_length) |
| 961 | + |
| 962 | + # Check outputs are valid |
| 963 | + assert isinstance(outputs, list) and len(outputs) == 2 |
| 964 | + for out in outputs: |
| 965 | + validate_tensor(out, (1, DEFAULT_MODEL_NUM_QUANTILES, 7), dtype=torch.float32) |
| 966 | + |
| 967 | + |
| 968 | +@pytest.mark.parametrize("attn_implementation", ["eager", "sdpa"]) |
| 969 | +@pytest.mark.parametrize("output_attentions", [False, True]) |
| 970 | +def test_attention_implementations_with_output_attentions(attn_implementation, output_attentions): |
| 971 | + """Test that attention implementations handle output_attentions correctly.""" |
| 972 | + # Create config with specified attention implementation |
| 973 | + config = Chronos2CoreConfig( |
| 974 | + d_model=128, |
| 975 | + d_kv=32, |
| 976 | + num_heads=4, |
| 977 | + dropout_rate=0.1, |
| 978 | + attn_implementation=attn_implementation, |
| 979 | + ) |
| 980 | + |
| 981 | + # Create MHA layer |
| 982 | + mha = MHA(config, use_rope=True) |
| 983 | + mha.eval() |
| 984 | + |
| 985 | + # Create dummy inputs |
| 986 | + batch_size = 2 |
| 987 | + seq_len = 10 |
| 988 | + hidden_states = torch.randn(batch_size, seq_len, config.d_model) |
| 989 | + position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1) |
| 990 | + mask = torch.zeros(batch_size, config.num_heads, seq_len, seq_len) |
| 991 | + |
| 992 | + # Test forward pass |
| 993 | + output = mha( |
| 994 | + hidden_states=hidden_states, |
| 995 | + mask=mask, |
| 996 | + position_ids=position_ids, |
| 997 | + output_attentions=output_attentions, |
| 998 | + ) |
| 999 | + |
| 1000 | + # Check output shape |
| 1001 | + assert output.hidden_states.shape == (batch_size, seq_len, config.d_model) |
| 1002 | + |
| 1003 | + # Check attention weights - should only be returned when output_attentions=True |
| 1004 | + if output_attentions: |
| 1005 | + assert output.attn_weights is not None |
| 1006 | + assert output.attn_weights.shape == (batch_size, config.num_heads, seq_len, seq_len) |
| 1007 | + else: |
| 1008 | + # SDPA doesn't return weights |
| 1009 | + if attn_implementation == "sdpa": |
| 1010 | + assert output.attn_weights is None |
| 1011 | + |
| 1012 | + |
| 1013 | +def test_eager_and_sdpa_produce_identical_outputs(pipeline): |
| 1014 | + """Test that eager and SDPA implementations produce identical outputs on full pipeline.""" |
| 1015 | + # Reload pipeline with SDPA |
| 1016 | + model_path = Path(__file__).parent / "dummy-chronos2-model" |
| 1017 | + pipeline_sdpa = BaseChronosPipeline.from_pretrained( |
| 1018 | + model_path, device_map="cpu", attn_implementation="sdpa", dtype=torch.float32 |
| 1019 | + ) |
| 1020 | + |
| 1021 | + # Note: the original pipeline fixture uses default attn_implementation which should be sdpa |
| 1022 | + # Force eager for comparison |
| 1023 | + pipeline_eager = BaseChronosPipeline.from_pretrained( |
| 1024 | + model_path, device_map="cpu", attn_implementation="eager", dtype=torch.float32 |
| 1025 | + ) |
| 1026 | + |
| 1027 | + # Test 1: Simple univariate input |
| 1028 | + inputs_simple = torch.rand(2, 1, 16) |
| 1029 | + prediction_length = 7 |
| 1030 | + |
| 1031 | + with torch.no_grad(): |
| 1032 | + outputs_eager = pipeline_eager.predict(inputs_simple, prediction_length=prediction_length) |
| 1033 | + outputs_sdpa = pipeline_sdpa.predict(inputs_simple, prediction_length=prediction_length) |
| 1034 | + |
| 1035 | + # Verify outputs match exactly |
| 1036 | + assert len(outputs_eager) == len(outputs_sdpa) |
| 1037 | + for out_eager, out_sdpa in zip(outputs_eager, outputs_sdpa): |
| 1038 | + # Should match exactly or very close (numerical precision) |
| 1039 | + assert torch.allclose(out_eager, out_sdpa, atol=1e-5, rtol=1e-4) |
| 1040 | + |
| 1041 | + # Test 2: Multivariate inputs with covariates to test group attention |
| 1042 | + inputs_grouped = [ |
| 1043 | + { |
| 1044 | + "target": np.random.randn(2, 36), |
| 1045 | + "past_covariates": { |
| 1046 | + "temperature": np.random.randn(36), |
| 1047 | + "weather_type": np.random.choice(["sunny", "cloudy", "rainy"], size=36), |
| 1048 | + }, |
| 1049 | + "future_covariates": { |
| 1050 | + "temperature": np.random.randn(prediction_length), |
| 1051 | + "weather_type": np.random.choice(["sunny", "cloudy", "rainy"], size=prediction_length), |
| 1052 | + }, |
| 1053 | + } |
| 1054 | + for _ in range(5) |
| 1055 | + ] |
| 1056 | + |
| 1057 | + with torch.no_grad(): |
| 1058 | + outputs_eager_grouped = pipeline_eager.predict(inputs_grouped, prediction_length=prediction_length) |
| 1059 | + outputs_sdpa_grouped = pipeline_sdpa.predict(inputs_grouped, prediction_length=prediction_length) |
| 1060 | + |
| 1061 | + # Verify outputs match for grouped inputs |
| 1062 | + assert len(outputs_eager_grouped) == len(outputs_sdpa_grouped) |
| 1063 | + for out_eager, out_sdpa in zip(outputs_eager_grouped, outputs_sdpa_grouped): |
| 1064 | + # Should match exactly or very close (numerical precision) |
| 1065 | + assert torch.allclose(out_eager, out_sdpa, atol=1e-5, rtol=1e-4) |
0 commit comments