Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 36 additions & 20 deletions openfisca_core/holders.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,41 +6,57 @@

import numpy as np
import psutil
from typing import TYPE_CHECKING, List, Union, Any, Optional, Dict

from openfisca_core import periods
from openfisca_core.commons import empty_clone
from openfisca_core.data_storage import InMemoryStorage, OnDiskStorage
from openfisca_core.errors import PeriodMismatchError
from openfisca_core.indexed_enums import Enum
from openfisca_core.periods import MONTH, YEAR, ETERNITY
from openfisca_core.periods import MONTH, YEAR, ETERNITY, Period
from openfisca_core.tools import eval_expression


if TYPE_CHECKING:
from openfisca_core.populations import Population
from openfisca_core.variables import Variable
from openfisca_core.simulations import Simulation


log = logging.getLogger(__name__)


#TODO change with a more specific type
ArrayLike = Any
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be shipped out of the box in Numpy 1.20/21



class Holder(object):
"""
A holder keeps tracks of a variable values after they have been calculated, or set as an input.
"""

def __init__(self, variable, population):
self.population = population
self.variable = variable
self.simulation = population.simulation
self._memory_storage = InMemoryStorage(is_eternal = (self.variable.definition_period == ETERNITY))
def __init__(self, variable:'Variable', population:'Population'):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flake won't like the lack of spaces.

Also you can forward-reference them no?

self.population:'Population' = population
self.variable:'Variable' = variable
# TODO change once decided if simulation is needed or not
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this requires a bit more of examination. I'll take a look.

if population.simulation is None:
raise Exception('You need a simulation attached to the population')
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error messages in OpenFisca are meant to help the users discover not just what is wrong, but how can they solve their problems or reach for help. Could you develop this further?

else:
self.simulation:'Simulation' = population.simulation
self._memory_storage:InMemoryStorage = InMemoryStorage(is_eternal = (self.variable.definition_period == ETERNITY))

# By default, do not activate on-disk storage, or variable dropping
self._disk_storage = None
self._on_disk_storable = False
self._do_not_store = False
self._disk_storage:Optional[OnDiskStorage] = None
self._on_disk_storable:bool = False
self._do_not_store:bool = False
if self.simulation and self.simulation.memory_config:
if self.variable.name not in self.simulation.memory_config.priority_variables:
self._disk_storage = self.create_disk_storage()
self._on_disk_storable = True
if self.variable.name in self.simulation.memory_config.variables_to_drop:
self._do_not_store = True

def clone(self, population):
def clone(self, population:'Population')->'Holder':
"""
Copy the holder just enough to be able to run a new simulation without modifying the original simulation.
"""
Expand All @@ -56,7 +72,7 @@ def clone(self, population):

return new

def create_disk_storage(self, directory = None, preserve = False):
def create_disk_storage(self, directory:Optional[str] = None, preserve:bool = False)->OnDiskStorage:
if directory is None:
directory = self.simulation.data_storage_dir
storage_dir = os.path.join(directory, self.variable.name)
Expand All @@ -68,7 +84,7 @@ def create_disk_storage(self, directory = None, preserve = False):
preserve_storage_dir = preserve
)

def delete_arrays(self, period = None):
def delete_arrays(self, period:Optional[Union[str, Period]] = None)->None:
"""
If ``period`` is ``None``, remove all known values of the variable.

Expand All @@ -94,7 +110,7 @@ def get_array(self, period):
if self._disk_storage:
return self._disk_storage.get(period)

def get_memory_usage(self):
def get_memory_usage(self)->Dict[str, Union[int, np.dtype]]:
"""
Get data about the virtual memory usage of the holder.

Expand Down Expand Up @@ -131,15 +147,15 @@ def get_memory_usage(self):

return usage

def get_known_periods(self):
def get_known_periods(self)->List[Period]:
"""
Get the list of periods the variable value is known for.
"""

return list(self._memory_storage.get_known_periods()) + list((
self._disk_storage.get_known_periods() if self._disk_storage else []))

def set_input(self, period, array):
def set_input(self, period:Union[str, Period], array: ArrayLike):
"""
Set a variable's value (``array``) for a given period (``period``)

Expand Down Expand Up @@ -183,7 +199,7 @@ def set_input(self, period, array):
return self.variable.set_input(self, period, array)
return self._set(period, array)

def _to_array(self, value):
def _to_array(self, value:ArrayLike)->np.array:
if not isinstance(value, np.ndarray):
value = np.asarray(value)
if value.ndim == 0:
Expand All @@ -204,7 +220,7 @@ def _to_array(self, value):
.format(value, self.variable.name, self.variable.dtype, value.dtype))
return value

def _set(self, period, value):
def _set(self, period:Union[Period, None], value:ArrayLike)->None:
value = self._to_array(value)
if self.variable.definition_period != ETERNITY:
if period is None:
Expand Down Expand Up @@ -236,7 +252,7 @@ def _set(self, period, value):
else:
self._memory_storage.put(value, period)

def put_in_cache(self, value, period):
def put_in_cache(self, value:ArrayLike, period:Period)->None:
if self._do_not_store:
return

Expand All @@ -255,7 +271,7 @@ def default_array(self):
return self.variable.default_array(self.population.count)


def set_input_dispatch_by_period(holder, period, array):
def set_input_dispatch_by_period(holder:'Holder', period:Period, array):
"""
This function can be declared as a ``set_input`` attribute of a variable.

