diff --git a/pytorch_forecasting/data/timeseries/_timeseries.py b/pytorch_forecasting/data/timeseries/_timeseries.py index 30fe9e0bb..f384367aa 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries.py +++ b/pytorch_forecasting/data/timeseries/_timeseries.py @@ -2348,7 +2348,7 @@ def __getitem__(self, idx: int) -> tuple[dict[str, torch.Tensor], torch.Tensor]: @staticmethod def _collate_fn( - batches: list[tuple[dict[str, torch.Tensor], torch.Tensor]] + batches: list[tuple[dict[str, torch.Tensor], torch.Tensor]], ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: """ Collate function to combine items into mini-batch for dataloader. diff --git a/pytorch_forecasting/models/base/_base_model.py b/pytorch_forecasting/models/base/_base_model.py index f7b14488f..1aa865dff 100644 --- a/pytorch_forecasting/models/base/_base_model.py +++ b/pytorch_forecasting/models/base/_base_model.py @@ -133,7 +133,7 @@ def _concatenate_output( str, List[Union[List[torch.Tensor], torch.Tensor, bool, int, str, np.ndarray]], ] - ] + ], ) -> Dict[ str, Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, int, bool, str]]] ]: diff --git a/pytorch_forecasting/utils/_utils.py b/pytorch_forecasting/utils/_utils.py index af93006cf..eb850b1e7 100644 --- a/pytorch_forecasting/utils/_utils.py +++ b/pytorch_forecasting/utils/_utils.py @@ -233,7 +233,7 @@ def autocorrelation(input, dim=0): def unpack_sequence( - sequence: Union[torch.Tensor, rnn.PackedSequence] + sequence: Union[torch.Tensor, rnn.PackedSequence], ) -> Tuple[torch.Tensor, torch.Tensor]: """ Unpack RNN sequence. @@ -257,7 +257,7 @@ def unpack_sequence( def concat_sequences( - sequences: Union[List[torch.Tensor], List[rnn.PackedSequence]] + sequences: Union[List[torch.Tensor], List[rnn.PackedSequence]], ) -> Union[torch.Tensor, rnn.PackedSequence]: """ Concatenate RNN sequences. @@ -272,7 +272,7 @@ def concat_sequences( if isinstance(sequences[0], rnn.PackedSequence): return rnn.pack_sequence(sequences, enforce_sorted=False) elif isinstance(sequences[0], torch.Tensor): - return torch.cat(sequences, dim=1) + return torch.cat(sequences, dim=0) elif isinstance(sequences[0], (tuple, list)): return tuple( concat_sequences([sequences[ii][i] for ii in range(len(sequences))]) diff --git a/tests/test_models/test_temporal_fusion_transformer.py b/tests/test_models/test_temporal_fusion_transformer.py index 24c249bd5..f0eab8671 100644 --- a/tests/test_models/test_temporal_fusion_transformer.py +++ b/tests/test_models/test_temporal_fusion_transformer.py @@ -10,9 +10,10 @@ import pytest import torch -from pytorch_forecasting import TimeSeriesDataSet +from pytorch_forecasting import Baseline, TimeSeriesDataSet from pytorch_forecasting.data import NaNLabelEncoder from pytorch_forecasting.data.encoders import GroupNormalizer, MultiNormalizer +from pytorch_forecasting.data.examples import generate_ar_data from pytorch_forecasting.metrics import ( CrossEntropy, MQF2DistributionLoss, @@ -521,3 +522,48 @@ def test_no_exogenous_variable(): return_y=True, return_index=True, ) + + +def test_correct_prediction_concatenation(): + data = generate_ar_data(seasonality=10.0, timesteps=100, n_series=2, seed=42) + data["static"] = 2 + data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D") + data.head() + + # create dataset and dataloaders + max_encoder_length = 20 + max_prediction_length = 5 + + training_cutoff = data["time_idx"].max() - max_prediction_length + + context_length = max_encoder_length + prediction_length = max_prediction_length + + training = TimeSeriesDataSet( + data[lambda x: x.time_idx <= training_cutoff], + time_idx="time_idx", + target="value", + categorical_encoders={"series": NaNLabelEncoder().fit(data.series)}, + group_ids=["series"], + # only unknown variable is "value" + # and N-Beats can also not take any additional variables + time_varying_unknown_reals=["value"], + max_encoder_length=context_length, + max_prediction_length=prediction_length, + ) + + batch_size = 71 + train_dataloader = training.to_dataloader( + train=True, batch_size=batch_size, num_workers=0 + ) + + baseline_model = Baseline() + predictions = baseline_model.predict( + train_dataloader, + return_x=True, + return_y=True, + trainer_kwargs=dict(logger=None, accelerator="cpu"), + ) + + # The predicted output and the target should have the same size. + assert predictions.output.size() == predictions.y[0].size()