Skip to content

Commit b276598

Browse files
committed
Uses sample_weights in distributed learning and adds testing
1 parent 9bda122 commit b276598

4 files changed

Lines changed: 128 additions & 20 deletions

File tree

mlforecast/distributed/forecast.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def _preprocess_partition(
142142
keep_last_n: Optional[int] = None,
143143
window_info: Optional[WindowInfo] = None,
144144
fit_ts_only: bool = False,
145+
weight_col: str | None = None,
145146
) -> List[List[Any]]:
146147
ts = copy.deepcopy(base_ts)
147148
if fit_ts_only:
@@ -152,6 +153,7 @@ def _preprocess_partition(
152153
target_col=target_col,
153154
static_features=static_features,
154155
keep_last_n=keep_last_n,
156+
weight_col=weight_col,
155157
)
156158
core_tfms = ts._get_core_lag_tfms()
157159
if core_tfms:
@@ -195,6 +197,7 @@ def _preprocess_partition(
195197
static_features=static_features,
196198
dropna=dropna,
197199
keep_last_n=keep_last_n,
200+
weight_col=weight_col,
198201
)
199202
return [
200203
[
@@ -220,6 +223,7 @@ def _preprocess_partitions(
220223
keep_last_n: Optional[int] = None,
221224
window_info: Optional[WindowInfo] = None,
222225
fit_ts_only: bool = False,
226+
weight_col: str | None = None,
223227
) -> List[Any]:
224228
if self.num_partitions:
225229
partition = dict(by=id_col, num=self.num_partitions, algo="coarse")
@@ -247,6 +251,7 @@ def _preprocess_partitions(
247251
"keep_last_n": keep_last_n,
248252
"window_info": window_info,
249253
"fit_ts_only": fit_ts_only,
254+
"weight_col": weight_col,
250255
},
251256
schema="ts:binary,train:binary,valid:binary",
252257
engine=self.engine,
@@ -266,13 +271,15 @@ def _preprocess(
266271
dropna: bool = True,
267272
keep_last_n: Optional[int] = None,
268273
window_info: Optional[WindowInfo] = None,
274+
weight_col: str | None = None,
269275
) -> fugue.AnyDataFrame:
270276
self._base_ts.id_col = id_col
271277
self._base_ts.time_col = time_col
272278
self._base_ts.target_col = target_col
273279
self._base_ts.static_features = static_features
274280
self._base_ts.dropna = dropna
275281
self._base_ts.keep_last_n = keep_last_n
282+
self._base_ts.weight_col = weight_col
276283
self._partition_results = self._preprocess_partitions(
277284
data=data,
278285
id_col=id_col,
@@ -282,6 +289,7 @@ def _preprocess(
282289
dropna=dropna,
283290
keep_last_n=keep_last_n,
284291
window_info=window_info,
292+
weight_col=weight_col,
285293
)
286294
base_schema = fa.get_schema(data)
287295
features_schema = {
@@ -341,6 +349,7 @@ def _fit(
341349
dropna: bool = True,
342350
keep_last_n: Optional[int] = None,
343351
window_info: Optional[WindowInfo] = None,
352+
weight_col: str | None = None,
344353
) -> "DistributedMLForecast":
345354
prep = self._preprocess(
346355
data,
@@ -351,28 +360,41 @@ def _fit(
351360
dropna=dropna,
352361
keep_last_n=keep_last_n,
353362
window_info=window_info,
363+
weight_col=weight_col,
354364
)
365+
exclude_cols = {id_col, time_col, target_col}
366+
if weight_col is not None:
367+
exclude_cols.add(weight_col)
355368
features = [
356369
x
357370
for x in fa.get_column_names(prep)
358-
if x not in {id_col, time_col, target_col}
371+
if x not in exclude_cols
359372
]
360373
self.models_ = {}
361374
if SPARK_INSTALLED and isinstance(data, SparkDataFrame):
362375
featurizer = VectorAssembler(
363376
inputCols=features, outputCol="features", handleInvalid="keep"
364377
)
365-
train_data = featurizer.transform(prep)[target_col, "features"]
378+
select_cols = [target_col, "features"]
379+
if weight_col is not None:
380+
select_cols.append(weight_col)
381+
train_data = featurizer.transform(prep).select(*select_cols)
366382
for name, model in self.models.items():
367-
trained_model = model._pre_fit(target_col).fit(train_data)
383+
trained_model = model._pre_fit(target_col, weight_col).fit(train_data)
368384
self.models_[name] = model.extract_local_model(trained_model)
369385
elif DASK_INSTALLED and isinstance(data, dd.DataFrame):
370386
X, y = prep[features], prep[target_col]
387+
if weights:=weight_col:
388+
weights = prep[weight_col]
371389
for name, model in self.models.items():
372-
trained_model = clone(model).fit(X, y)
390+
trained_model = clone(model).fit(X, y, sample_weight=weights)
373391
self.models_[name] = trained_model.model_
374392
elif RAY_INSTALLED and isinstance(data, RayDataset):
375393
# Need to materialize
394+
if weight_col is not None:
395+
raise NotImplementedError(
396+
"Only spark and dask engines currently support sample weights."
397+
)
376398
prep_selected = prep.select_columns(cols=features + [target_col]).materialize()
377399
X = RayDMatrix(
378400
prep_selected,
@@ -396,6 +418,7 @@ def fit(
396418
static_features: Optional[List[str]] = None,
397419
dropna: bool = True,
398420
keep_last_n: Optional[int] = None,
421+
weight_col: str | None = None,
399422
) -> "DistributedMLForecast":
400423
"""Apply the feature engineering and train the models.
401424
@@ -409,6 +432,7 @@ def fit(
409432
dropna (bool): Drop rows with missing values produced by the transformations. Defaults to True.
410433
keep_last_n (int, optional): Keep only these many records from each serie for the forecasting step. Can save time and memory if your features allow it.
411434
Defaults to None.
435+
weight_col (str, optional): Column that contains the sample weights. Defaults to None.
412436
413437
Returns:
414438
(DistributedMLForecast): Forecast object with series values and trained models.
@@ -421,6 +445,7 @@ def fit(
421445
static_features=static_features,
422446
dropna=dropna,
423447
keep_last_n=keep_last_n,
448+
weight_col=weight_col,
424449
)
425450

426451
@staticmethod
@@ -548,6 +573,7 @@ def cross_validation(
548573
before_predict_callback: Optional[Callable] = None,
549574
after_predict_callback: Optional[Callable] = None,
550575
input_size: Optional[int] = None,
576+
weight_col: str | None = None,
551577
) -> fugue.AnyDataFrame:
552578
"""Perform time series cross validation.
553579
Creates `n_windows` splits where each window has `h` test periods,
@@ -577,6 +603,7 @@ def cross_validation(
577603
The series identifier is on the index. Defaults to None.
578604
input_size (int, optional): Maximum training samples per serie in each window. If None, will use an expanding window.
579605
Defaults to None.
606+
weight_col (str, optional): Column that contains the sample weights. Defaults to None.
580607
581608
Returns:
582609
(dask, spark or ray DataFrame): Predictions for each window with the series id, timestamp, target value and predictions from each model.
@@ -595,6 +622,7 @@ def cross_validation(
595622
dropna=dropna,
596623
keep_last_n=keep_last_n,
597624
window_info=window_info,
625+
weight_col=weight_col,
598626
)
599627
self.cv_models_.append(self.models_)
600628
partition_results = self._partition_results
@@ -608,6 +636,7 @@ def cross_validation(
608636
dropna=dropna,
609637
keep_last_n=keep_last_n,
610638
window_info=window_info,
639+
weight_col=weight_col,
611640
)
612641
schema = self._get_predict_schema() + Schema(
613642
("cutoff", "datetime"), (self._base_ts.target_col, "double")
@@ -846,4 +875,4 @@ def combine_core_lag_tfms(by_partition):
846875
fcst = MLForecast(models=self.models_, freq=ts.freq)
847876
fcst.ts = ts
848877
fcst.models_ = self.models_
849-
return fcst
878+
return fcst

mlforecast/distributed/models/spark/lgb.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@
2323

2424

2525
class SparkLGBMForecast(LightGBMRegressor):
26-
def _pre_fit(self, target_col):
26+
def _pre_fit(self, target_col, weight_col=None):
27+
if weight_col is not None and hasattr(self, "setWeightCol"):
28+
return self.setLabelCol(target_col).setWeightCol(weight_col)
2729
return self.setLabelCol(target_col)
2830

2931
def extract_local_model(self, trained_model):
3032
model_str = trained_model.getNativeModel()
3133
local_model = lgb.Booster(model_str=model_str)
32-
return local_model
34+
return local_model

mlforecast/distributed/models/spark/xgb.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515

1616

1717
class SparkXGBForecast(SparkXGBRegressor):
18-
def _pre_fit(self, target_col):
18+
def _pre_fit(self, target_col, weight_col=None):
1919
self.setParams(label_col=target_col)
20+
if weight_col is not None:
21+
self.setParams(weight_col=weight_col)
2022
return self
2123

2224
def extract_local_model(self, trained_model):
2325
model_str = trained_model.get_booster().save_raw("ubj")
2426
local_model = xgb.XGBRegressor()
2527
local_model.load_model(model_str)
26-
return local_model
28+
return local_model

tests/test_distributed_forecast.py

Lines changed: 86 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
import warnings
33

44
import dask.dataframe as dd
5+
import numpy as np
56
import pandas as pd
67
import pytest
8+
from sklearn.base import BaseEstimator
79

810
from mlforecast.distributed import DistributedMLForecast
911
from mlforecast.distributed.models.dask.lgb import DaskLGBMForecast
@@ -12,21 +14,64 @@
1214

1315
warnings.simplefilter("ignore", FutureWarning)
1416

15-
@pytest.mark.skipif(sys.platform == "win32", reason="Distributed tests are not supported on Windows")
16-
@pytest.mark.skipif(sys.version_info <= (3, 9), reason="Distributed tests are not supported on Python < 3.10")
17-
def test_dask_distributed_forecast():
17+
18+
def _reset_index_partition(partition: pd.DataFrame) -> pd.DataFrame:
19+
return partition.reset_index()
20+
21+
22+
def _make_partitioned_series(df: pd.DataFrame, npartitions: int = 4) -> dd.DataFrame:
23+
partitioned = dd.from_pandas(df.set_index("unique_id"), npartitions=npartitions)
24+
partitioned = partitioned.map_partitions(_reset_index_partition)
25+
partitioned["unique_id"] = partitioned["unique_id"].astype(str)
26+
return partitioned
27+
28+
29+
@pytest.fixture(scope="module")
30+
def partitioned_series():
1831
series = generate_daily_series(
1932
100, equal_ends=True, min_length=500, max_length=1_000
2033
)
21-
npartitions = 4
22-
partitioned_series = dd.from_pandas(
23-
series.set_index("unique_id"), npartitions=npartitions
24-
) # make sure we split by the id_col
25-
partitioned_series = partitioned_series.map_partitions(lambda df: df.reset_index())
26-
partitioned_series["unique_id"] = partitioned_series["unique_id"].astype(
27-
str
28-
) # can't handle categoricals atm
34+
return _make_partitioned_series(series)
35+
36+
37+
@pytest.fixture
38+
def small_ordered_series():
39+
series = generate_daily_series(5, min_length=60, max_length=60)
40+
return series.sort_values(["unique_id", "ds"]).reset_index(drop=True)
41+
42+
43+
class _RecordingLocalModel:
44+
def __init__(self, sample_weight):
45+
if sample_weight is None:
46+
self.sample_weight_ = None
47+
self.weight_mean_ = 0.0
48+
else:
49+
self.sample_weight_ = np.asarray(sample_weight, dtype=float)
50+
self.weight_mean_ = float(self.sample_weight_.mean())
51+
52+
def predict(self, X):
53+
length = X.shape[0] if hasattr(X, "shape") else len(X)
54+
return np.full(length, self.weight_mean_, dtype=float)
2955

56+
57+
class _RecordingDaskRegressor(BaseEstimator):
58+
def fit(self, X, y, sample_weight=None): # noqa: ARG002, D401, N803
59+
if sample_weight is None:
60+
weights = None
61+
else:
62+
if hasattr(sample_weight, "compute"):
63+
sample_weight = sample_weight.compute()
64+
weights = (
65+
sample_weight.to_numpy()
66+
if hasattr(sample_weight, "to_numpy")
67+
else np.asarray(sample_weight, dtype=float)
68+
)
69+
self.model_ = _RecordingLocalModel(weights)
70+
return self
71+
72+
@pytest.mark.skipif(sys.platform == "win32", reason="Distributed tests are not supported on Windows")
73+
@pytest.mark.skipif(sys.version_info <= (3, 9), reason="Distributed tests are not supported on Python < 3.10")
74+
def test_dask_distributed_forecast(partitioned_series):
3075
# test existing features provide the same result
3176
fcst = DistributedMLForecast(
3277
models=[DaskLGBMForecast(verbosity=-1, random_state=0)],
@@ -49,3 +94,33 @@ def test_dask_distributed_forecast():
4994
fcst.preprocess(partitioned_series, static_features=[], dropna=False)
5095
preds2 = fcst.predict(10).compute()
5196
pd.testing.assert_frame_equal(preds1, preds2)
97+
98+
99+
@pytest.mark.skipif(sys.platform == "win32", reason="Distributed tests are not supported on Windows")
100+
@pytest.mark.skipif(sys.version_info <= (3, 9), reason="Distributed tests are not supported on Python < 3.10")
101+
def test_dask_distributed_weight_col_affects_predictions(small_ordered_series):
102+
def _fit_and_forecast(weights):
103+
weighted = small_ordered_series.copy()
104+
weighted["weight"] = weights
105+
partitioned = _make_partitioned_series(weighted, npartitions=2)
106+
fcst = DistributedMLForecast(
107+
models={"stub": _RecordingDaskRegressor()},
108+
freq="D",
109+
lags=[1],
110+
date_features=["dayofweek"],
111+
)
112+
fcst.fit(
113+
partitioned,
114+
static_features=[],
115+
dropna=False,
116+
weight_col="weight",
117+
)
118+
return fcst.predict(5).compute()
119+
120+
uniform_weights = np.ones(len(small_ordered_series))
121+
skewed_weights = np.arange(len(small_ordered_series), dtype=float)
122+
123+
preds_uniform = _fit_and_forecast(uniform_weights)
124+
preds_skewed = _fit_and_forecast(skewed_weights)
125+
126+
assert not np.allclose(preds_uniform["stub"], preds_skewed["stub"])

0 commit comments

Comments
 (0)