Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlforecast/lag_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def take(self, idxs: np.ndarray) -> "Combine":

@staticmethod
def stack(transforms: Sequence["Combine"]) -> "Combine":
out = copy.deepcopy(transforms[0])
out = copy.copy(transforms[0])
out.tfm1 = transforms[0].tfm1.stack([tfm.tfm1 for tfm in transforms])
out.tfm2 = transforms[0].tfm2.stack([tfm.tfm2 for tfm in transforms])
return out
25 changes: 25 additions & 0 deletions tests/test_lag_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,26 @@ def test_nested_combine_take(grouped_array):
assert subset_tfm.operator == operator.sub
assert subset_tfm.tfm1.operator == operator.add

# Numerical correctness: subset update() matches a fresh fit on the same 3 groups
indptr = grouped_array.indptr
parts = [grouped_array.data[indptr[i]:indptr[i + 1]] for i in idxs]
new_indptr = np.zeros(len(idxs) + 1, dtype=indptr.dtype)
for j, part in enumerate(parts):
new_indptr[j + 1] = new_indptr[j] + len(part)
subset_ga = CoreGroupedArray(np.concatenate(parts), new_indptr)

fresh_tfm = Combine(
Combine(
RollingMean(window_size=7, min_samples=1),
RollingMean(window_size=5, min_samples=1),
operator.add
),
RollingMean(window_size=3, min_samples=1),
operator.sub
)._set_core_tfm(1)
fresh_tfm.transform(subset_ga)
np.testing.assert_allclose(subset_tfm.update(subset_ga), fresh_tfm.update(subset_ga))

def test_combine_stack(grouped_array):
tfm1 = Combine(
RollingMean(window_size=7, min_samples=1),
Expand All @@ -112,6 +132,11 @@ def test_combine_stack(grouped_array):
assert isinstance(stacked_tfm, Combine)
assert stacked_tfm.operator == operator.add

# Numerical correctness: stacking a single fitted transform should reproduce its update()
single_stacked = Combine.stack([tfm1])
tfm1.transform(grouped_array) # reset internal state
np.testing.assert_allclose(single_stacked.update(grouped_array), tfm1.update(grouped_array))


def test_combine_stack_behavioral(grouped_array):
"""Verify that Combine.stack() doesn't just return first partition"""
Expand Down