Expand Down Expand Up @@ -290,7 +306,7 @@ def set_input_dispatch_by_period(holder, period, array):
sub_period = sub_period.offset(1)


def set_input_divide_by_period(holder, period, array):
def set_input_divide_by_period(holder:Holder, period:Period, array)->None:
"""
This function can be declared as a ``set_input`` attribute of a variable.

Expand Down
46 changes: 28 additions & 18 deletions openfisca_core/periods.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import datetime
import re
from os import linesep
from typing import Dict
from typing import Dict, Union, List, Optional


DAY = 'day'
Expand All @@ -34,9 +34,19 @@ def N_(message):
str_by_instant_cache: Dict = {}
year_or_month_or_day_re = re.compile(r'(18|19|20)\d{2}(-(0?[1-9]|1[0-2])(-([0-2]?\d|3[0-1]))?)?$')

class PeriodValueError(ValueError):

def __init__(self, value):
message = linesep.join([
"Expected a period (eg. '2017', '2017-01', '2017-01-01', ...); got: '{}'.".format(value),
"Learn more about legal period formats in OpenFisca:",
"<https://openfisca.org/doc/coding-the-legislation/35_periods.html#periods-in-simulations>."
])
super().__init__(message)


class Instant(tuple):
def __repr__(self):
def __repr__(self)->str:
"""Transform instant to to its Python representation as a string.

>>> repr(instant(2014))
Expand All @@ -48,7 +58,7 @@ def __repr__(self):
"""
return '{}({})'.format(self.__class__.__name__, super(Instant, self).__repr__())

def __str__(self):
def __str__(self)->str:
"""Transform instant to a string.

>>> str(instant(2014))
Expand Down Expand Up @@ -81,7 +91,7 @@ def date(self):
return instant_date

@property
def day(self):
def day(self)->int:
"""Extract day from instant.

>>> instant(2014).day
Expand All @@ -94,7 +104,7 @@ def day(self):
return self[2]

@property
def month(self):
def month(self)->int:
"""Extract month from instant.

>>> instant(2014).month
Expand All @@ -106,7 +116,7 @@ def month(self):
"""
return self[1]

def period(self, unit, size = 1):
def period(self, unit:str, size:int = 1)->Period:
"""Create a new period starting at instant.

>>> instant(2014).period('month')
Expand All @@ -120,7 +130,7 @@ def period(self, unit, size = 1):
assert isinstance(size, int) and size >= 1, 'Invalid size: {} of type {}'.format(size, type(size))
return Period((unit, self, size))

def offset(self, offset, unit):
def offset(self, offset:int, unit:str)->'Instant':
"""Increment (or decrement) the given instant with offset units.

>>> instant(2014).offset(1, 'day')
Expand Down Expand Up @@ -257,7 +267,7 @@ def offset(self, offset, unit):
return self.__class__((year, month, day))

@property
def year(self):
def year(self)->int:
"""Extract year from instant.

>>> instant(2014).year
Expand All @@ -271,7 +281,7 @@ def year(self):


class Period(tuple):
def __repr__(self):
def __repr__(self)->str:
"""Transform period to to its Python representation as a string.

>>> repr(period('year', 2014))
Expand All @@ -283,7 +293,7 @@ def __repr__(self):
"""
return '{}({})'.format(self.__class__.__name__, super(Period, self).__repr__())

def __str__(self):
def __str__(self)->str:
"""Transform period to a string.

>>> str(period(YEAR, 2014))
Expand Down Expand Up @@ -344,7 +354,7 @@ def date(self):
return self.start.date

@property
def days(self):
def days(self)->int:
"""Count the number of days in period.

>>> period('day', 2014).days
Expand Down Expand Up @@ -410,7 +420,7 @@ def intersection(self, start, stop):
(intersection_stop.date - intersection_start.date).days + 1,
))

def get_subperiods(self, unit):
def get_subperiods(self, unit:str):
"""
Return the list of all the periods of unit ``unit`` contained in self.

Expand Down Expand Up @@ -786,7 +796,7 @@ def instant_date(instant):
return instant_date


def period(value):
def period(value:Union[str, Period, Instant, int])->Period:
"""Return a new period, aka a triple (unit, start_instant, size).

>>> period('2014')
Expand All @@ -810,7 +820,7 @@ def period(value):
if isinstance(value, Instant):
return Period((DAY, value, 1))

def parse_simple_period(value):
def parse_simple_period(value:str)->Optional[Period]:
"""
Parses simple periods respecting the ISO format, such as 2012 or 2015-03
"""
Expand Down Expand Up @@ -867,7 +877,7 @@ def raise_error(value):
# middle component must be a valid iso period
base_period = parse_simple_period(components[1])
if not base_period:
raise_error(value)
raise PeriodValueError(value)

# period like year:2015-03 have a size of 1
if len(components) == 2:
Expand All @@ -889,7 +899,7 @@ def raise_error(value):
return Period((unit, base_period.start, size))


def key_period_size(period):
def key_period_size(period:Period)->str:
"""
Defines a key in order to sort periods by length. It uses two aspects : first unit then size

Expand All @@ -910,7 +920,7 @@ def key_period_size(period):
return '{}_{}'.format(unit_weight(unit), size)


def unit_weights():
def unit_weights()->Dict[str,int]:
return {
DAY: 100,
MONTH: 200,
Expand All @@ -919,5 +929,5 @@ def unit_weights():
}


def unit_weight(unit):
def unit_weight(unit:str)->int:
return unit_weights()[unit]
Loading