diff --git a/iexfinance/stocks/__init__.py b/iexfinance/stocks/__init__.py index 49d6327..982c4d2 100644 --- a/iexfinance/stocks/__init__.py +++ b/iexfinance/stocks/__init__.py @@ -1,6 +1,13 @@ +from datetime import date, datetime +from typing import Union + +import pandas as pd + +import iexfinance.stocks.cache as cache from iexfinance.stocks.base import Stock # noqa from iexfinance.stocks.collections import CollectionsReader from iexfinance.stocks.historical import HistoricalReader, IntradayReader +from iexfinance.stocks.historical_cache import HistoricalReaderCache from iexfinance.stocks.ipocalendar import IPOReader from iexfinance.stocks.marketvolume import MarketVolumeReader from iexfinance.stocks.movers import MoversReader @@ -10,7 +17,12 @@ from iexfinance.utils.exceptions import ImmediateDeprecationError -def get_historical_data(symbols, start=None, end=None, close_only=False, **kwargs): +def get_historical_data( + symbols: Union[str, list], + start: Union[str, int, date, datetime, pd.Timestamp] = None, + end: Union[str, int, date, datetime, pd.Timestamp] = None, + close_only: bool = False, + **kwargs): """ Function to obtain historical date for a symbol or list of symbols. Return an instance of HistoricalReader @@ -38,6 +50,11 @@ def get_historical_data(symbols, start=None, end=None, close_only=False, **kwarg list or DataFrame Historical stock prices over date range, start to end """ + if cache._IEXFINANCE_CACHE_ is not None: + return HistoricalReaderCache( + symbols, start=start, end=end, close_only=close_only, **kwargs + ).fetch() + return HistoricalReader( symbols, start=start, end=end, close_only=close_only, **kwargs ).fetch() diff --git a/iexfinance/stocks/cache.py b/iexfinance/stocks/cache.py new file mode 100644 index 0000000..c26036e --- /dev/null +++ b/iexfinance/stocks/cache.py @@ -0,0 +1,42 @@ +import os +from enum import Enum +from typing import NamedTuple, Union + +import pandas as pd + +_IEXFINANCE_CACHE_ = None + +class CacheType(Enum): + NO_CACHE = 1 + HDF_STORE = 2 + +class CacheMetadata(NamedTuple): + """ + cache_type: Enum, default CacheType.NO_CACHE + The type of cache (i.e. data store) to use to store previously requested + data. + cache_path: string, default None + Required if `cache_type` is specified. + A path to a file that stores the cached data. + """ + cache_path: str + cache_type: CacheType = CacheType.NO_CACHE + +def prepare_cache(cache: Union[CacheMetadata, pd.HDFStore]): + global _IEXFINANCE_CACHE_ + + if isinstance(cache, pd.HDFStore): + _IEXFINANCE_CACHE_ = cache + return + + cache_type = cache.cache_type + cache_path = cache.cache_path + + if not isinstance(cache_type, CacheType): + raise TypeError('`cache_type` must be an instance of CacheType Enum') + if cache_path is None: + raise ArgumentError('`cache_path` must not be none.') + if cache_type == CacheType.HDF_STORE: + _IEXFINANCE_CACHE_ = pd.HDFStore(cache_path) + else: + raise InternalError('Cannot initialize cache.') diff --git a/iexfinance/stocks/historical_cache.py b/iexfinance/stocks/historical_cache.py new file mode 100644 index 0000000..989b79f --- /dev/null +++ b/iexfinance/stocks/historical_cache.py @@ -0,0 +1,73 @@ +import logging +import datetime + +import iexfinance.stocks.cache as cache +from iexfinance.stocks.historical import HistoricalReader + +logger = logging.getLogger(__name__) + +class HistoricalReaderCache(HistoricalReader): + """ + Base class to download historical data from the chart endpoint that is + also cached. + + Reference: https://iextrading.com/developer/docs/#chart + """ + + def __init__(self, symbols, start=None, end=None, close_only=False, **kwargs): + if cache._IEXFINANCE_CACHE_ is None: + raise InternalError("Must called `prepare_cache` first.") + self.kwargs = kwargs + super(HistoricalReaderCache, self).__init__(symbols, start=start, end=end, close_only=close_only, **kwargs) + + def _execute_iex_query(self, url): + if len(self.symbols) > 1: + raise InternalError("Not supported yet") + return self._get_historical_data_cached(self.symbols[0]) + + def _format_output(self, out, format=None): + if self.output_format == "json": + raise InternalError("Need to convert dataframe to json") + if len(self.symbols) > 1: + raise InternalError("Need to concanatanate cached data.") + else: + result = out + result = result.loc[self.start : self.end, :] + if self.close_only is True: + result = result.loc[:, ["close", "volume"]] + return result + + def _get_historical_data(self, symbol, start, end): + return HistoricalReader( + symbol, start=start, end=end, close_only=self.close_only, **self.kwargs + ).fetch() + + def _get_historical_data_cached(self, symbol): + logger.info(f"{symbol}: `get_historical_data_cached` request between {self.start} and {self.end}.") + + if symbol not in cache._IEXFINANCE_CACHE_: + logger.info(f"{symbol}: No data is cached.") + df = self._get_historical_data(symbol, self.start, self.end) + metadata = {'min_date': self.start, 'max_date': self.end} + else: + metadata = cache._IEXFINANCE_CACHE_.get_storer(symbol).attrs.metadata + logger.info(f"{symbol}: Data is catched between {metadata['min_date']} and {metadata['max_date']}.") + + df = cache._IEXFINANCE_CACHE_[symbol] + if self.start < metadata['min_date']: + logger.info(f"{symbol}: Requesting data between {self.start} and {metadata['min_date']}.") + df = df.append(self._get_historical_data(symbol, self.start, metadata['max_date'])) + metadata['min_date'] = self.start + + if self.end > metadata['max_date']: + logger.info(f"{symbol}: Requesting data between {metadata['max_date']} and {self.end}.") + df = df.append(self._get_historical_data(symbol, metadata['max_date'], self.end)) + metadata['max_date'] = self.end + + # Not using HDFStore.append() because of the need to de-duplicate + df = df[~df.index.duplicated(keep='first')] + + cache._IEXFINANCE_CACHE_[symbol] = df + cache._IEXFINANCE_CACHE_.get_storer(symbol).attrs.metadata = metadata + + return df diff --git a/iexfinance/tests/stocks/test_historical_cache.py b/iexfinance/tests/stocks/test_historical_cache.py new file mode 100644 index 0000000..0697c27 --- /dev/null +++ b/iexfinance/tests/stocks/test_historical_cache.py @@ -0,0 +1,80 @@ +import logging +import os +import tempfile +import unittest +from datetime import date, datetime, timedelta + +import pytest +from pandas import to_datetime + +import iexfinance.stocks.cache +from iexfinance.stocks import get_historical_data +from iexfinance.stocks.cache import * + + +class TestHistoricalCache(unittest.TestCase): + def setup_class(self): + today = date.today() + today = to_datetime(today) + self.end = today - timedelta(days=30) + self.start = self.end - timedelta(days=365)*5 + + @pytest.fixture(autouse=True) + def prepare_test_cache(scope="function"): + with tempfile.TemporaryDirectory() as tempdir: + hdf_store_path = os.path.join(tempdir, 'test_store.h5') + cache_metadata = CacheMetadata(cache_path=hdf_store_path, cache_type=CacheType.HDF_STORE) + prepare_cache(cache_metadata) + yield + + def _messages_used_logs(self, caplog): + return [log for log in caplog if 'MESSAGES USED' in log] + + def _assert_data(self, data): + expected = data.loc["2017-02-09"] + assert expected["close"] == pytest.approx(821.36, 3) + assert expected["high"] == pytest.approx(825.0, 3) + + def test_get_historical_data_cached_none(self): + with self.assertLogs(level='INFO') as cm: + data = get_historical_data(["AMZN"], self.start, self.end) + + messages_used_logs = self._messages_used_logs(cm.output) + assert len(messages_used_logs) == 1 + assert 'INFO:iexfinance.base:MESSAGES USED: 35330' in messages_used_logs + + self._assert_data(data) + + def test_get_historical_data_cached_full(self): + with self.assertLogs(level='INFO') as cm: + get_historical_data(["AMZN"], self.start, self.end) + start = self.start + timedelta(days=1) + end = self.end - timedelta(days=1) + data = get_historical_data(["AMZN"], start, end) + + messages_used_logs = self._messages_used_logs(cm.output) + assert len(messages_used_logs) == 1 + assert messages_used_logs.count('INFO:iexfinance.base:MESSAGES USED: 35330') == 1 + + self._assert_data(data) + + def test_get_historical_data_cached_missing_start(self): + with self.assertLogs(level='INFO') as cm: + get_historical_data(["AMZN"], self.start, self.end) + start = self.start - timedelta(days=5) + data = get_historical_data(["AMZN"], start, self.end) + + messages_used_logs = self._messages_used_logs(cm.output) + assert len(messages_used_logs) == 2 + assert messages_used_logs.count('INFO:iexfinance.base:MESSAGES USED: 35330') == 2 + + def test_get_historical_data_cached_missing_end(self): + with self.assertLogs(level='INFO') as cm: + get_historical_data(["AMZN"], self.start, self.end) + end = self.end + timedelta(days=5) + data = get_historical_data(["AMZN"], self.start, end) + + messages_used_logs = self._messages_used_logs(cm.output) + assert len(messages_used_logs) == 2 + assert messages_used_logs.count('INFO:iexfinance.base:MESSAGES USED: 35330') == 1 + assert messages_used_logs.count('INFO:iexfinance.base:MESSAGES USED: 620') == 1