Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion iexfinance/stocks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from datetime import date, datetime
from typing import Union

import pandas as pd

import iexfinance.stocks.cache as cache
from iexfinance.stocks.base import Stock # noqa
from iexfinance.stocks.collections import CollectionsReader
from iexfinance.stocks.historical import HistoricalReader, IntradayReader
from iexfinance.stocks.historical_cache import HistoricalReaderCache
from iexfinance.stocks.ipocalendar import IPOReader
from iexfinance.stocks.marketvolume import MarketVolumeReader
from iexfinance.stocks.movers import MoversReader
Expand All @@ -10,7 +17,12 @@
from iexfinance.utils.exceptions import ImmediateDeprecationError


def get_historical_data(symbols, start=None, end=None, close_only=False, **kwargs):
def get_historical_data(
symbols: Union[str, list],
start: Union[str, int, date, datetime, pd.Timestamp] = None,
end: Union[str, int, date, datetime, pd.Timestamp] = None,
close_only: bool = False,
**kwargs):
"""
Function to obtain historical date for a symbol or list of
symbols. Return an instance of HistoricalReader
Expand Down Expand Up @@ -38,6 +50,11 @@ def get_historical_data(symbols, start=None, end=None, close_only=False, **kwarg
list or DataFrame
Historical stock prices over date range, start to end
"""
if cache._IEXFINANCE_CACHE_ is not None:
return HistoricalReaderCache(
symbols, start=start, end=end, close_only=close_only, **kwargs
).fetch()

return HistoricalReader(
symbols, start=start, end=end, close_only=close_only, **kwargs
).fetch()
Expand Down
42 changes: 42 additions & 0 deletions iexfinance/stocks/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
from enum import Enum
from typing import NamedTuple, Union

import pandas as pd

_IEXFINANCE_CACHE_ = None

class CacheType(Enum):
NO_CACHE = 1
HDF_STORE = 2

class CacheMetadata(NamedTuple):
"""
cache_type: Enum, default CacheType.NO_CACHE
The type of cache (i.e. data store) to use to store previously requested
data.
cache_path: string, default None
Required if `cache_type` is specified.
A path to a file that stores the cached data.
"""
cache_path: str
cache_type: CacheType = CacheType.NO_CACHE

def prepare_cache(cache: Union[CacheMetadata, pd.HDFStore]):
global _IEXFINANCE_CACHE_
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment about global variables


if isinstance(cache, pd.HDFStore):
_IEXFINANCE_CACHE_ = cache
return

cache_type = cache.cache_type
cache_path = cache.cache_path

if not isinstance(cache_type, CacheType):
raise TypeError('`cache_type` must be an instance of CacheType Enum')
if cache_path is None:
raise ArgumentError('`cache_path` must not be none.')
if cache_type == CacheType.HDF_STORE:
_IEXFINANCE_CACHE_ = pd.HDFStore(cache_path)
else:
raise InternalError('Cannot initialize cache.')
73 changes: 73 additions & 0 deletions iexfinance/stocks/historical_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import logging
import datetime

import iexfinance.stocks.cache as cache
from iexfinance.stocks.historical import HistoricalReader

logger = logging.getLogger(__name__)

class HistoricalReaderCache(HistoricalReader):
"""
Base class to download historical data from the chart endpoint that is
also cached.

Reference: https://iextrading.com/developer/docs/#chart
"""

def __init__(self, symbols, start=None, end=None, close_only=False, **kwargs):
if cache._IEXFINANCE_CACHE_ is None:
raise InternalError("Must called `prepare_cache` first.")
self.kwargs = kwargs
super(HistoricalReaderCache, self).__init__(symbols, start=start, end=end, close_only=close_only, **kwargs)

def _execute_iex_query(self, url):
if len(self.symbols) > 1:
raise InternalError("Not supported yet")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic can go in the constructor after super

return self._get_historical_data_cached(self.symbols[0])

def _format_output(self, out, format=None):
if self.output_format == "json":
raise InternalError("Need to convert dataframe to json")
if len(self.symbols) > 1:
raise InternalError("Need to concanatanate cached data.")
else:
result = out
result = result.loc[self.start : self.end, :]
if self.close_only is True:
result = result.loc[:, ["close", "volume"]]
return result

def _get_historical_data(self, symbol, start, end):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use fetch in the parent class instead of this method?

return HistoricalReader(
symbol, start=start, end=end, close_only=self.close_only, **self.kwargs
).fetch()

def _get_historical_data_cached(self, symbol):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fetch for the subclass.

