@@ -142,6 +142,7 @@ def _preprocess_partition(
142142 keep_last_n : Optional [int ] = None ,
143143 window_info : Optional [WindowInfo ] = None ,
144144 fit_ts_only : bool = False ,
145+ weight_col : str | None = None ,
145146 ) -> List [List [Any ]]:
146147 ts = copy .deepcopy (base_ts )
147148 if fit_ts_only :
@@ -152,6 +153,7 @@ def _preprocess_partition(
152153 target_col = target_col ,
153154 static_features = static_features ,
154155 keep_last_n = keep_last_n ,
156+ weight_col = weight_col ,
155157 )
156158 core_tfms = ts ._get_core_lag_tfms ()
157159 if core_tfms :
@@ -195,6 +197,7 @@ def _preprocess_partition(
195197 static_features = static_features ,
196198 dropna = dropna ,
197199 keep_last_n = keep_last_n ,
200+ weight_col = weight_col ,
198201 )
199202 return [
200203 [
@@ -220,6 +223,7 @@ def _preprocess_partitions(
220223 keep_last_n : Optional [int ] = None ,
221224 window_info : Optional [WindowInfo ] = None ,
222225 fit_ts_only : bool = False ,
226+ weight_col : str | None = None ,
223227 ) -> List [Any ]:
224228 if self .num_partitions :
225229 partition = dict (by = id_col , num = self .num_partitions , algo = "coarse" )
@@ -247,6 +251,7 @@ def _preprocess_partitions(
247251 "keep_last_n" : keep_last_n ,
248252 "window_info" : window_info ,
249253 "fit_ts_only" : fit_ts_only ,
254+ "weight_col" : weight_col ,
250255 },
251256 schema = "ts:binary,train:binary,valid:binary" ,
252257 engine = self .engine ,
@@ -266,13 +271,15 @@ def _preprocess(
266271 dropna : bool = True ,
267272 keep_last_n : Optional [int ] = None ,
268273 window_info : Optional [WindowInfo ] = None ,
274+ weight_col : str | None = None ,
269275 ) -> fugue .AnyDataFrame :
270276 self ._base_ts .id_col = id_col
271277 self ._base_ts .time_col = time_col
272278 self ._base_ts .target_col = target_col
273279 self ._base_ts .static_features = static_features
274280 self ._base_ts .dropna = dropna
275281 self ._base_ts .keep_last_n = keep_last_n
282+ self ._base_ts .weight_col = weight_col
276283 self ._partition_results = self ._preprocess_partitions (
277284 data = data ,
278285 id_col = id_col ,
@@ -282,6 +289,7 @@ def _preprocess(
282289 dropna = dropna ,
283290 keep_last_n = keep_last_n ,
284291 window_info = window_info ,
292+ weight_col = weight_col ,
285293 )
286294 base_schema = fa .get_schema (data )
287295 features_schema = {
@@ -341,6 +349,7 @@ def _fit(
341349 dropna : bool = True ,
342350 keep_last_n : Optional [int ] = None ,
343351 window_info : Optional [WindowInfo ] = None ,
352+ weight_col : str | None = None ,
344353 ) -> "DistributedMLForecast" :
345354 prep = self ._preprocess (
346355 data ,
@@ -351,28 +360,41 @@ def _fit(
351360 dropna = dropna ,
352361 keep_last_n = keep_last_n ,
353362 window_info = window_info ,
363+ weight_col = weight_col ,
354364 )
365+ exclude_cols = {id_col , time_col , target_col }
366+ if weight_col is not None :
367+ exclude_cols .add (weight_col )
355368 features = [
356369 x
357370 for x in fa .get_column_names (prep )
358- if x not in { id_col , time_col , target_col }
371+ if x not in exclude_cols
359372 ]
360373 self .models_ = {}
361374 if SPARK_INSTALLED and isinstance (data , SparkDataFrame ):
362375 featurizer = VectorAssembler (
363376 inputCols = features , outputCol = "features" , handleInvalid = "keep"
364377 )
365- train_data = featurizer .transform (prep )[target_col , "features" ]
378+ select_cols = [target_col , "features" ]
379+ if weight_col is not None :
380+ select_cols .append (weight_col )
381+ train_data = featurizer .transform (prep ).select (* select_cols )
366382 for name , model in self .models .items ():
367- trained_model = model ._pre_fit (target_col ).fit (train_data )
383+ trained_model = model ._pre_fit (target_col , weight_col ).fit (train_data )
368384 self .models_ [name ] = model .extract_local_model (trained_model )
369385 elif DASK_INSTALLED and isinstance (data , dd .DataFrame ):
370386 X , y = prep [features ], prep [target_col ]
387+ if weights := weight_col :
388+ weights = prep [weight_col ]
371389 for name , model in self .models .items ():
372- trained_model = clone (model ).fit (X , y )
390+ trained_model = clone (model ).fit (X , y , sample_weight = weights )
373391 self .models_ [name ] = trained_model .model_
374392 elif RAY_INSTALLED and isinstance (data , RayDataset ):
375393 # Need to materialize
394+ if weight_col is not None :
395+ raise NotImplementedError (
396+ "Only spark and dask engines currently support sample weights."
397+ )
376398 prep_selected = prep .select_columns (cols = features + [target_col ]).materialize ()
377399 X = RayDMatrix (
378400 prep_selected ,
@@ -396,6 +418,7 @@ def fit(
396418 static_features : Optional [List [str ]] = None ,
397419 dropna : bool = True ,
398420 keep_last_n : Optional [int ] = None ,
421+ weight_col : str | None = None ,
399422 ) -> "DistributedMLForecast" :
400423 """Apply the feature engineering and train the models.
401424
@@ -409,6 +432,7 @@ def fit(
409432 dropna (bool): Drop rows with missing values produced by the transformations. Defaults to True.
410433 keep_last_n (int, optional): Keep only these many records from each serie for the forecasting step. Can save time and memory if your features allow it.
411434 Defaults to None.
435+ weight_col (str, optional): Column that contains the sample weights. Defaults to None.
412436
413437 Returns:
414438 (DistributedMLForecast): Forecast object with series values and trained models.
@@ -421,6 +445,7 @@ def fit(
421445 static_features = static_features ,
422446 dropna = dropna ,
423447 keep_last_n = keep_last_n ,
448+ weight_col = weight_col ,
424449 )
425450
426451 @staticmethod
@@ -548,6 +573,7 @@ def cross_validation(
548573 before_predict_callback : Optional [Callable ] = None ,
549574 after_predict_callback : Optional [Callable ] = None ,
550575 input_size : Optional [int ] = None ,
576+ weight_col : str | None = None ,
551577 ) -> fugue .AnyDataFrame :
552578 """Perform time series cross validation.
553579 Creates `n_windows` splits where each window has `h` test periods,
@@ -577,6 +603,7 @@ def cross_validation(
577603 The series identifier is on the index. Defaults to None.
578604 input_size (int, optional): Maximum training samples per serie in each window. If None, will use an expanding window.
579605 Defaults to None.
606+ weight_col (str, optional): Column that contains the sample weights. Defaults to None.
580607
581608 Returns:
582609 (dask, spark or ray DataFrame): Predictions for each window with the series id, timestamp, target value and predictions from each model.
@@ -595,6 +622,7 @@ def cross_validation(
595622 dropna = dropna ,
596623 keep_last_n = keep_last_n ,
597624 window_info = window_info ,
625+ weight_col = weight_col ,
598626 )
599627 self .cv_models_ .append (self .models_ )
600628 partition_results = self ._partition_results
@@ -608,6 +636,7 @@ def cross_validation(
608636 dropna = dropna ,
609637 keep_last_n = keep_last_n ,
610638 window_info = window_info ,
639+ weight_col = weight_col ,
611640 )
612641 schema = self ._get_predict_schema () + Schema (
613642 ("cutoff" , "datetime" ), (self ._base_ts .target_col , "double" )
@@ -846,4 +875,4 @@ def combine_core_lag_tfms(by_partition):
846875 fcst = MLForecast (models = self .models_ , freq = ts .freq )
847876 fcst .ts = ts
848877 fcst .models_ = self .models_
849- return fcst
878+ return fcst
0 commit comments