Skip to content

Commit 1feeffa

Browse files
authored
Fix #606: Follow-ups from Combine distributed fix (yherin/fix-distrib... (#624)
1 parent 6e66d2f commit 1feeffa

2 files changed

Lines changed: 26 additions & 1 deletion

File tree

mlforecast/lag_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ def take(self, idxs: np.ndarray) -> "Combine":
463463

464464
@staticmethod
465465
def stack(transforms: Sequence["Combine"]) -> "Combine":
466-
out = copy.deepcopy(transforms[0])
466+
out = copy.copy(transforms[0])
467467
out.tfm1 = transforms[0].tfm1.stack([tfm.tfm1 for tfm in transforms])
468468
out.tfm2 = transforms[0].tfm2.stack([tfm.tfm2 for tfm in transforms])
469469
return out

tests/test_lag_transforms.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,26 @@ def test_nested_combine_take(grouped_array):
9292
assert subset_tfm.operator == operator.sub
9393
assert subset_tfm.tfm1.operator == operator.add
9494

95+
# Numerical correctness: subset update() matches a fresh fit on the same 3 groups
96+
indptr = grouped_array.indptr
97+
parts = [grouped_array.data[indptr[i]:indptr[i + 1]] for i in idxs]
98+
new_indptr = np.zeros(len(idxs) + 1, dtype=indptr.dtype)
99+
for j, part in enumerate(parts):
100+
new_indptr[j + 1] = new_indptr[j] + len(part)
101+
subset_ga = CoreGroupedArray(np.concatenate(parts), new_indptr)
102+
103+
fresh_tfm = Combine(
104+
Combine(
105+
RollingMean(window_size=7, min_samples=1),
106+
RollingMean(window_size=5, min_samples=1),
107+
operator.add
108+
),
109+
RollingMean(window_size=3, min_samples=1),
110+
operator.sub
111+
)._set_core_tfm(1)
112+
fresh_tfm.transform(subset_ga)
113+
np.testing.assert_allclose(subset_tfm.update(subset_ga), fresh_tfm.update(subset_ga))
114+
95115
def test_combine_stack(grouped_array):
96116
tfm1 = Combine(
97117
RollingMean(window_size=7, min_samples=1),
@@ -112,6 +132,11 @@ def test_combine_stack(grouped_array):
112132
assert isinstance(stacked_tfm, Combine)
113133
assert stacked_tfm.operator == operator.add
114134

135+
# Numerical correctness: stacking a single fitted transform should reproduce its update()
136+
single_stacked = Combine.stack([tfm1])
137+
tfm1.transform(grouped_array) # reset internal state
138+
np.testing.assert_allclose(single_stacked.update(grouped_array), tfm1.update(grouped_array))
139+
115140

116141
def test_combine_stack_behavioral(grouped_array):
117142
"""Verify that Combine.stack() doesn't just return first partition"""

0 commit comments

Comments
 (0)