Skip to content

Commit b962e24

Browse files
authored
use fill_gaps in resample_dataframe (#267)
1 parent 80330c7 commit b962e24

File tree

3 files changed

+72
-14
lines changed

3 files changed

+72
-14
lines changed

nbs/nixtla_client.ipynb

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
" retry_if_exception, \n",
5858
" retry_if_not_exception_type,\n",
5959
")\n",
60+
"from utilsforecast.preprocessing import fill_gaps\n",
6061
"from utilsforecast.processing import (\n",
6162
" backtest_splits,\n",
6263
" drop_index_if_pandas,\n",
@@ -405,10 +406,24 @@
405406
" self.freq = inferred_freq\n",
406407
"\n",
407408
" def resample_dataframe(self, df: pd.DataFrame):\n",
408-
" df = df.copy()\n",
409-
" df['ds'] = pd.to_datetime(df['ds'])\n",
410-
" resampled_df = df.set_index('ds').groupby('unique_id').resample(self.freq).bfill()\n",
411-
" resampled_df = resampled_df.drop(columns='unique_id').reset_index()\n",
409+
" if not pd.api.types.is_datetime64_any_dtype(df['ds'].dtype):\n",
410+
" df = df.copy(deep=False)\n",
411+
" df['ds'] = pd.to_datetime(df['ds'])\n",
412+
" resampled_df = fill_gaps(\n",
413+
" df,\n",
414+
" freq=self.freq,\n",
415+
" start='per_serie',\n",
416+
" end='per_serie',\n",
417+
" id_col='unique_id',\n",
418+
" time_col='ds',\n",
419+
" )\n",
420+
" numeric_cols = resampled_df.columns.drop(['unique_id', 'ds'])\n",
421+
" resampled_df[numeric_cols] = (\n",
422+
" resampled_df\n",
423+
" .groupby('unique_id', observed=True)\n",
424+
" [numeric_cols]\n",
425+
" .bfill()\n",
426+
" )\n",
412427
" resampled_df['ds'] = resampled_df['ds'].astype(str)\n",
413428
" return resampled_df\n",
414429
"\n",
@@ -525,7 +540,7 @@
525540
" Y_df = self.resample_dataframe(Y_df)\n",
526541
" x_cols = []\n",
527542
" if X_df is not None:\n",
528-
" x_cols = X_df.drop(columns=['unique_id', 'ds']).columns.to_list()\n",
543+
" x_cols = X_df.columns.drop(['unique_id', 'ds']).to_list()\n",
529544
" if not all(col in df.columns for col in x_cols):\n",
530545
" raise Exception(\n",
531546
" 'You must include the exogenous variables in the `df` object, '\n",
@@ -1961,6 +1976,39 @@
19611976
" )"
19621977
]
19631978
},
1979+
{
1980+
"cell_type": "code",
1981+
"execution_count": null,
1982+
"metadata": {},
1983+
"outputs": [],
1984+
"source": [
1985+
"#| hide\n",
1986+
"# test resample with timestamps at non standard cuts\n",
1987+
"custom_dates = pd.date_range('2000-01-01 00:04:00', freq='5min', periods=100)\n",
1988+
"custom_df = pd.DataFrame(\n",
1989+
" {\n",
1990+
" 'unique_id': np.repeat(np.array([0, 1]), 50),\n",
1991+
" 'ds': custom_dates,\n",
1992+
" 'y': np.arange(100),\n",
1993+
" }\n",
1994+
")\n",
1995+
"# drop second row from each serie\n",
1996+
"custom_df = custom_df.drop([1, 51])\n",
1997+
"model = _NixtlaClientModel(\n",
1998+
" client=nixtla_client,\n",
1999+
" h=1,\n",
2000+
" freq='5min'\n",
2001+
")\n",
2002+
"resampled_df = model.resample_dataframe(custom_df)\n",
2003+
"# we do a backfill so the second row must've got the value of the third row\n",
2004+
"assert resampled_df.loc[1, 'y'] == resampled_df.loc[2, 'y']\n",
2005+
"assert resampled_df.loc[51, 'y'] == resampled_df.loc[52, 'y']\n",
2006+
"pd.testing.assert_series_equal(\n",
2007+
" resampled_df['ds'],\n",
2008+
" custom_dates.to_series(index=resampled_df.index, name='ds').astype(str),\n",
2009+
")"
2010+
]
2011+
},
19642012
{
19652013
"cell_type": "code",
19662014
"execution_count": null,
@@ -2269,7 +2317,7 @@
22692317
" min_length=500 if freq != '15T' else 1_200, \n",
22702318
" max_length=550 if freq != '15T' else 2_000,\n",
22712319
" )\n",
2272-
" df_freq['ds'] = df_freq.groupby('unique_id')['ds'].transform(\n",
2320+
" df_freq['ds'] = df_freq.groupby('unique_id', observed=True)['ds'].transform(\n",
22732321
" lambda x: pd.date_range(periods=len(x), freq=freq, end='2023-01-01')\n",
22742322
" )\n",
22752323
" kwargs = dict(\n",

nixtlats/nixtla_client.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
retry_if_exception,
3030
retry_if_not_exception_type,
3131
)
32+
from utilsforecast.preprocessing import fill_gaps
3233
from utilsforecast.processing import (
3334
backtest_splits,
3435
drop_index_if_pandas,
@@ -335,12 +336,21 @@ def infer_freq(self, df: pd.DataFrame):
335336
self.freq = inferred_freq
336337

337338
def resample_dataframe(self, df: pd.DataFrame):
338-
df = df.copy()
339-
df["ds"] = pd.to_datetime(df["ds"])
340-
resampled_df = (
341-
df.set_index("ds").groupby("unique_id").resample(self.freq).bfill()
339+
if not pd.api.types.is_datetime64_any_dtype(df["ds"].dtype):
340+
df = df.copy(deep=False)
341+
df["ds"] = pd.to_datetime(df["ds"])
342+
resampled_df = fill_gaps(
343+
df,
344+
freq=self.freq,
345+
start="per_serie",
346+
end="per_serie",
347+
id_col="unique_id",
348+
time_col="ds",
342349
)
343-
resampled_df = resampled_df.drop(columns="unique_id").reset_index()
350+
numeric_cols = resampled_df.columns.drop(["unique_id", "ds"])
351+
resampled_df[numeric_cols] = resampled_df.groupby("unique_id", observed=True)[
352+
numeric_cols
353+
].bfill()
344354
resampled_df["ds"] = resampled_df["ds"].astype(str)
345355
return resampled_df
346356

@@ -469,7 +479,7 @@ def preprocess_dataframes(
469479
Y_df = self.resample_dataframe(Y_df)
470480
x_cols = []
471481
if X_df is not None:
472-
x_cols = X_df.drop(columns=["unique_id", "ds"]).columns.to_list()
482+
x_cols = X_df.columns.drop(["unique_id", "ds"]).to_list()
473483
if not all(col in df.columns for col in x_cols):
474484
raise Exception(
475485
"You must include the exogenous variables in the `df` object, "

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"statsforecast",
1414
]
1515
distributed = ["dask[dataframe]", "fugue[ray]>=0.8.7", "pyspark", "ray[serve-grpc]"]
16-
plotting = ["utilsforecast[plotting]>=0.0.5"]
16+
plotting = ["utilsforecast[plotting]>=0.1.7"]
1717
date_extras = ["holidays"]
1818

1919
setuptools.setup(
@@ -36,7 +36,7 @@
3636
"pydantic<2",
3737
"requests",
3838
"tenacity",
39-
"utilsforecast>=0.0.13",
39+
"utilsforecast>=0.1.7",
4040
],
4141
extras_require={
4242
"dev": dev + distributed + plotting + date_extras,

0 commit comments

Comments
 (0)