Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytorch_forecasting/data/timeseries/_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2348,7 +2348,7 @@ def __getitem__(self, idx: int) -> tuple[dict[str, torch.Tensor], torch.Tensor]:

@staticmethod
def _collate_fn(
batches: list[tuple[dict[str, torch.Tensor], torch.Tensor]]
batches: list[tuple[dict[str, torch.Tensor], torch.Tensor]],
) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
"""
Collate function to combine items into mini-batch for dataloader.
Expand Down
2 changes: 1 addition & 1 deletion pytorch_forecasting/models/base/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _concatenate_output(
str,
List[Union[List[torch.Tensor], torch.Tensor, bool, int, str, np.ndarray]],
]
]
],
) -> Dict[
str, Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, int, bool, str]]]
]:
Expand Down
6 changes: 3 additions & 3 deletions pytorch_forecasting/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def autocorrelation(input, dim=0):


def unpack_sequence(
sequence: Union[torch.Tensor, rnn.PackedSequence]
sequence: Union[torch.Tensor, rnn.PackedSequence],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Unpack RNN sequence.
Expand All @@ -257,7 +257,7 @@ def unpack_sequence(


def concat_sequences(
sequences: Union[List[torch.Tensor], List[rnn.PackedSequence]]
sequences: Union[List[torch.Tensor], List[rnn.PackedSequence]],
) -> Union[torch.Tensor, rnn.PackedSequence]:
"""
Concatenate RNN sequences.
Expand All @@ -272,7 +272,7 @@ def concat_sequences(
if isinstance(sequences[0], rnn.PackedSequence):
return rnn.pack_sequence(sequences, enforce_sorted=False)
elif isinstance(sequences[0], torch.Tensor):
return torch.cat(sequences, dim=1)
return torch.cat(sequences, dim=0)
elif isinstance(sequences[0], (tuple, list)):
return tuple(
concat_sequences([sequences[ii][i] for ii in range(len(sequences))])
Expand Down
48 changes: 47 additions & 1 deletion tests/test_models/test_temporal_fusion_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
import pytest
import torch

from pytorch_forecasting import TimeSeriesDataSet
from pytorch_forecasting import Baseline, TimeSeriesDataSet
from pytorch_forecasting.data import NaNLabelEncoder
from pytorch_forecasting.data.encoders import GroupNormalizer, MultiNormalizer
from pytorch_forecasting.data.examples import generate_ar_data
from pytorch_forecasting.metrics import (
CrossEntropy,
MQF2DistributionLoss,
Expand Down Expand Up @@ -521,3 +522,48 @@ def test_no_exogenous_variable():
return_y=True,
return_index=True,
)


def test_correct_prediction_concatenation():
data = generate_ar_data(seasonality=10.0, timesteps=100, n_series=2, seed=42)
data["static"] = 2
data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D")
data.head()

# create dataset and dataloaders
max_encoder_length = 20
max_prediction_length = 5

training_cutoff = data["time_idx"].max() - max_prediction_length

context_length = max_encoder_length
prediction_length = max_prediction_length

training = TimeSeriesDataSet(
data[lambda x: x.time_idx <= training_cutoff],
time_idx="time_idx",
target="value",
categorical_encoders={"series": NaNLabelEncoder().fit(data.series)},
group_ids=["series"],
# only unknown variable is "value"
# and N-Beats can also not take any additional variables
time_varying_unknown_reals=["value"],
max_encoder_length=context_length,
max_prediction_length=prediction_length,
)

batch_size = 71
train_dataloader = training.to_dataloader(
train=True, batch_size=batch_size, num_workers=0
)

baseline_model = Baseline()
predictions = baseline_model.predict(
train_dataloader,
return_x=True,
return_y=True,
trainer_kwargs=dict(logger=None, accelerator="cpu"),
)

# The predicted output and the target should have the same size.
assert predictions.output.size() == predictions.y[0].size()
Loading