Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions nbs/distributed.timegpt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,14 @@
" X_df: Optional[fugue.AnyDataFrame] = None,\n",
" level: Optional[List[Union[int, float]]] = None,\n",
" quantiles: Optional[List[float]] = None,\n",
" fewshot_steps: int = 0,\n",
" fewshot_loss: str = 'default',\n",
" finetune_steps: int = 0,\n",
" finetune_loss: str = 'default',\n",
" clean_ex_first: bool = True,\n",
" validate_token: bool = False,\n",
" add_history: bool = False,\n",
" date_features: Union[bool, List[str]] = False,\n",
" date_features_to_one_hot: Union[bool, List[str]] = True,\n",
" model: str = 'short-horizon',\n",
" model: str = 'timegpt-1',\n",
" num_partitions: Optional[int] = None,\n",
" ) -> fugue.AnyDataFrame:\n",
" kwargs = dict(\n",
Expand All @@ -179,8 +179,8 @@
" target_col=target_col,\n",
" level=level,\n",
" quantiles=quantiles,\n",
" fewshot_steps=fewshot_steps,\n",
" fewshot_loss=fewshot_loss,\n",
" finetune_steps=finetune_steps,\n",
" finetune_loss=finetune_loss,\n",
" clean_ex_first=clean_ex_first,\n",
" validate_token=validate_token,\n",
" add_history=add_history,\n",
Expand Down Expand Up @@ -217,7 +217,7 @@
" validate_token: bool = False,\n",
" date_features: Union[bool, List[str]] = False,\n",
" date_features_to_one_hot: Union[bool, List[str]] = True,\n",
" model: str = 'short-horizon',\n",
" model: str = 'timegpt-1',\n",
" num_partitions: Optional[int] = None,\n",
" ) -> fugue.AnyDataFrame:\n",
" kwargs = dict(\n",
Expand Down Expand Up @@ -254,13 +254,13 @@
" target_col: str = 'y',\n",
" level: Optional[List[Union[int, float]]] = None,\n",
" quantiles: Optional[List[float]] = None,\n",
" fewshot_steps: int = 0,\n",
" fewshot_loss: str = 'default',\n",
" finetune_steps: int = 0,\n",
" finetune_loss: str = 'default',\n",
" clean_ex_first: bool = True,\n",
" validate_token: bool = False,\n",
" date_features: Union[bool, List[str]] = False,\n",
" date_features_to_one_hot: Union[bool, List[str]] = True,\n",
" model: str = 'short-horizon',\n",
" model: str = 'timegpt-1',\n",
" n_windows: int = 1,\n",
" step_size: Optional[int] = None,\n",
" num_partitions: Optional[int] = None,\n",
Expand All @@ -273,8 +273,8 @@
" target_col=target_col,\n",
" level=level,\n",
" quantiles=quantiles,\n",
" fewshot_steps=fewshot_steps,\n",
" fewshot_loss=fewshot_loss,\n",
" finetune_steps=finetune_steps,\n",
" finetune_loss=finetune_loss,\n",
" clean_ex_first=clean_ex_first,\n",
" validate_token=validate_token,\n",
" date_features=date_features,\n",
Expand Down Expand Up @@ -448,7 +448,7 @@
" num_partitions=1,\n",
" id_col=id_col,\n",
" time_col=time_col,\n",
" model='short-horizon',\n",
" model='timegpt-1',\n",
" **fcst_kwargs\n",
" )\n",
" fcst_df = fa.as_pandas(fcst_df)\n",
Expand Down Expand Up @@ -771,7 +771,7 @@
" num_partitions=1,\n",
" id_col=id_col,\n",
" time_col=time_col,\n",
" model='short-horizon',\n",
" model='timegpt-1',\n",
" **anomalies_kwargs\n",
" )\n",
" anomalies_df = fa.as_pandas(anomalies_df)\n",
Expand Down
55 changes: 23 additions & 32 deletions nbs/docs/getting-started/1_getting_started_short.ipynb

Large diffs are not rendered by default.

