Skip to content

Commit 00d1f3a

Browse files
AzulGarzaMMenchero
authored andcommitted
feat: replace TimeGPT class by NixtlaClient class (#276)
1 parent e2cd9aa commit 00d1f3a

File tree

15 files changed

+592
-401
lines changed

15 files changed

+592
-401
lines changed

.github/workflows/ci.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ jobs:
5656
run: pip install ./
5757

5858
- name: Check import
59-
run: python -c "from nixtlats import TimeGPT;"
59+
run: |
60+
python -c "from nixtlats import TimeGPT;"
61+
python -c "from nixtlats import NixtlaClient;"
6062
6163
run-tests:
6264
runs-on: ${{ matrix.os }}

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,12 @@ Get started with TimeGPT now:
4444
```python
4545
df = pd.read_csv('https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/electricity-short.csv')
4646

47-
from nixtlats import TimeGPT
48-
timegpt = TimeGPT(
47+
from nixtlats import NixtlaClient
48+
nixtla_client = NixtlaClient(
4949
# defaults to os.environ.get("NIXTLA_API_KEY")
5050
api_key = 'my_api_key_provided_by_nixtla'
5151
)
52-
fcst_df = timegpt.forecast(df, h=24, level=[80, 90])
52+
fcst_df = nixtla_client.forecast(df, h=24, level=[80, 90])
5353
```
5454

5555
![](./nbs/img/forecast_readme.png)

action_files/models_performance/main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from utilsforecast.evaluation import evaluate
1212
from utilsforecast.losses import mae, mape, mse
1313

14-
from nixtlats import TimeGPT
14+
from nixtlats import NixtlaClient
1515

1616

1717
logger = logging.getLogger(__name__)
@@ -141,7 +141,7 @@ def evaluate_timegpt(self, model: str) -> Tuple[pd.DataFrame, pd.DataFrame]:
141141
init_time = time()
142142
# A: this sould be replaced with
143143
# cross validation
144-
timegpt = TimeGPT()
144+
timegpt = NixtlaClient()
145145
fcst_df = timegpt.forecast(
146146
df=self.df_train,
147147
X_df=self.df_test.drop(columns=self.target_col)
@@ -200,7 +200,7 @@ def evaluate_benchmark_performace(self) -> Tuple[pd.DataFrame, pd.DataFrame]:
200200

201201
def plot_and_save_forecasts(self, cv_df: pd.DataFrame, plot_dir: str) -> str:
202202
"""Plot ans saves forecasts, returns the path of the plot"""
203-
timegpt = TimeGPT()
203+
timegpt = NixtlaClient()
204204
df = self.df.copy()
205205
df[self.time_col] = pd.to_datetime(df[self.time_col])
206206
if not self.has_id_col:

nbs/distributed.timegpt.ipynb renamed to nbs/distributed.nixtla_client.ipynb

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"metadata": {},
77
"outputs": [],
88
"source": [
9-
"#| default_exp distributed.timegpt"
9+
"#| default_exp distributed.nixtla_client"
1010
]
1111
},
1212
{
@@ -83,7 +83,7 @@
8383
"outputs": [],
8484
"source": [
8585
"#| export\n",
86-
"class _DistributedTimeGPT:\n",
86+
"class _DistributedNixtlaClient:\n",
8787
"\n",
8888
" def __init__(\n",
8989
" self, \n",
@@ -300,49 +300,49 @@
300300
" )\n",
301301
" return fcst_df\n",
302302
" \n",
303-
" def _instantiate_timegpt(self):\n",
304-
" from nixtlats.timegpt import _TimeGPT\n",
305-
" timegpt = _TimeGPT(\n",
303+
" def _instantiate_nixtla_client(self):\n",
304+
" from nixtlats.nixtla_client import _NixtlaClient\n",
305+
" nixtla_client = _NixtlaClient(\n",
306306
" api_key=self.api_key, \n",
307307
" base_url=self.base_url,\n",
308308
" max_retries=self.max_retries,\n",
309309
" retry_interval=self.retry_interval,\n",
310310
" max_wait_time=self.max_wait_time,\n",
311311
" )\n",
312-
" return timegpt\n",
312+
" return nixtla_client\n",
313313
"\n",
314314
" def _forecast(\n",
315315
" self, \n",
316316
" df: pd.DataFrame, \n",
317317
" kwargs,\n",
318318
" ) -> pd.DataFrame:\n",
319-
" timegpt = self._instantiate_timegpt()\n",
320-
" return timegpt._forecast(df=df, **kwargs)\n",
319+
" nixtla_client = self._instantiate_nixtla_client()\n",
320+
" return nixtla_client._forecast(df=df, **kwargs)\n",
321321
"\n",
322322
" def _forecast_x(\n",
323323
" self, \n",
324324
" df: pd.DataFrame, \n",
325325
" X_df: pd.DataFrame,\n",
326326
" kwargs,\n",
327327
" ) -> pd.DataFrame:\n",
328-
" timegpt = self._instantiate_timegpt()\n",
329-
" return timegpt._forecast(df=df, X_df=X_df, **kwargs)\n",
328+
" nixtla_client = self._instantiate_nixtla_client()\n",
329+
" return nixtla_client._forecast(df=df, X_df=X_df, **kwargs)\n",
330330
"\n",
331331
" def _detect_anomalies(\n",
332332
" self, \n",
333333
" df: pd.DataFrame, \n",
334334
" kwargs,\n",
335335
" ) -> pd.DataFrame:\n",
336-
" timegpt = self._instantiate_timegpt()\n",
337-
" return timegpt._detect_anomalies(df=df, **kwargs)\n",
336+
" nixtla_client = self._instantiate_nixtla_client()\n",
337+
" return nixtla_client._detect_anomalies(df=df, **kwargs)\n",
338338
"\n",
339339
" def _cross_validation(\n",
340340
" self, \n",
341341
" df: pd.DataFrame, \n",
342342
" kwargs,\n",
343343
" ) -> pd.DataFrame:\n",
344-
" timegpt = self._instantiate_timegpt()\n",
345-
" return timegpt._cross_validation(df=df, **kwargs)\n",
344+
" nixtla_client = self._instantiate_nixtla_client()\n",
345+
" return nixtla_client._cross_validation(df=df, **kwargs)\n",
346346
" \n",
347347
" @staticmethod\n",
348348
" def _get_forecast_schema(id_col, time_col, level, quantiles, cv=False):\n",
@@ -400,7 +400,7 @@
400400
" time_col: str = 'ds',\n",
401401
" **fcst_kwargs,\n",
402402
" ):\n",
403-
" fcst_df = distributed_timegpt.forecast(\n",
403+
" fcst_df = distributed_nixtla_client.forecast(\n",
404404
" df=df, \n",
405405
" h=horizon,\n",
406406
" id_col=id_col,\n",
@@ -442,7 +442,7 @@
442442
" time_col: str = 'ds',\n",
443443
" **fcst_kwargs,\n",
444444
" ):\n",
445-
" fcst_df = distributed_timegpt.forecast(\n",
445+
" fcst_df = distributed_nixtla_client.forecast(\n",
446446
" df=df, \n",
447447
" h=horizon, \n",
448448
" num_partitions=1,\n",
@@ -452,7 +452,7 @@
452452
" **fcst_kwargs\n",
453453
" )\n",
454454
" fcst_df = fa.as_pandas(fcst_df)\n",
455-
" fcst_df_2 = distributed_timegpt.forecast(\n",
455+
" fcst_df_2 = distributed_nixtla_client.forecast(\n",
456456
" df=df, \n",
457457
" h=horizon, \n",
458458
" num_partitions=1,\n",
@@ -485,7 +485,7 @@
485485
" time_col: str = 'ds',\n",
486486
" **fcst_kwargs,\n",
487487
" ):\n",
488-
" fcst_df = distributed_timegpt.forecast(\n",
488+
" fcst_df = distributed_nixtla_client.forecast(\n",
489489
" df=df, \n",
490490
" h=horizon, \n",
491491
" num_partitions=1,\n",
@@ -494,7 +494,7 @@
494494
" **fcst_kwargs\n",
495495
" )\n",
496496
" fcst_df = fa.as_pandas(fcst_df)\n",
497-
" fcst_df_2 = distributed_timegpt.forecast(\n",
497+
" fcst_df_2 = distributed_nixtla_client.forecast(\n",
498498
" df=df, \n",
499499
" h=horizon, \n",
500500
" num_partitions=2,\n",
@@ -523,7 +523,7 @@
523523
" time_col: str = 'ds',\n",
524524
" **fcst_kwargs,\n",
525525
" ):\n",
526-
" fcst_df = distributed_timegpt.cross_validation(\n",
526+
" fcst_df = distributed_nixtla_client.cross_validation(\n",
527527
" df=df, \n",
528528
" h=horizon, \n",
529529
" num_partitions=1,\n",
@@ -532,7 +532,7 @@
532532
" **fcst_kwargs\n",
533533
" )\n",
534534
" fcst_df = fa.as_pandas(fcst_df)\n",
535-
" fcst_df_2 = distributed_timegpt.cross_validation(\n",
535+
" fcst_df_2 = distributed_nixtla_client.cross_validation(\n",
536536
" df=df, \n",
537537
" h=horizon, \n",
538538
" num_partitions=2,\n",
@@ -592,7 +592,7 @@
592592
" time_col: str = 'ds',\n",
593593
" **fcst_kwargs,\n",
594594
" ):\n",
595-
" fcst_df = distributed_timegpt.forecast(\n",
595+
" fcst_df = distributed_nixtla_client.forecast(\n",
596596
" df=df, \n",
597597
" X_df=X_df,\n",
598598
" h=horizon,\n",
@@ -610,7 +610,7 @@
610610
" exp_cols.extend([f'TimeGPT-lo-{lv}' for lv in reversed(level)])\n",
611611
" exp_cols.extend([f'TimeGPT-hi-{lv}' for lv in level])\n",
612612
" test_eq(cols, exp_cols)\n",
613-
" fcst_df_2 = distributed_timegpt.forecast(\n",
613+
" fcst_df_2 = distributed_nixtla_client.forecast(\n",
614614
" df=df, \n",
615615
" h=horizon,\n",
616616
" id_col=id_col,\n",
@@ -640,7 +640,7 @@
640640
" time_col: str = 'ds',\n",
641641
" **fcst_kwargs,\n",
642642
" ):\n",
643-
" fcst_df = distributed_timegpt.forecast(\n",
643+
" fcst_df = distributed_nixtla_client.forecast(\n",
644644
" df=df, \n",
645645
" X_df=X_df,\n",
646646
" h=horizon, \n",
@@ -650,7 +650,7 @@
650650
" **fcst_kwargs\n",
651651
" )\n",
652652
" fcst_df = fa.as_pandas(fcst_df)\n",
653-
" fcst_df_2 = distributed_timegpt.forecast(\n",
653+
" fcst_df_2 = distributed_nixtla_client.forecast(\n",
654654
" df=df, \n",
655655
" h=horizon, \n",
656656
" num_partitions=2,\n",
@@ -705,7 +705,7 @@
705705
" time_col: str = 'ds',\n",
706706
" **anomalies_kwargs,\n",
707707
" ):\n",
708-
" anomalies_df = distributed_timegpt.detect_anomalies(\n",
708+
" anomalies_df = distributed_nixtla_client.detect_anomalies(\n",
709709
" df=df, \n",
710710
" id_col=id_col,\n",
711711
" time_col=time_col,\n",
@@ -731,15 +731,15 @@
731731
" time_col: str = 'ds',\n",
732732
" **anomalies_kwargs,\n",
733733
" ):\n",
734-
" anomalies_df = distributed_timegpt.detect_anomalies(\n",
734+
" anomalies_df = distributed_nixtla_client.detect_anomalies(\n",
735735
" df=df, \n",
736736
" num_partitions=1,\n",
737737
" id_col=id_col,\n",
738738
" time_col=time_col,\n",
739739
" **anomalies_kwargs\n",
740740
" )\n",
741741
" anomalies_df = fa.as_pandas(anomalies_df)\n",
742-
" anomalies_df_2 = distributed_timegpt.detect_anomalies(\n",
742+
" anomalies_df_2 = distributed_nixtla_client.detect_anomalies(\n",
743743
" df=df, \n",
744744
" num_partitions=2,\n",
745745
" id_col=id_col,\n",
@@ -766,7 +766,7 @@
766766
" time_col: str = 'ds',\n",
767767
" **anomalies_kwargs,\n",
768768
" ):\n",
769-
" anomalies_df = distributed_timegpt.detect_anomalies(\n",
769+
" anomalies_df = distributed_nixtla_client.detect_anomalies(\n",
770770
" df=df, \n",
771771
" num_partitions=1,\n",
772772
" id_col=id_col,\n",
@@ -775,7 +775,7 @@
775775
" **anomalies_kwargs\n",
776776
" )\n",
777777
" anomalies_df = fa.as_pandas(anomalies_df)\n",
778-
" anomalies_df_2 = distributed_timegpt.detect_anomalies(\n",
778+
" anomalies_df_2 = distributed_nixtla_client.detect_anomalies(\n",
779779
" df=df, \n",
780780
" num_partitions=1,\n",
781781
" id_col=id_col,\n",
@@ -844,9 +844,9 @@
844844
" assert all(col in df_qls.columns for col in exp_q_cols)\n",
845845
" # test monotonicity of quantiles\n",
846846
" df_qls.apply(lambda x: x.is_monotonic_increasing, axis=1).sum() == len(exp_q_cols)\n",
847-
" test_method_qls(distributed_timegpt.forecast)\n",
848-
" test_method_qls(distributed_timegpt.forecast, add_history=True)\n",
849-
" test_method_qls(distributed_timegpt.cross_validation)"
847+
" test_method_qls(distributed_nixtla_client.forecast)\n",
848+
" test_method_qls(distributed_nixtla_client.forecast, add_history=True)\n",
849+
" test_method_qls(distributed_nixtla_client.cross_validation)"
850850
]
851851
},
852852
{
@@ -856,7 +856,7 @@
856856
"outputs": [],
857857
"source": [
858858
"#| hide\n",
859-
"distributed_timegpt = _DistributedTimeGPT()"
859+
"distributed_nixtla_client = _DistributedNixtlaClient()"
860860
]
861861
},
862862
{

0 commit comments

Comments
 (0)