Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 26 additions & 16 deletions src/chronos/chronos2/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,29 @@ def validate_and_prepare_single_dict_task(
f"Found invalid type for `past_covariates` in element at index {idx}. "
f'Expected dict with {{"feat_1": tensor_1, "feat_2": tensor_2, ...}}, but found {type(task_past_covariates)}'
)

# gather keys and ensure known-future keys come last to match downstream assumptions
task_covariates_keys = sorted(task_past_covariates.keys())

task_future_covariates = task.get("future_covariates", {})
if not isinstance(task_future_covariates, dict):
raise ValueError(
f"Found invalid type for `future_covariates` in element at index {idx}. "
f'Expected dict with {{"feat_1": tensor_1, "feat_2": tensor_2, ...}}, but found {type(task_future_covariates)}'
)
task_future_covariates_keys = sorted(task_future_covariates.keys())
if not set(task_future_covariates_keys).issubset(task_covariates_keys):
raise ValueError(
f"Expected keys in `future_covariates` to be a subset of `past_covariates` {task_covariates_keys}, "
f"but found {task_future_covariates_keys} in element at index {idx}"
)

# create ordered keys: past-only first, then known-future (so known-future are the last rows)
task_past_only_keys = [k for k in task_covariates_keys if k not in task_future_covariates_keys] # past_only_keys
task_ordered_covariate_keys = task_past_only_keys + task_future_covariates_keys

task_past_covariates_list: list[torch.Tensor] = []
for key in task_covariates_keys:
for key in task_ordered_covariate_keys:
tensor = task_past_covariates[key]
if isinstance(tensor, np.ndarray):
# apply encoding to categorical variates
Expand Down Expand Up @@ -140,21 +160,10 @@ def validate_and_prepare_single_dict_task(
if task_past_covariates_list
else torch.zeros((0, history_length), device=task_target.device)
)
# validate future_covariates
task_future_covariates = task.get("future_covariates", {})
if not isinstance(task_future_covariates, dict):
raise ValueError(
f"Found invalid type for `future_covariates` in element at index {idx}. "
f'Expected dict with {{"feat_1": tensor_1, "feat_2": tensor_2, ...}}, but found {type(task_future_covariates)}'
)
task_future_covariates_keys = sorted(task_future_covariates.keys())
if not set(task_future_covariates_keys).issubset(task_covariates_keys):
raise ValueError(
f"Expected keys in `future_covariates` to be a subset of `past_covariates` {task_covariates_keys}, "
f"but found {task_future_covariates_keys} in element at index {idx}"
)

# validate future_covariates (build rows in the same task_ordered_covariate_keys order)
task_future_covariates_list: list[torch.Tensor] = []
for key in task_covariates_keys:
for key in task_ordered_covariate_keys:
# future values of past-only covariates are filled with NaNs
tensor = task_future_covariates.get(key, torch.full((prediction_length,), fill_value=torch.nan))
if isinstance(tensor, np.ndarray):
Expand Down Expand Up @@ -186,7 +195,8 @@ def validate_and_prepare_single_dict_task(
).to(dtype=torch.float32)
task_n_targets = task_target.shape[0]
task_n_covariates = task_past_covariates_tensor.shape[0]
task_n_future_covariates = len(task_future_covariates_list)
# number of known-future covariates
task_n_future_covariates = len(task_future_covariates_keys)

return (
task_context_tensor,
Expand Down