Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 54 additions & 6 deletions nbs/nixtla_client.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
" retry_if_exception, \n",
" retry_if_not_exception_type,\n",
")\n",
"from utilsforecast.preprocessing import fill_gaps\n",
"from utilsforecast.processing import (\n",
" backtest_splits,\n",
" drop_index_if_pandas,\n",
Expand Down Expand Up @@ -405,10 +406,24 @@
" self.freq = inferred_freq\n",
"\n",
" def resample_dataframe(self, df: pd.DataFrame):\n",
" df = df.copy()\n",
" df['ds'] = pd.to_datetime(df['ds'])\n",
" resampled_df = df.set_index('ds').groupby('unique_id').resample(self.freq).bfill()\n",
" resampled_df = resampled_df.drop(columns='unique_id').reset_index()\n",
" if not pd.api.types.is_datetime64_any_dtype(df['ds'].dtype):\n",
" df = df.copy(deep=False)\n",
" df['ds'] = pd.to_datetime(df['ds'])\n",
" resampled_df = fill_gaps(\n",
" df,\n",
" freq=self.freq,\n",
" start='per_serie',\n",
" end='per_serie',\n",
" id_col='unique_id',\n",
" time_col='ds',\n",
" )\n",
" numeric_cols = resampled_df.columns.drop(['unique_id', 'ds'])\n",
" resampled_df[numeric_cols] = (\n",
" resampled_df\n",
" .groupby('unique_id', observed=True)\n",
" [numeric_cols]\n",
" .bfill()\n",
" )\n",
" resampled_df['ds'] = resampled_df['ds'].astype(str)\n",
" return resampled_df\n",
"\n",
Expand Down Expand Up @@ -525,7 +540,7 @@
" Y_df = self.resample_dataframe(Y_df)\n",
" x_cols = []\n",
" if X_df is not None:\n",
" x_cols = X_df.drop(columns=['unique_id', 'ds']).columns.to_list()\n",
" x_cols = X_df.columns.drop(['unique_id', 'ds']).to_list()\n",
" if not all(col in df.columns for col in x_cols):\n",
" raise Exception(\n",
" 'You must include the exogenous variables in the `df` object, '\n",
Expand Down Expand Up @@ -1961,6 +1976,39 @@
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# test resample with timestamps at non standard cuts\n",
"custom_dates = pd.date_range('2000-01-01 00:04:00', freq='5min', periods=100)\n",
"custom_df = pd.DataFrame(\n",
" {\n",
" 'unique_id': np.repeat(np.array([0, 1]), 50),\n",
" 'ds': custom_dates,\n",
" 'y': np.arange(100),\n",
" }\n",
")\n",
"# drop second row from each serie\n",
"custom_df = custom_df.drop([1, 51])\n",
"model = _NixtlaClientModel(\n",
" client=nixtla_client,\n",
" h=1,\n",
" freq='5min'\n",
")\n",
"resampled_df = model.resample_dataframe(custom_df)\n",
"# we do a backfill so the second row must've got the value of the third row\n",
"assert resampled_df.loc[1, 'y'] == resampled_df.loc[2, 'y']\n",
"assert resampled_df.loc[51, 'y'] == resampled_df.loc[52, 'y']\n",
"pd.testing.assert_series_equal(\n",
" resampled_df['ds'],\n",
" custom_dates.to_series(index=resampled_df.index, name='ds').astype(str),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -2269,7 +2317,7 @@
" min_length=500 if freq != '15T' else 1_200, \n",
" max_length=550 if freq != '15T' else 2_000,\n",
" )\n",
" df_freq['ds'] = df_freq.groupby('unique_id')['ds'].transform(\n",
" df_freq['ds'] = df_freq.groupby('unique_id', observed=True)['ds'].transform(\n",
" lambda x: pd.date_range(periods=len(x), freq=freq, end='2023-01-01')\n",
" )\n",
" kwargs = dict(\n",
Expand Down
22 changes: 16 additions & 6 deletions nixtlats/nixtla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
retry_if_exception,
retry_if_not_exception_type,
)
from utilsforecast.preprocessing import fill_gaps
from utilsforecast.processing import (
backtest_splits,
drop_index_if_pandas,
Expand Down Expand Up @@ -335,12 +336,21 @@ def infer_freq(self, df: pd.DataFrame):
self.freq = inferred_freq

def resample_dataframe(self, df: pd.DataFrame):
df = df.copy()
df["ds"] = pd.to_datetime(df["ds"])
resampled_df = (
df.set_index("ds").groupby("unique_id").resample(self.freq).bfill()
if not pd.api.types.is_datetime64_any_dtype(df["ds"].dtype):
df = df.copy(deep=False)
df["ds"] = pd.to_datetime(df["ds"])
resampled_df = fill_gaps(
df,
freq=self.freq,
start="per_serie",
end="per_serie",
id_col="unique_id",
time_col="ds",
)
resampled_df = resampled_df.drop(columns="unique_id").reset_index()
numeric_cols = resampled_df.columns.drop(["unique_id", "ds"])
resampled_df[numeric_cols] = resampled_df.groupby("unique_id", observed=True)[
numeric_cols
].bfill()
resampled_df["ds"] = resampled_df["ds"].astype(str)
return resampled_df

Expand Down Expand Up @@ -469,7 +479,7 @@ def preprocess_dataframes(
Y_df = self.resample_dataframe(Y_df)
x_cols = []
if X_df is not None:
x_cols = X_df.drop(columns=["unique_id", "ds"]).columns.to_list()
x_cols = X_df.columns.drop(["unique_id", "ds"]).to_list()
if not all(col in df.columns for col in x_cols):
raise Exception(
"You must include the exogenous variables in the `df` object, "
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"statsforecast",
]
distributed = ["dask[dataframe]", "fugue[ray]>=0.8.7", "pyspark", "ray[serve-grpc]"]
plotting = ["utilsforecast[plotting]>=0.0.5"]
plotting = ["utilsforecast[plotting]>=0.1.7"]
date_extras = ["holidays"]

setuptools.setup(
Expand All @@ -36,7 +36,7 @@
"pydantic<2",
"requests",
"tenacity",
"utilsforecast>=0.0.13",
"utilsforecast>=0.1.7",
],
extras_require={
"dev": dev + distributed + plotting + date_extras,
Expand Down