diff --git a/nbs/nixtla_client.ipynb b/nbs/nixtla_client.ipynb index efbb062a3..714172098 100644 --- a/nbs/nixtla_client.ipynb +++ b/nbs/nixtla_client.ipynb @@ -57,6 +57,7 @@ " retry_if_exception, \n", " retry_if_not_exception_type,\n", ")\n", + "from utilsforecast.preprocessing import fill_gaps\n", "from utilsforecast.processing import (\n", " backtest_splits,\n", " drop_index_if_pandas,\n", @@ -405,10 +406,24 @@ " self.freq = inferred_freq\n", "\n", " def resample_dataframe(self, df: pd.DataFrame):\n", - " df = df.copy()\n", - " df['ds'] = pd.to_datetime(df['ds'])\n", - " resampled_df = df.set_index('ds').groupby('unique_id').resample(self.freq).bfill()\n", - " resampled_df = resampled_df.drop(columns='unique_id').reset_index()\n", + " if not pd.api.types.is_datetime64_any_dtype(df['ds'].dtype):\n", + " df = df.copy(deep=False)\n", + " df['ds'] = pd.to_datetime(df['ds'])\n", + " resampled_df = fill_gaps(\n", + " df,\n", + " freq=self.freq,\n", + " start='per_serie',\n", + " end='per_serie',\n", + " id_col='unique_id',\n", + " time_col='ds',\n", + " )\n", + " numeric_cols = resampled_df.columns.drop(['unique_id', 'ds'])\n", + " resampled_df[numeric_cols] = (\n", + " resampled_df\n", + " .groupby('unique_id', observed=True)\n", + " [numeric_cols]\n", + " .bfill()\n", + " )\n", " resampled_df['ds'] = resampled_df['ds'].astype(str)\n", " return resampled_df\n", "\n", @@ -525,7 +540,7 @@ " Y_df = self.resample_dataframe(Y_df)\n", " x_cols = []\n", " if X_df is not None:\n", - " x_cols = X_df.drop(columns=['unique_id', 'ds']).columns.to_list()\n", + " x_cols = X_df.columns.drop(['unique_id', 'ds']).to_list()\n", " if not all(col in df.columns for col in x_cols):\n", " raise Exception(\n", " 'You must include the exogenous variables in the `df` object, '\n", @@ -1961,6 +1976,39 @@ " )" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "# test resample with timestamps at non standard cuts\n", + "custom_dates = pd.date_range('2000-01-01 00:04:00', freq='5min', periods=100)\n", + "custom_df = pd.DataFrame(\n", + " {\n", + " 'unique_id': np.repeat(np.array([0, 1]), 50),\n", + " 'ds': custom_dates,\n", + " 'y': np.arange(100),\n", + " }\n", + ")\n", + "# drop second row from each serie\n", + "custom_df = custom_df.drop([1, 51])\n", + "model = _NixtlaClientModel(\n", + " client=nixtla_client,\n", + " h=1,\n", + " freq='5min'\n", + ")\n", + "resampled_df = model.resample_dataframe(custom_df)\n", + "# we do a backfill so the second row must've got the value of the third row\n", + "assert resampled_df.loc[1, 'y'] == resampled_df.loc[2, 'y']\n", + "assert resampled_df.loc[51, 'y'] == resampled_df.loc[52, 'y']\n", + "pd.testing.assert_series_equal(\n", + " resampled_df['ds'],\n", + " custom_dates.to_series(index=resampled_df.index, name='ds').astype(str),\n", + ")" + ] + }, { "cell_type": "code", "execution_count": null, @@ -2269,7 +2317,7 @@ " min_length=500 if freq != '15T' else 1_200, \n", " max_length=550 if freq != '15T' else 2_000,\n", " )\n", - " df_freq['ds'] = df_freq.groupby('unique_id')['ds'].transform(\n", + " df_freq['ds'] = df_freq.groupby('unique_id', observed=True)['ds'].transform(\n", " lambda x: pd.date_range(periods=len(x), freq=freq, end='2023-01-01')\n", " )\n", " kwargs = dict(\n", diff --git a/nixtlats/nixtla_client.py b/nixtlats/nixtla_client.py index b404086d3..6b9ffe552 100644 --- a/nixtlats/nixtla_client.py +++ b/nixtlats/nixtla_client.py @@ -29,6 +29,7 @@ retry_if_exception, retry_if_not_exception_type, ) +from utilsforecast.preprocessing import fill_gaps from utilsforecast.processing import ( backtest_splits, drop_index_if_pandas, @@ -335,12 +336,21 @@ def infer_freq(self, df: pd.DataFrame): self.freq = inferred_freq def resample_dataframe(self, df: pd.DataFrame): - df = df.copy() - df["ds"] = pd.to_datetime(df["ds"]) - resampled_df = ( - df.set_index("ds").groupby("unique_id").resample(self.freq).bfill() + if not pd.api.types.is_datetime64_any_dtype(df["ds"].dtype): + df = df.copy(deep=False) + df["ds"] = pd.to_datetime(df["ds"]) + resampled_df = fill_gaps( + df, + freq=self.freq, + start="per_serie", + end="per_serie", + id_col="unique_id", + time_col="ds", ) - resampled_df = resampled_df.drop(columns="unique_id").reset_index() + numeric_cols = resampled_df.columns.drop(["unique_id", "ds"]) + resampled_df[numeric_cols] = resampled_df.groupby("unique_id", observed=True)[ + numeric_cols + ].bfill() resampled_df["ds"] = resampled_df["ds"].astype(str) return resampled_df @@ -469,7 +479,7 @@ def preprocess_dataframes( Y_df = self.resample_dataframe(Y_df) x_cols = [] if X_df is not None: - x_cols = X_df.drop(columns=["unique_id", "ds"]).columns.to_list() + x_cols = X_df.columns.drop(["unique_id", "ds"]).to_list() if not all(col in df.columns for col in x_cols): raise Exception( "You must include the exogenous variables in the `df` object, " diff --git a/setup.py b/setup.py index fdb0b7fe0..cf6ef704b 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ "statsforecast", ] distributed = ["dask[dataframe]", "fugue[ray]>=0.8.7", "pyspark", "ray[serve-grpc]"] -plotting = ["utilsforecast[plotting]>=0.0.5"] +plotting = ["utilsforecast[plotting]>=0.1.7"] date_extras = ["holidays"] setuptools.setup( @@ -36,7 +36,7 @@ "pydantic<2", "requests", "tenacity", - "utilsforecast>=0.0.13", + "utilsforecast>=0.1.7", ], extras_require={ "dev": dev + distributed + plotting + date_extras,