logger.info(f"{symbol}: `get_historical_data_cached` request between {self.start} and {self.end}.")

if symbol not in cache._IEXFINANCE_CACHE_:
logger.info(f"{symbol}: No data is cached.")
df = self._get_historical_data(symbol, self.start, self.end)
metadata = {'min_date': self.start, 'max_date': self.end}
else:
metadata = cache._IEXFINANCE_CACHE_.get_storer(symbol).attrs.metadata
logger.info(f"{symbol}: Data is catched between {metadata['min_date']} and {metadata['max_date']}.")

df = cache._IEXFINANCE_CACHE_[symbol]
if self.start < metadata['min_date']:
logger.info(f"{symbol}: Requesting data between {self.start} and {metadata['min_date']}.")
df = df.append(self._get_historical_data(symbol, self.start, metadata['max_date']))
metadata['min_date'] = self.start

if self.end > metadata['max_date']:
logger.info(f"{symbol}: Requesting data between {metadata['max_date']} and {self.end}.")
df = df.append(self._get_historical_data(symbol, metadata['max_date'], self.end))
metadata['max_date'] = self.end

# Not using HDFStore.append() because of the need to de-duplicate
df = df[~df.index.duplicated(keep='first')]

cache._IEXFINANCE_CACHE_[symbol] = df
cache._IEXFINANCE_CACHE_.get_storer(symbol).attrs.metadata = metadata

return df
80 changes: 80 additions & 0 deletions iexfinance/tests/stocks/test_historical_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import logging
import os
import tempfile
import unittest
from datetime import date, datetime, timedelta

import pytest
from pandas import to_datetime

import iexfinance.stocks.cache
from iexfinance.stocks import get_historical_data
from iexfinance.stocks.cache import *


class TestHistoricalCache(unittest.TestCase):
def setup_class(self):
today = date.today()
today = to_datetime(today)
self.end = today - timedelta(days=30)
self.start = self.end - timedelta(days=365)*5

@pytest.fixture(autouse=True)
def prepare_test_cache(scope="function"):
with tempfile.TemporaryDirectory() as tempdir:
hdf_store_path = os.path.join(tempdir, 'test_store.h5')
cache_metadata = CacheMetadata(cache_path=hdf_store_path, cache_type=CacheType.HDF_STORE)
prepare_cache(cache_metadata)
yield

def _messages_used_logs(self, caplog):
return [log for log in caplog if 'MESSAGES USED' in log]

def _assert_data(self, data):
expected = data.loc["2017-02-09"]
assert expected["close"] == pytest.approx(821.36, 3)
assert expected["high"] == pytest.approx(825.0, 3)

def test_get_historical_data_cached_none(self):
with self.assertLogs(level='INFO') as cm:
data = get_historical_data(["AMZN"], self.start, self.end)

messages_used_logs = self._messages_used_logs(cm.output)
assert len(messages_used_logs) == 1
assert 'INFO:iexfinance.base:MESSAGES USED: 35330' in messages_used_logs

self._assert_data(data)

def test_get_historical_data_cached_full(self):
with self.assertLogs(level='INFO') as cm:
get_historical_data(["AMZN"], self.start, self.end)
start = self.start + timedelta(days=1)
end = self.end - timedelta(days=1)
data = get_historical_data(["AMZN"], start, end)

messages_used_logs = self._messages_used_logs(cm.output)
assert len(messages_used_logs) == 1
assert messages_used_logs.count('INFO:iexfinance.base:MESSAGES USED: 35330') == 1

self._assert_data(data)

def test_get_historical_data_cached_missing_start(self):
with self.assertLogs(level='INFO') as cm:
get_historical_data(["AMZN"], self.start, self.end)
start = self.start - timedelta(days=5)
data = get_historical_data(["AMZN"], start, self.end)

messages_used_logs = self._messages_used_logs(cm.output)
assert len(messages_used_logs) == 2
assert messages_used_logs.count('INFO:iexfinance.base:MESSAGES USED: 35330') == 2

def test_get_historical_data_cached_missing_end(self):
with self.assertLogs(level='INFO') as cm:
get_historical_data(["AMZN"], self.start, self.end)
end = self.end + timedelta(days=5)
data = get_historical_data(["AMZN"], self.start, end)

messages_used_logs = self._messages_used_logs(cm.output)
assert len(messages_used_logs) == 2
assert messages_used_logs.count('INFO:iexfinance.base:MESSAGES USED: 35330') == 1
assert messages_used_logs.count('INFO:iexfinance.base:MESSAGES USED: 620') == 1