69 changes: 33 additions & 36 deletions nbs/docs/how-to-guides/0_distributed_fcst_spark.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,7 @@
"execution_count": null,
"id": "fcf6004b-ebd0-4a3c-8c02-d5463c62f79e",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ubuntu/miniconda/envs/nixtlats/lib/python3.11/site-packages/statsforecast/core.py:25: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from tqdm.autonotebook import tqdm\n"
]
}
],
"outputs": [],
"source": [
"from nixtlats import TimeGPT"
]
Expand Down Expand Up @@ -176,8 +167,7 @@
"text": [
"Setting default log level to \"WARN\".\n",
"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
"23/11/09 17:49:01 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n",
"23/11/09 17:49:02 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.\n"
"24/04/01 03:34:43 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
]
}
],
Expand Down Expand Up @@ -242,10 +232,10 @@
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:nixtlats.timegpt:Validating inputs... (4 + 16) / 20]\n",
"INFO:nixtlats.timegpt:Validating inputs...\n",
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
"INFO:nixtlats.timegpt:Inferred freq: H\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...=============> (19 + 1) / 20]\n"
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n"
]
},
{
Expand Down Expand Up @@ -302,42 +292,42 @@
"INFO:nixtlats.timegpt:Validating inputs...\n",
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
"INFO:nixtlats.timegpt:Inferred freq: H\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint... (36 + 60) / 96]\n",
"INFO:nixtlats.timegpt:Validating inputs... (54 + 42) / 96]\n",
"INFO:nixtlats.timegpt:Validating inputs...\n",
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
"INFO:nixtlats.timegpt:Inferred freq: H\n",
"INFO:nixtlats.timegpt:Validating inputs...\n",
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
"INFO:nixtlats.timegpt:Inferred freq: H\n",
"INFO:nixtlats.timegpt:Inferred freq: H\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
"INFO:nixtlats.timegpt:Validating inputs...\n",
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
"INFO:nixtlats.timegpt:Inferred freq: H\n",
"INFO:nixtlats.timegpt:Validating inputs...========> (71 + 25) / 96]\n",
"INFO:nixtlats.timegpt:Validating inputs...\n",
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
"INFO:nixtlats.timegpt:Inferred freq: H\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...=============> (92 + 4) / 96]\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
"INFO:nixtlats.timegpt:Validating inputs... \n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
"INFO:nixtlats.timegpt:Validating inputs...\n",
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
"INFO:nixtlats.timegpt:Inferred freq: H\n",
"INFO:nixtlats.timegpt:Validating inputs...\n",
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
"INFO:nixtlats.timegpt:Inferred freq: H\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...===> (76 + 20) / 96]\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
"INFO:nixtlats.timegpt:Validating inputs...\n",
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
"INFO:nixtlats.timegpt:Inferred freq: H\n",
"INFO:nixtlats.timegpt:Validating inputs...\n",
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
"INFO:nixtlats.timegpt:Inferred freq: H\n",
"INFO:nixtlats.timegpt:Inferred freq: H\n",
"INFO:nixtlats.timegpt:Validating inputs...\n",
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
"INFO:nixtlats.timegpt:Inferred freq: H\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...=============> (92 + 4) / 96]\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...==============> (93 + 3) / 96]\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
" \r"
]
Expand All @@ -347,7 +337,7 @@
"#| hide\n",
"# test different results for different models\n",
"fcst_df_1 = fcst_df.toPandas()\n",
"fcst_df_2 = timegpt.forecast(spark_df, h=12, model='long-horizon')\n",
"fcst_df_2 = timegpt.forecast(spark_df, h=12, model='timegpt-1-long-horizon')\n",
"fcst_df_2 = fcst_df_2.toPandas()\n",
"test_fail(\n",
" lambda: pd.testing.assert_frame_equal(fcst_df_1[['TimeGPT']], fcst_df_2[['TimeGPT']]),\n",
Expand Down Expand Up @@ -464,24 +454,31 @@
"output_type": "stream",
"text": [
"INFO:nixtlats.timegpt:Validating inputs...\n",
"INFO:nixtlats.timegpt:Validating inputs...\n",
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
"INFO:nixtlats.timegpt:Inferred freq: H\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...=============> (19 + 1) / 20]\n"
"INFO:nixtlats.timegpt:Inferred freq: H\n",
"INFO:nixtlats.timegpt:Using the following exogenous variables: Exogenous1, Exogenous2, day_0, day_1, day_2, day_3, day_4, day_5, day_6\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
"INFO:nixtlats.timegpt:Using the following exogenous variables: Exogenous1, Exogenous2, day_0, day_1, day_2, day_3, day_4, day_5, day_6\n",
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
"[Stage 33:=====================================================> (19 + 1) / 20]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+---------+-------------------+------------------+------------------+-----------------+-----------------+------------------+\n",
"|unique_id| ds| TimeGPT| TimeGPT-lo-90| TimeGPT-lo-80| TimeGPT-hi-80| TimeGPT-hi-90|\n",
"+---------+-------------------+------------------+------------------+-----------------+-----------------+------------------+\n",
"| FR|2016-12-31 00:00:00| 64.97691027939692|60.056473801735784|61.71575274765864|68.23806781113521| 69.89734675705805|\n",
"| FR|2016-12-31 01:00:00| 60.14365519077404| 56.12626745731457|56.73784790927991|63.54946247226818| 64.16104292423351|\n",
"| FR|2016-12-31 02:00:00| 59.42375860682185| 54.84932824030574|56.52975776758845|62.31775944605525| 63.99818897333796|\n",
"| FR|2016-12-31 03:00:00| 55.11264928302748| 47.59671153125746|51.95117842731459|58.27412013874037| 62.6285870347975|\n",
"| FR|2016-12-31 04:00:00|54.400922806813526|44.925772896840385|49.65213255412798|59.14971305949907|63.876072716786666|\n",
"+---------+-------------------+------------------+------------------+-----------------+-----------------+------------------+\n",
"+---------+-------------------+------------------+------------------+------------------+-----------------+-----------------+\n",
"|unique_id| ds| TimeGPT| TimeGPT-lo-90| TimeGPT-lo-80| TimeGPT-hi-80| TimeGPT-hi-90|\n",
"+---------+-------------------+------------------+------------------+------------------+-----------------+-----------------+\n",
"| FR|2016-12-31 00:00:00| 59.39155162090687| 54.47111514324573| 56.13039408916859|62.65270915264515| 64.311988098568|\n",
"| FR|2016-12-31 01:00:00| 60.1843929541434|56.167005220683926|56.778585672649264|63.59020023563754|64.20178068760288|\n",
"| FR|2016-12-31 02:00:00| 58.12912691907976| 53.55469655256365| 55.23512607984636|61.02312775831316|62.70355728559587|\n",
"| FR|2016-12-31 03:00:00|53.825965179940155| 46.31002742817014| 50.66449432422726|56.98743603565305|61.34190293171017|\n",
"| FR|2016-12-31 04:00:00| 47.6941769331486| 38.21902702317546| 42.94538668046305|52.44296718583414|57.16932684312174|\n",
"+---------+-------------------+------------------+------------------+------------------+-----------------+-----------------+\n",
"only showing top 5 rows\n",
"\n"
]
Expand Down
Loading