|
57 | 57 | " retry_if_exception, \n", |
58 | 58 | " retry_if_not_exception_type,\n", |
59 | 59 | ")\n", |
| 60 | + "from utilsforecast.preprocessing import fill_gaps\n", |
60 | 61 | "from utilsforecast.processing import (\n", |
61 | 62 | " backtest_splits,\n", |
62 | 63 | " drop_index_if_pandas,\n", |
|
405 | 406 | " self.freq = inferred_freq\n", |
406 | 407 | "\n", |
407 | 408 | " def resample_dataframe(self, df: pd.DataFrame):\n", |
408 | | - " df = df.copy()\n", |
409 | | - " df['ds'] = pd.to_datetime(df['ds'])\n", |
410 | | - " resampled_df = df.set_index('ds').groupby('unique_id').resample(self.freq).bfill()\n", |
411 | | - " resampled_df = resampled_df.drop(columns='unique_id').reset_index()\n", |
| 409 | + " if not pd.api.types.is_datetime64_any_dtype(df['ds'].dtype):\n", |
| 410 | + " df = df.copy(deep=False)\n", |
| 411 | + " df['ds'] = pd.to_datetime(df['ds'])\n", |
| 412 | + " resampled_df = fill_gaps(\n", |
| 413 | + " df,\n", |
| 414 | + " freq=self.freq,\n", |
| 415 | + " start='per_serie',\n", |
| 416 | + " end='per_serie',\n", |
| 417 | + " id_col='unique_id',\n", |
| 418 | + " time_col='ds',\n", |
| 419 | + " )\n", |
| 420 | + " numeric_cols = resampled_df.columns.drop(['unique_id', 'ds'])\n", |
| 421 | + " resampled_df[numeric_cols] = (\n", |
| 422 | + " resampled_df\n", |
| 423 | + " .groupby('unique_id', observed=True)\n", |
| 424 | + " [numeric_cols]\n", |
| 425 | + " .bfill()\n", |
| 426 | + " )\n", |
412 | 427 | " resampled_df['ds'] = resampled_df['ds'].astype(str)\n", |
413 | 428 | " return resampled_df\n", |
414 | 429 | "\n", |
|
525 | 540 | " Y_df = self.resample_dataframe(Y_df)\n", |
526 | 541 | " x_cols = []\n", |
527 | 542 | " if X_df is not None:\n", |
528 | | - " x_cols = X_df.drop(columns=['unique_id', 'ds']).columns.to_list()\n", |
| 543 | + " x_cols = X_df.columns.drop(['unique_id', 'ds']).to_list()\n", |
529 | 544 | " if not all(col in df.columns for col in x_cols):\n", |
530 | 545 | " raise Exception(\n", |
531 | 546 | " 'You must include the exogenous variables in the `df` object, '\n", |
|
1961 | 1976 | " )" |
1962 | 1977 | ] |
1963 | 1978 | }, |
| 1979 | + { |
| 1980 | + "cell_type": "code", |
| 1981 | + "execution_count": null, |
| 1982 | + "metadata": {}, |
| 1983 | + "outputs": [], |
| 1984 | + "source": [ |
| 1985 | + "#| hide\n", |
| 1986 | + "# test resample with timestamps at non standard cuts\n", |
| 1987 | + "custom_dates = pd.date_range('2000-01-01 00:04:00', freq='5min', periods=100)\n", |
| 1988 | + "custom_df = pd.DataFrame(\n", |
| 1989 | + " {\n", |
| 1990 | + " 'unique_id': np.repeat(np.array([0, 1]), 50),\n", |
| 1991 | + " 'ds': custom_dates,\n", |
| 1992 | + " 'y': np.arange(100),\n", |
| 1993 | + " }\n", |
| 1994 | + ")\n", |
| 1995 | + "# drop second row from each serie\n", |
| 1996 | + "custom_df = custom_df.drop([1, 51])\n", |
| 1997 | + "model = _NixtlaClientModel(\n", |
| 1998 | + " client=nixtla_client,\n", |
| 1999 | + " h=1,\n", |
| 2000 | + " freq='5min'\n", |
| 2001 | + ")\n", |
| 2002 | + "resampled_df = model.resample_dataframe(custom_df)\n", |
| 2003 | + "# we do a backfill so the second row must've got the value of the third row\n", |
| 2004 | + "assert resampled_df.loc[1, 'y'] == resampled_df.loc[2, 'y']\n", |
| 2005 | + "assert resampled_df.loc[51, 'y'] == resampled_df.loc[52, 'y']\n", |
| 2006 | + "pd.testing.assert_series_equal(\n", |
| 2007 | + " resampled_df['ds'],\n", |
| 2008 | + " custom_dates.to_series(index=resampled_df.index, name='ds').astype(str),\n", |
| 2009 | + ")" |
| 2010 | + ] |
| 2011 | + }, |
1964 | 2012 | { |
1965 | 2013 | "cell_type": "code", |
1966 | 2014 | "execution_count": null, |
|
2269 | 2317 | " min_length=500 if freq != '15T' else 1_200, \n", |
2270 | 2318 | " max_length=550 if freq != '15T' else 2_000,\n", |
2271 | 2319 | " )\n", |
2272 | | - " df_freq['ds'] = df_freq.groupby('unique_id')['ds'].transform(\n", |
| 2320 | + " df_freq['ds'] = df_freq.groupby('unique_id', observed=True)['ds'].transform(\n", |
2273 | 2321 | " lambda x: pd.date_range(periods=len(x), freq=freq, end='2023-01-01')\n", |
2274 | 2322 | " )\n", |
2275 | 2323 | " kwargs = dict(\n", |
|
0 commit comments