Skip to content

Commit b85230b

Browse files
authored
fix: recover finetune and deprecate fewshot (#272)
1 parent 6dfba26 commit b85230b

19 files changed

+435
-495
lines changed

nbs/distributed.timegpt.ipynb

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,8 @@
161161
" X_df: Optional[fugue.AnyDataFrame] = None,\n",
162162
" level: Optional[List[Union[int, float]]] = None,\n",
163163
" quantiles: Optional[List[float]] = None,\n",
164-
" fewshot_steps: int = 0,\n",
165-
" fewshot_loss: str = 'default',\n",
164+
" finetune_steps: int = 0,\n",
165+
" finetune_loss: str = 'default',\n",
166166
" clean_ex_first: bool = True,\n",
167167
" validate_token: bool = False,\n",
168168
" add_history: bool = False,\n",
@@ -179,8 +179,8 @@
179179
" target_col=target_col,\n",
180180
" level=level,\n",
181181
" quantiles=quantiles,\n",
182-
" fewshot_steps=fewshot_steps,\n",
183-
" fewshot_loss=fewshot_loss,\n",
182+
" finetune_steps=finetune_steps,\n",
183+
" finetune_loss=finetune_loss,\n",
184184
" clean_ex_first=clean_ex_first,\n",
185185
" validate_token=validate_token,\n",
186186
" add_history=add_history,\n",
@@ -254,8 +254,8 @@
254254
" target_col: str = 'y',\n",
255255
" level: Optional[List[Union[int, float]]] = None,\n",
256256
" quantiles: Optional[List[float]] = None,\n",
257-
" fewshot_steps: int = 0,\n",
258-
" fewshot_loss: str = 'default',\n",
257+
" finetune_steps: int = 0,\n",
258+
" finetune_loss: str = 'default',\n",
259259
" clean_ex_first: bool = True,\n",
260260
" validate_token: bool = False,\n",
261261
" date_features: Union[bool, List[str]] = False,\n",
@@ -273,8 +273,8 @@
273273
" target_col=target_col,\n",
274274
" level=level,\n",
275275
" quantiles=quantiles,\n",
276-
" fewshot_steps=fewshot_steps,\n",
277-
" fewshot_loss=fewshot_loss,\n",
276+
" finetune_steps=finetune_steps,\n",
277+
" finetune_loss=finetune_loss,\n",
278278
" clean_ex_first=clean_ex_first,\n",
279279
" validate_token=validate_token,\n",
280280
" date_features=date_features,\n",

nbs/docs/getting-started/1_getting_started_short.ipynb

Lines changed: 23 additions & 32 deletions
Large diffs are not rendered by default.

nbs/docs/how-to-guides/0_distributed_fcst_spark.ipynb

Lines changed: 32 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -118,16 +118,7 @@
118118
"execution_count": null,
119119
"id": "fcf6004b-ebd0-4a3c-8c02-d5463c62f79e",
120120
"metadata": {},
121-
"outputs": [
122-
{
123-
"name": "stderr",
124-
"output_type": "stream",
125-
"text": [
126-
"/home/ubuntu/miniconda/envs/nixtlats/lib/python3.11/site-packages/statsforecast/core.py:25: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
127-
" from tqdm.autonotebook import tqdm\n"
128-
]
129-
}
130-
],
121+
"outputs": [],
131122
"source": [
132123
"from nixtlats import TimeGPT"
133124
]
@@ -176,8 +167,7 @@
176167
"text": [
177168
"Setting default log level to \"WARN\".\n",
178169
"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
179-
"23/11/09 17:49:01 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n",
180-
"23/11/09 17:49:02 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.\n"
170+
"24/04/01 03:34:43 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
181171
]
182172
}
183173
],
@@ -242,10 +232,10 @@
242232
"name": "stderr",
243233
"output_type": "stream",
244234
"text": [
245-
"INFO:nixtlats.timegpt:Validating inputs... (4 + 16) / 20]\n",
235+
"INFO:nixtlats.timegpt:Validating inputs...\n",
246236
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
247237
"INFO:nixtlats.timegpt:Inferred freq: H\n",
248-
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...=============> (19 + 1) / 20]\n"
238+
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n"
249239
]
250240
},
251241
{
@@ -302,42 +292,42 @@
302292
"INFO:nixtlats.timegpt:Validating inputs...\n",
303293
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
304294
"INFO:nixtlats.timegpt:Inferred freq: H\n",
305-
"INFO:nixtlats.timegpt:Calling Forecast Endpoint... (36 + 60) / 96]\n",
306-
"INFO:nixtlats.timegpt:Validating inputs... (54 + 42) / 96]\n",
295+
"INFO:nixtlats.timegpt:Validating inputs...\n",
307296
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
308-
"INFO:nixtlats.timegpt:Inferred freq: H\n",
309297
"INFO:nixtlats.timegpt:Validating inputs...\n",
310298
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
311299
"INFO:nixtlats.timegpt:Inferred freq: H\n",
300+
"INFO:nixtlats.timegpt:Inferred freq: H\n",
301+
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
312302
"INFO:nixtlats.timegpt:Validating inputs...\n",
313303
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
314304
"INFO:nixtlats.timegpt:Inferred freq: H\n",
315-
"INFO:nixtlats.timegpt:Validating inputs...========> (71 + 25) / 96]\n",
305+
"INFO:nixtlats.timegpt:Validating inputs...\n",
316306
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
317307
"INFO:nixtlats.timegpt:Inferred freq: H\n",
318-
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...=============> (92 + 4) / 96]\n",
319308
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
320309
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
321310
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
322-
"INFO:nixtlats.timegpt:Validating inputs... \n",
311+
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
312+
"INFO:nixtlats.timegpt:Validating inputs...\n",
323313
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
324314
"INFO:nixtlats.timegpt:Inferred freq: H\n",
325315
"INFO:nixtlats.timegpt:Validating inputs...\n",
326316
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
327317
"INFO:nixtlats.timegpt:Inferred freq: H\n",
328-
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...===> (76 + 20) / 96]\n",
329-
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
330318
"INFO:nixtlats.timegpt:Validating inputs...\n",
331319
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
332-
"INFO:nixtlats.timegpt:Inferred freq: H\n",
333320
"INFO:nixtlats.timegpt:Validating inputs...\n",
334321
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
335322
"INFO:nixtlats.timegpt:Inferred freq: H\n",
323+
"INFO:nixtlats.timegpt:Inferred freq: H\n",
336324
"INFO:nixtlats.timegpt:Validating inputs...\n",
337325
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
338326
"INFO:nixtlats.timegpt:Inferred freq: H\n",
339-
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...=============> (92 + 4) / 96]\n",
340-
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...==============> (93 + 3) / 96]\n",
327+
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
328+
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
329+
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
330+
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
341331
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
342332
" \r"
343333
]
@@ -464,24 +454,31 @@
464454
"output_type": "stream",
465455
"text": [
466456
"INFO:nixtlats.timegpt:Validating inputs...\n",
457+
"INFO:nixtlats.timegpt:Validating inputs...\n",
458+
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
467459
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
468460
"INFO:nixtlats.timegpt:Inferred freq: H\n",
469-
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...=============> (19 + 1) / 20]\n"
461+
"INFO:nixtlats.timegpt:Inferred freq: H\n",
462+
"INFO:nixtlats.timegpt:Using the following exogenous variables: Exogenous1, Exogenous2, day_0, day_1, day_2, day_3, day_4, day_5, day_6\n",
463+
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
464+
"INFO:nixtlats.timegpt:Using the following exogenous variables: Exogenous1, Exogenous2, day_0, day_1, day_2, day_3, day_4, day_5, day_6\n",
465+
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
466+
"[Stage 33:=====================================================> (19 + 1) / 20]\r"
470467
]
471468
},
472469
{
473470
"name": "stdout",
474471
"output_type": "stream",
475472
"text": [
476-
"+---------+-------------------+------------------+------------------+-----------------+-----------------+------------------+\n",
477-
"|unique_id| ds| TimeGPT| TimeGPT-lo-90| TimeGPT-lo-80| TimeGPT-hi-80| TimeGPT-hi-90|\n",
478-
"+---------+-------------------+------------------+------------------+-----------------+-----------------+------------------+\n",
479-
"| FR|2016-12-31 00:00:00| 64.97691027939692|60.056473801735784|61.71575274765864|68.23806781113521| 69.89734675705805|\n",
480-
"| FR|2016-12-31 01:00:00| 60.14365519077404| 56.12626745731457|56.73784790927991|63.54946247226818| 64.16104292423351|\n",
481-
"| FR|2016-12-31 02:00:00| 59.42375860682185| 54.84932824030574|56.52975776758845|62.31775944605525| 63.99818897333796|\n",
482-
"| FR|2016-12-31 03:00:00| 55.11264928302748| 47.59671153125746|51.95117842731459|58.27412013874037| 62.6285870347975|\n",
483-
"| FR|2016-12-31 04:00:00|54.400922806813526|44.925772896840385|49.65213255412798|59.14971305949907|63.876072716786666|\n",
484-
"+---------+-------------------+------------------+------------------+-----------------+-----------------+------------------+\n",
473+
"+---------+-------------------+------------------+------------------+------------------+-----------------+-----------------+\n",
474+
"|unique_id| ds| TimeGPT| TimeGPT-lo-90| TimeGPT-lo-80| TimeGPT-hi-80| TimeGPT-hi-90|\n",
475+
"+---------+-------------------+------------------+------------------+------------------+-----------------+-----------------+\n",
476+
"| FR|2016-12-31 00:00:00| 59.39155162090687| 54.47111514324573| 56.13039408916859|62.65270915264515| 64.311988098568|\n",
477+
"| FR|2016-12-31 01:00:00| 60.1843929541434|56.167005220683926|56.778585672649264|63.59020023563754|64.20178068760288|\n",
478+
"| FR|2016-12-31 02:00:00| 58.12912691907976| 53.55469655256365| 55.23512607984636|61.02312775831316|62.70355728559587|\n",
479+
"| FR|2016-12-31 03:00:00|53.825965179940155| 46.31002742817014| 50.66449432422726|56.98743603565305|61.34190293171017|\n",
480+
"| FR|2016-12-31 04:00:00| 47.6941769331486| 38.21902702317546| 42.94538668046305|52.44296718583414|57.16932684312174|\n",
481+
"+---------+-------------------+------------------+------------------+------------------+-----------------+-----------------+\n",
485482
"only showing top 5 rows\n",
486483
"\n"
487484
]

nbs/docs/how-to-guides/1_distributed_cv_spark.ipynb

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -118,16 +118,7 @@
118118
"execution_count": null,
119119
"id": "21bbe459-ed98-4ac1-8da7-2287305b3680",
120120
"metadata": {},
121-
"outputs": [
122-
{
123-
"name": "stderr",
124-
"output_type": "stream",
125-
"text": [
126-
"/home/ubuntu/miniconda/envs/nixtlats/lib/python3.11/site-packages/statsforecast/core.py:25: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
127-
" from tqdm.autonotebook import tqdm\n"
128-
]
129-
}
130-
],
121+
"outputs": [],
131122
"source": [
132123
"from nixtlats import TimeGPT"
133124
]
@@ -176,10 +167,7 @@
176167
"text": [
177168
"Setting default log level to \"WARN\".\n",
178169
"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
179-
"23/11/09 17:49:20 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n",
180-
"23/11/09 17:49:21 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.\n",
181-
"23/11/09 17:49:21 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.\n",
182-
"23/11/09 17:49:21 WARN Utils: Service 'SparkUI' could not bind on port 4042. Attempting port 4043.\n"
170+
"24/04/01 03:35:32 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
183171
]
184172
}
185173
],
@@ -244,12 +232,12 @@
244232
"name": "stderr",
245233
"output_type": "stream",
246234
"text": [
247-
"INFO:nixtlats.timegpt:Validating inputs... (5 + 15) / 20]\n",
235+
"INFO:nixtlats.timegpt:Validating inputs...\n",
248236
"INFO:nixtlats.timegpt:Inferred freq: H\n",
249237
"INFO:nixtlats.timegpt:Validating inputs...\n",
250238
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
251239
"INFO:nixtlats.timegpt:Inferred freq: H\n",
252-
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...=============> (19 + 1) / 20]\n",
240+
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
253241
"INFO:nixtlats.timegpt:Validating inputs...\n",
254242
"INFO:nixtlats.timegpt:Validating inputs...\n",
255243
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
@@ -379,51 +367,51 @@
379367
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
380368
"INFO:nixtlats.timegpt:Inferred freq: H\n",
381369
"WARNING:nixtlats.timegpt:The specified horizon \"h\" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.\n",
382-
"INFO:nixtlats.timegpt:Restricting input...\n",
370+
"INFO:nixtlats.timegpt:Using the following exogenous variables: Exogenous1, Exogenous2, day_0, day_1, day_2, day_3, day_4, day_5, day_6\n",
383371
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
384-
"INFO:nixtlats.timegpt:Validating inputs...=====================> (19 + 1) / 20]\n",
372+
"INFO:nixtlats.timegpt:Validating inputs...\n",
385373
"INFO:nixtlats.timegpt:Validating inputs...\n",
386374
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
387375
"INFO:nixtlats.timegpt:Inferred freq: H\n",
388376
"WARNING:nixtlats.timegpt:The specified horizon \"h\" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.\n",
389-
"INFO:nixtlats.timegpt:Restricting input...\n",
377+
"INFO:nixtlats.timegpt:Using the following exogenous variables: Exogenous1, Exogenous2, day_0, day_1, day_2, day_3, day_4, day_5, day_6\n",
390378
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
391379
"INFO:nixtlats.timegpt:Validating inputs...\n",
392380
"INFO:nixtlats.timegpt:Validating inputs...\n",
393381
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
394382
"INFO:nixtlats.timegpt:Inferred freq: H\n",
395383
"WARNING:nixtlats.timegpt:The specified horizon \"h\" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.\n",
396-
"INFO:nixtlats.timegpt:Restricting input...\n",
384+
"INFO:nixtlats.timegpt:Using the following exogenous variables: Exogenous1, Exogenous2, day_0, day_1, day_2, day_3, day_4, day_5, day_6\n",
397385
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
398386
"INFO:nixtlats.timegpt:Validating inputs...\n",
399387
"INFO:nixtlats.timegpt:Validating inputs...\n",
400388
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
401389
"INFO:nixtlats.timegpt:Inferred freq: H\n",
402390
"WARNING:nixtlats.timegpt:The specified horizon \"h\" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.\n",
403-
"INFO:nixtlats.timegpt:Restricting input...\n",
391+
"INFO:nixtlats.timegpt:Using the following exogenous variables: Exogenous1, Exogenous2, day_0, day_1, day_2, day_3, day_4, day_5, day_6\n",
404392
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n",
405393
"INFO:nixtlats.timegpt:Validating inputs...\n",
406394
"INFO:nixtlats.timegpt:Validating inputs...\n",
407395
"INFO:nixtlats.timegpt:Preprocessing dataframes...\n",
408396
"INFO:nixtlats.timegpt:Inferred freq: H\n",
409397
"WARNING:nixtlats.timegpt:The specified horizon \"h\" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.\n",
410-
"INFO:nixtlats.timegpt:Restricting input...\n",
398+
"INFO:nixtlats.timegpt:Using the following exogenous variables: Exogenous1, Exogenous2, day_0, day_1, day_2, day_3, day_4, day_5, day_6\n",
411399
"INFO:nixtlats.timegpt:Calling Forecast Endpoint...\n"
412400
]
413401
},
414402
{
415403
"name": "stdout",
416404
"output_type": "stream",
417405
"text": [
418-
"+---------+-------------------+-------------------+------------------+------------------+------------------+------------------+------------------+\n",
419-
"|unique_id| ds| cutoff| TimeGPT| TimeGPT-lo-90| TimeGPT-lo-80| TimeGPT-hi-80| TimeGPT-hi-90|\n",
420-
"+---------+-------------------+-------------------+------------------+------------------+------------------+------------------+------------------+\n",
421-
"| FR|2016-12-21 00:00:00|2016-12-20 23:00:00| 57.46266174316406| 54.32243190002441|54.725050598144534| 60.20027288818359|60.602891586303706|\n",
422-
"| FR|2016-12-21 01:00:00|2016-12-20 23:00:00|52.549095153808594|50.111817771911625| 50.20576373291016| 54.89242657470703| 54.98637253570556|\n",
423-
"| FR|2016-12-21 02:00:00|2016-12-20 23:00:00| 49.98523712158203|47.396572181701664| 48.40804647827149|51.562427764892576| 52.5739020614624|\n",
424-
"| FR|2016-12-21 03:00:00|2016-12-20 23:00:00| 49.146240234375| 46.38533438110352| 46.51724838256836| 51.77523208618164| 51.90714608764648|\n",
425-
"| FR|2016-12-21 04:00:00|2016-12-20 23:00:00| 47.01085662841797| 42.29354175567627|42.783941421508786|51.237771835327145|51.728171501159665|\n",
426-
"+---------+-------------------+-------------------+------------------+------------------+------------------+------------------+------------------+\n",
406+
"+---------+-------------------+-------------------+------------------+------------------+-----------------+------------------+------------------+\n",
407+
"|unique_id| ds| cutoff| TimeGPT| TimeGPT-lo-90| TimeGPT-lo-80| TimeGPT-hi-80| TimeGPT-hi-90|\n",
408+
"+---------+-------------------+-------------------+------------------+------------------+-----------------+------------------+------------------+\n",
409+
"| FR|2016-12-21 00:00:00|2016-12-20 23:00:00| 66.39748296460945| 62.03776876172859|63.28946471509773| 69.50550121412117| 70.7571971674903|\n",
410+
"| FR|2016-12-21 01:00:00|2016-12-20 23:00:00| 63.71841894125738|59.770956050632385|61.16832944845953| 66.26850843405524| 67.66588183188237|\n",
411+
"| FR|2016-12-21 02:00:00|2016-12-20 23:00:00| 61.13784444132001| 58.88184931650312| 59.5156742600456|62.760014622594426|63.393839566136904|\n",
412+
"| FR|2016-12-21 03:00:00|2016-12-20 23:00:00| 55.77490648975175|53.047358607671676|53.22071413745683|58.329098842046676|58.502454371831824|\n",
413+
"| FR|2016-12-21 04:00:00|2016-12-20 23:00:00|48.803786601770284| 44.10176355336941|44.58027990316188| 53.02729330037869| 53.50580965017116|\n",
414+
"+---------+-------------------+-------------------+------------------+------------------+-----------------+------------------+------------------+\n",
427415
"only showing top 5 rows\n",
428416
"\n"
429417
]

0 commit comments

Comments
 (0)