|
6 | 6 | "metadata": {}, |
7 | 7 | "outputs": [], |
8 | 8 | "source": [ |
9 | | - "#| default_exp distributed.timegpt" |
| 9 | + "#| default_exp distributed.nixtla_client" |
10 | 10 | ] |
11 | 11 | }, |
12 | 12 | { |
|
83 | 83 | "outputs": [], |
84 | 84 | "source": [ |
85 | 85 | "#| export\n", |
86 | | - "class _DistributedTimeGPT:\n", |
| 86 | + "class _DistributedNixtlaClient:\n", |
87 | 87 | "\n", |
88 | 88 | " def __init__(\n", |
89 | 89 | " self, \n", |
|
300 | 300 | " )\n", |
301 | 301 | " return fcst_df\n", |
302 | 302 | " \n", |
303 | | - " def _instantiate_timegpt(self):\n", |
304 | | - " from nixtlats.timegpt import _TimeGPT\n", |
305 | | - " timegpt = _TimeGPT(\n", |
| 303 | + " def _instantiate_nixtla_client(self):\n", |
| 304 | + " from nixtlats.nixtla_client import _NixtlaClient\n", |
| 305 | + " nixtla_client = _NixtlaClient(\n", |
306 | 306 | " api_key=self.api_key, \n", |
307 | 307 | " base_url=self.base_url,\n", |
308 | 308 | " max_retries=self.max_retries,\n", |
309 | 309 | " retry_interval=self.retry_interval,\n", |
310 | 310 | " max_wait_time=self.max_wait_time,\n", |
311 | 311 | " )\n", |
312 | | - " return timegpt\n", |
| 312 | + " return nixtla_client\n", |
313 | 313 | "\n", |
314 | 314 | " def _forecast(\n", |
315 | 315 | " self, \n", |
316 | 316 | " df: pd.DataFrame, \n", |
317 | 317 | " kwargs,\n", |
318 | 318 | " ) -> pd.DataFrame:\n", |
319 | | - " timegpt = self._instantiate_timegpt()\n", |
320 | | - " return timegpt._forecast(df=df, **kwargs)\n", |
| 319 | + " nixtla_client = self._instantiate_nixtla_client()\n", |
| 320 | + " return nixtla_client._forecast(df=df, **kwargs)\n", |
321 | 321 | "\n", |
322 | 322 | " def _forecast_x(\n", |
323 | 323 | " self, \n", |
324 | 324 | " df: pd.DataFrame, \n", |
325 | 325 | " X_df: pd.DataFrame,\n", |
326 | 326 | " kwargs,\n", |
327 | 327 | " ) -> pd.DataFrame:\n", |
328 | | - " timegpt = self._instantiate_timegpt()\n", |
329 | | - " return timegpt._forecast(df=df, X_df=X_df, **kwargs)\n", |
| 328 | + " nixtla_client = self._instantiate_nixtla_client()\n", |
| 329 | + " return nixtla_client._forecast(df=df, X_df=X_df, **kwargs)\n", |
330 | 330 | "\n", |
331 | 331 | " def _detect_anomalies(\n", |
332 | 332 | " self, \n", |
333 | 333 | " df: pd.DataFrame, \n", |
334 | 334 | " kwargs,\n", |
335 | 335 | " ) -> pd.DataFrame:\n", |
336 | | - " timegpt = self._instantiate_timegpt()\n", |
337 | | - " return timegpt._detect_anomalies(df=df, **kwargs)\n", |
| 336 | + " nixtla_client = self._instantiate_nixtla_client()\n", |
| 337 | + " return nixtla_client._detect_anomalies(df=df, **kwargs)\n", |
338 | 338 | "\n", |
339 | 339 | " def _cross_validation(\n", |
340 | 340 | " self, \n", |
341 | 341 | " df: pd.DataFrame, \n", |
342 | 342 | " kwargs,\n", |
343 | 343 | " ) -> pd.DataFrame:\n", |
344 | | - " timegpt = self._instantiate_timegpt()\n", |
345 | | - " return timegpt._cross_validation(df=df, **kwargs)\n", |
| 344 | + " nixtla_client = self._instantiate_nixtla_client()\n", |
| 345 | + " return nixtla_client._cross_validation(df=df, **kwargs)\n", |
346 | 346 | " \n", |
347 | 347 | " @staticmethod\n", |
348 | 348 | " def _get_forecast_schema(id_col, time_col, level, quantiles, cv=False):\n", |
|
400 | 400 | " time_col: str = 'ds',\n", |
401 | 401 | " **fcst_kwargs,\n", |
402 | 402 | " ):\n", |
403 | | - " fcst_df = distributed_timegpt.forecast(\n", |
| 403 | + " fcst_df = distributed_nixtla_client.forecast(\n", |
404 | 404 | " df=df, \n", |
405 | 405 | " h=horizon,\n", |
406 | 406 | " id_col=id_col,\n", |
|
442 | 442 | " time_col: str = 'ds',\n", |
443 | 443 | " **fcst_kwargs,\n", |
444 | 444 | " ):\n", |
445 | | - " fcst_df = distributed_timegpt.forecast(\n", |
| 445 | + " fcst_df = distributed_nixtla_client.forecast(\n", |
446 | 446 | " df=df, \n", |
447 | 447 | " h=horizon, \n", |
448 | 448 | " num_partitions=1,\n", |
|
452 | 452 | " **fcst_kwargs\n", |
453 | 453 | " )\n", |
454 | 454 | " fcst_df = fa.as_pandas(fcst_df)\n", |
455 | | - " fcst_df_2 = distributed_timegpt.forecast(\n", |
| 455 | + " fcst_df_2 = distributed_nixtla_client.forecast(\n", |
456 | 456 | " df=df, \n", |
457 | 457 | " h=horizon, \n", |
458 | 458 | " num_partitions=1,\n", |
|
485 | 485 | " time_col: str = 'ds',\n", |
486 | 486 | " **fcst_kwargs,\n", |
487 | 487 | " ):\n", |
488 | | - " fcst_df = distributed_timegpt.forecast(\n", |
| 488 | + " fcst_df = distributed_nixtla_client.forecast(\n", |
489 | 489 | " df=df, \n", |
490 | 490 | " h=horizon, \n", |
491 | 491 | " num_partitions=1,\n", |
|
494 | 494 | " **fcst_kwargs\n", |
495 | 495 | " )\n", |
496 | 496 | " fcst_df = fa.as_pandas(fcst_df)\n", |
497 | | - " fcst_df_2 = distributed_timegpt.forecast(\n", |
| 497 | + " fcst_df_2 = distributed_nixtla_client.forecast(\n", |
498 | 498 | " df=df, \n", |
499 | 499 | " h=horizon, \n", |
500 | 500 | " num_partitions=2,\n", |
|
523 | 523 | " time_col: str = 'ds',\n", |
524 | 524 | " **fcst_kwargs,\n", |
525 | 525 | " ):\n", |
526 | | - " fcst_df = distributed_timegpt.cross_validation(\n", |
| 526 | + " fcst_df = distributed_nixtla_client.cross_validation(\n", |
527 | 527 | " df=df, \n", |
528 | 528 | " h=horizon, \n", |
529 | 529 | " num_partitions=1,\n", |
|
532 | 532 | " **fcst_kwargs\n", |
533 | 533 | " )\n", |
534 | 534 | " fcst_df = fa.as_pandas(fcst_df)\n", |
535 | | - " fcst_df_2 = distributed_timegpt.cross_validation(\n", |
| 535 | + " fcst_df_2 = distributed_nixtla_client.cross_validation(\n", |
536 | 536 | " df=df, \n", |
537 | 537 | " h=horizon, \n", |
538 | 538 | " num_partitions=2,\n", |
|
592 | 592 | " time_col: str = 'ds',\n", |
593 | 593 | " **fcst_kwargs,\n", |
594 | 594 | " ):\n", |
595 | | - " fcst_df = distributed_timegpt.forecast(\n", |
| 595 | + " fcst_df = distributed_nixtla_client.forecast(\n", |
596 | 596 | " df=df, \n", |
597 | 597 | " X_df=X_df,\n", |
598 | 598 | " h=horizon,\n", |
|
610 | 610 | " exp_cols.extend([f'TimeGPT-lo-{lv}' for lv in reversed(level)])\n", |
611 | 611 | " exp_cols.extend([f'TimeGPT-hi-{lv}' for lv in level])\n", |
612 | 612 | " test_eq(cols, exp_cols)\n", |
613 | | - " fcst_df_2 = distributed_timegpt.forecast(\n", |
| 613 | + " fcst_df_2 = distributed_nixtla_client.forecast(\n", |
614 | 614 | " df=df, \n", |
615 | 615 | " h=horizon,\n", |
616 | 616 | " id_col=id_col,\n", |
|
640 | 640 | " time_col: str = 'ds',\n", |
641 | 641 | " **fcst_kwargs,\n", |
642 | 642 | " ):\n", |
643 | | - " fcst_df = distributed_timegpt.forecast(\n", |
| 643 | + " fcst_df = distributed_nixtla_client.forecast(\n", |
644 | 644 | " df=df, \n", |
645 | 645 | " X_df=X_df,\n", |
646 | 646 | " h=horizon, \n", |
|
650 | 650 | " **fcst_kwargs\n", |
651 | 651 | " )\n", |
652 | 652 | " fcst_df = fa.as_pandas(fcst_df)\n", |
653 | | - " fcst_df_2 = distributed_timegpt.forecast(\n", |
| 653 | + " fcst_df_2 = distributed_nixtla_client.forecast(\n", |
654 | 654 | " df=df, \n", |
655 | 655 | " h=horizon, \n", |
656 | 656 | " num_partitions=2,\n", |
|
705 | 705 | " time_col: str = 'ds',\n", |
706 | 706 | " **anomalies_kwargs,\n", |
707 | 707 | " ):\n", |
708 | | - " anomalies_df = distributed_timegpt.detect_anomalies(\n", |
| 708 | + " anomalies_df = distributed_nixtla_client.detect_anomalies(\n", |
709 | 709 | " df=df, \n", |
710 | 710 | " id_col=id_col,\n", |
711 | 711 | " time_col=time_col,\n", |
|
731 | 731 | " time_col: str = 'ds',\n", |
732 | 732 | " **anomalies_kwargs,\n", |
733 | 733 | " ):\n", |
734 | | - " anomalies_df = distributed_timegpt.detect_anomalies(\n", |
| 734 | + " anomalies_df = distributed_nixtla_client.detect_anomalies(\n", |
735 | 735 | " df=df, \n", |
736 | 736 | " num_partitions=1,\n", |
737 | 737 | " id_col=id_col,\n", |
738 | 738 | " time_col=time_col,\n", |
739 | 739 | " **anomalies_kwargs\n", |
740 | 740 | " )\n", |
741 | 741 | " anomalies_df = fa.as_pandas(anomalies_df)\n", |
742 | | - " anomalies_df_2 = distributed_timegpt.detect_anomalies(\n", |
| 742 | + " anomalies_df_2 = distributed_nixtla_client.detect_anomalies(\n", |
743 | 743 | " df=df, \n", |
744 | 744 | " num_partitions=2,\n", |
745 | 745 | " id_col=id_col,\n", |
|
766 | 766 | " time_col: str = 'ds',\n", |
767 | 767 | " **anomalies_kwargs,\n", |
768 | 768 | " ):\n", |
769 | | - " anomalies_df = distributed_timegpt.detect_anomalies(\n", |
| 769 | + " anomalies_df = distributed_nixtla_client.detect_anomalies(\n", |
770 | 770 | " df=df, \n", |
771 | 771 | " num_partitions=1,\n", |
772 | 772 | " id_col=id_col,\n", |
|
775 | 775 | " **anomalies_kwargs\n", |
776 | 776 | " )\n", |
777 | 777 | " anomalies_df = fa.as_pandas(anomalies_df)\n", |
778 | | - " anomalies_df_2 = distributed_timegpt.detect_anomalies(\n", |
| 778 | + " anomalies_df_2 = distributed_nixtla_client.detect_anomalies(\n", |
779 | 779 | " df=df, \n", |
780 | 780 | " num_partitions=1,\n", |
781 | 781 | " id_col=id_col,\n", |
|
844 | 844 | " assert all(col in df_qls.columns for col in exp_q_cols)\n", |
845 | 845 | " # test monotonicity of quantiles\n", |
846 | 846 | " df_qls.apply(lambda x: x.is_monotonic_increasing, axis=1).sum() == len(exp_q_cols)\n", |
847 | | - " test_method_qls(distributed_timegpt.forecast)\n", |
848 | | - " test_method_qls(distributed_timegpt.forecast, add_history=True)\n", |
849 | | - " test_method_qls(distributed_timegpt.cross_validation)" |
| 847 | + " test_method_qls(distributed_nixtla_client.forecast)\n", |
| 848 | + " test_method_qls(distributed_nixtla_client.forecast, add_history=True)\n", |
| 849 | + " test_method_qls(distributed_nixtla_client.cross_validation)" |
850 | 850 | ] |
851 | 851 | }, |
852 | 852 | { |
|
856 | 856 | "outputs": [], |
857 | 857 | "source": [ |
858 | 858 | "#| hide\n", |
859 | | - "distributed_timegpt = _DistributedTimeGPT()" |
| 859 | + "distributed_nixtla_client = _DistributedNixtlaClient()" |
860 | 860 | ] |
861 | 861 | }, |
862 | 862 | { |
|
0 commit comments