From 22d66cb91257c5c2cb106a6004903e94bd3c46d6 Mon Sep 17 00:00:00 2001 From: cescofran Date: Tue, 16 Feb 2021 13:15:39 +0100 Subject: [PATCH 1/7] Add typings to the variables module --- openfisca_core/variables.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/openfisca_core/variables.py b/openfisca_core/variables.py index c1ecd8ba8c..d2408fbef7 100644 --- a/openfisca_core/variables.py +++ b/openfisca_core/variables.py @@ -4,7 +4,7 @@ import inspect import re import textwrap -from typing import Optional +from typing import Optional, List, Callable, Type, Dict, KeysView, Tuple, Union import numpy as np from sortedcontainers.sorteddict import SortedDict @@ -147,7 +147,7 @@ class Variable(object): Free multilines text field describing the variable context and usage. """ - def __init__(self, baseline_variable = None): + def __init__(self, baseline_variable:Optional['Variable'] = None): self.name = self.__class__.__name__ attr = { name: value for name, value in self.__class__.__dict__.items() @@ -190,7 +190,8 @@ def __init__(self, baseline_variable = None): # ----- Setters used to build the variable ----- # - def set(self, attributes, attribute_name, required = False, allowed_values = None, allowed_type = None, setter = None, default = None): + def set(self, attributes, attribute_name:str, required:bool = False, allowed_values:Optional[Union[List, KeysView, Tuple]] = None, + allowed_type:Optional[Union[Type, Tuple[Type, ...]]] = None, setter:Optional[Callable] = None, default = None): value = attributes.pop(attribute_name, None) if value is None and self.baseline_variable: return getattr(self.baseline_variable, attribute_name) @@ -211,12 +212,12 @@ def set(self, attributes, attribute_name, required = False, allowed_values = Non return default return value - def set_entity(self, entity): + def set_entity(self, entity:Entity)->Entity: if not isinstance(entity, Entity): raise ValueError(f"Invalid value '{entity}' for attribute 'entity' in variable '{self.name}'. Must be an instance of Entity.") return entity - def set_possible_values(self, possible_values): + def set_possible_values(self, possible_values:Type[Enum]): if not issubclass(possible_values, Enum): raise ValueError("Invalid value '{}' for attribute 'possible_values' in variable '{}'. Must be a subclass of {}." .format(possible_values, self.name, Enum)) @@ -266,7 +267,7 @@ def set_calculate_output(self, calculate_output): return self.baseline_variable.calculate_output return calculate_output - def set_formulas(self, formulas_attr): + def set_formulas(self, formulas_attr:Dict[str, Callable])->Dict[str, Callable]: formulas = SortedDict() for formula_name, formula in formulas_attr.items(): starting_date = self.parse_formula_name(formula_name) @@ -322,7 +323,7 @@ def raise_error(): # ----- Methods ----- # - def is_input_variable(self): + def is_input_variable(self)->bool: """ Returns True if the variable is an input variable. """ @@ -388,7 +389,7 @@ def get_formula(self, period = None): return None - def clone(self): + def clone(self)->Variable: clone = self.__class__() return clone @@ -447,7 +448,7 @@ def _partition(dict, predicate): return true_dict, false_dict -def get_neutralized_variable(variable): +def get_neutralized_variable(variable:Variable)->Variable: """ Return a new neutralized variable (to be used by reforms). A neutralized variable always returns its default value, and does not cache anything. From f566ebfd63246aaa7c8caffeca2cb3543b7120e8 Mon Sep 17 00:00:00 2001 From: cescofran Date: Tue, 16 Feb 2021 13:16:39 +0100 Subject: [PATCH 2/7] Add typings to the periods module --- openfisca_core/periods.py | 46 ++++++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/openfisca_core/periods.py b/openfisca_core/periods.py index d906794e94..a4a15404e5 100644 --- a/openfisca_core/periods.py +++ b/openfisca_core/periods.py @@ -15,7 +15,7 @@ import datetime import re from os import linesep -from typing import Dict +from typing import Dict, Union, List, Optional DAY = 'day' @@ -34,9 +34,19 @@ def N_(message): str_by_instant_cache: Dict = {} year_or_month_or_day_re = re.compile(r'(18|19|20)\d{2}(-(0?[1-9]|1[0-2])(-([0-2]?\d|3[0-1]))?)?$') +class PeriodValueError(ValueError): + + def __init__(self, value): + message = linesep.join([ + "Expected a period (eg. '2017', '2017-01', '2017-01-01', ...); got: '{}'.".format(value), + "Learn more about legal period formats in OpenFisca:", + "." + ]) + super().__init__(message) + class Instant(tuple): - def __repr__(self): + def __repr__(self)->str: """Transform instant to to its Python representation as a string. >>> repr(instant(2014)) @@ -48,7 +58,7 @@ def __repr__(self): """ return '{}({})'.format(self.__class__.__name__, super(Instant, self).__repr__()) - def __str__(self): + def __str__(self)->str: """Transform instant to a string. >>> str(instant(2014)) @@ -81,7 +91,7 @@ def date(self): return instant_date @property - def day(self): + def day(self)->int: """Extract day from instant. >>> instant(2014).day @@ -94,7 +104,7 @@ def day(self): return self[2] @property - def month(self): + def month(self)->int: """Extract month from instant. >>> instant(2014).month @@ -106,7 +116,7 @@ def month(self): """ return self[1] - def period(self, unit, size = 1): + def period(self, unit:str, size:int = 1)->Period: """Create a new period starting at instant. >>> instant(2014).period('month') @@ -120,7 +130,7 @@ def period(self, unit, size = 1): assert isinstance(size, int) and size >= 1, 'Invalid size: {} of type {}'.format(size, type(size)) return Period((unit, self, size)) - def offset(self, offset, unit): + def offset(self, offset:int, unit:str)->'Instant': """Increment (or decrement) the given instant with offset units. >>> instant(2014).offset(1, 'day') @@ -257,7 +267,7 @@ def offset(self, offset, unit): return self.__class__((year, month, day)) @property - def year(self): + def year(self)->int: """Extract year from instant. >>> instant(2014).year @@ -271,7 +281,7 @@ def year(self): class Period(tuple): - def __repr__(self): + def __repr__(self)->str: """Transform period to to its Python representation as a string. >>> repr(period('year', 2014)) @@ -283,7 +293,7 @@ def __repr__(self): """ return '{}({})'.format(self.__class__.__name__, super(Period, self).__repr__()) - def __str__(self): + def __str__(self)->str: """Transform period to a string. >>> str(period(YEAR, 2014)) @@ -344,7 +354,7 @@ def date(self): return self.start.date @property - def days(self): + def days(self)->int: """Count the number of days in period. >>> period('day', 2014).days @@ -410,7 +420,7 @@ def intersection(self, start, stop): (intersection_stop.date - intersection_start.date).days + 1, )) - def get_subperiods(self, unit): + def get_subperiods(self, unit:str): """ Return the list of all the periods of unit ``unit`` contained in self. @@ -786,7 +796,7 @@ def instant_date(instant): return instant_date -def period(value): +def period(value:Union[str, Period, Instant, int])->Period: """Return a new period, aka a triple (unit, start_instant, size). >>> period('2014') @@ -810,7 +820,7 @@ def period(value): if isinstance(value, Instant): return Period((DAY, value, 1)) - def parse_simple_period(value): + def parse_simple_period(value:str)->Optional[Period]: """ Parses simple periods respecting the ISO format, such as 2012 or 2015-03 """ @@ -867,7 +877,7 @@ def raise_error(value): # middle component must be a valid iso period base_period = parse_simple_period(components[1]) if not base_period: - raise_error(value) + raise PeriodValueError(value) # period like year:2015-03 have a size of 1 if len(components) == 2: @@ -889,7 +899,7 @@ def raise_error(value): return Period((unit, base_period.start, size)) -def key_period_size(period): +def key_period_size(period:Period)->str: """ Defines a key in order to sort periods by length. It uses two aspects : first unit then size @@ -910,7 +920,7 @@ def key_period_size(period): return '{}_{}'.format(unit_weight(unit), size) -def unit_weights(): +def unit_weights()->Dict[str,int]: return { DAY: 100, MONTH: 200, @@ -919,5 +929,5 @@ def unit_weights(): } -def unit_weight(unit): +def unit_weight(unit:str)->int: return unit_weights()[unit] From 7ae9259c264f1d749367e489fb6bc744e6303c2c Mon Sep 17 00:00:00 2001 From: cescofran Date: Tue, 16 Feb 2021 13:17:41 +0100 Subject: [PATCH 3/7] Add typings to the holders module --- openfisca_core/holders.py | 56 +++++++++++++++++++++++++-------------- 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/openfisca_core/holders.py b/openfisca_core/holders.py index 233a368cbc..f8f24983e5 100644 --- a/openfisca_core/holders.py +++ b/openfisca_core/holders.py @@ -6,33 +6,49 @@ import numpy as np import psutil +from typing import TYPE_CHECKING, List, Union, Any, Optional, Dict from openfisca_core import periods from openfisca_core.commons import empty_clone from openfisca_core.data_storage import InMemoryStorage, OnDiskStorage from openfisca_core.errors import PeriodMismatchError from openfisca_core.indexed_enums import Enum -from openfisca_core.periods import MONTH, YEAR, ETERNITY +from openfisca_core.periods import MONTH, YEAR, ETERNITY, Period from openfisca_core.tools import eval_expression + +if TYPE_CHECKING: + from openfisca_core.populations import Population + from openfisca_core.variables import Variable + from openfisca_core.simulations import Simulation + + log = logging.getLogger(__name__) +#TODO change with a more specific type +ArrayLike = Any + + class Holder(object): """ A holder keeps tracks of a variable values after they have been calculated, or set as an input. """ - def __init__(self, variable, population): - self.population = population - self.variable = variable - self.simulation = population.simulation - self._memory_storage = InMemoryStorage(is_eternal = (self.variable.definition_period == ETERNITY)) + def __init__(self, variable:'Variable', population:'Population'): + self.population:'Population' = population + self.variable:'Variable' = variable + # TODO change once decided if simulation is needed or not + if population.simulation is None: + raise Exception('You need a simulation attached to the population') + else: + self.simulation:'Simulation' = population.simulation + self._memory_storage:InMemoryStorage = InMemoryStorage(is_eternal = (self.variable.definition_period == ETERNITY)) # By default, do not activate on-disk storage, or variable dropping - self._disk_storage = None - self._on_disk_storable = False - self._do_not_store = False + self._disk_storage:Optional[OnDiskStorage] = None + self._on_disk_storable:bool = False + self._do_not_store:bool = False if self.simulation and self.simulation.memory_config: if self.variable.name not in self.simulation.memory_config.priority_variables: self._disk_storage = self.create_disk_storage() @@ -40,7 +56,7 @@ def __init__(self, variable, population): if self.variable.name in self.simulation.memory_config.variables_to_drop: self._do_not_store = True - def clone(self, population): + def clone(self, population:'Population')->'Holder': """ Copy the holder just enough to be able to run a new simulation without modifying the original simulation. """ @@ -56,7 +72,7 @@ def clone(self, population): return new - def create_disk_storage(self, directory = None, preserve = False): + def create_disk_storage(self, directory:Optional[str] = None, preserve:bool = False)->OnDiskStorage: if directory is None: directory = self.simulation.data_storage_dir storage_dir = os.path.join(directory, self.variable.name) @@ -68,7 +84,7 @@ def create_disk_storage(self, directory = None, preserve = False): preserve_storage_dir = preserve ) - def delete_arrays(self, period = None): + def delete_arrays(self, period:Optional[Union[str, Period]] = None)->None: """ If ``period`` is ``None``, remove all known values of the variable. @@ -94,7 +110,7 @@ def get_array(self, period): if self._disk_storage: return self._disk_storage.get(period) - def get_memory_usage(self): + def get_memory_usage(self)->Dict[str, Union[int, np.dtype]]: """ Get data about the virtual memory usage of the holder. @@ -131,7 +147,7 @@ def get_memory_usage(self): return usage - def get_known_periods(self): + def get_known_periods(self)->List[Period]: """ Get the list of periods the variable value is known for. """ @@ -139,7 +155,7 @@ def get_known_periods(self): return list(self._memory_storage.get_known_periods()) + list(( self._disk_storage.get_known_periods() if self._disk_storage else [])) - def set_input(self, period, array): + def set_input(self, period:Union[str, Period], array: ArrayLike): """ Set a variable's value (``array``) for a given period (``period``) @@ -183,7 +199,7 @@ def set_input(self, period, array): return self.variable.set_input(self, period, array) return self._set(period, array) - def _to_array(self, value): + def _to_array(self, value:ArrayLike)->np.array: if not isinstance(value, np.ndarray): value = np.asarray(value) if value.ndim == 0: @@ -204,7 +220,7 @@ def _to_array(self, value): .format(value, self.variable.name, self.variable.dtype, value.dtype)) return value - def _set(self, period, value): + def _set(self, period:Union[Period, None], value:ArrayLike)->None: value = self._to_array(value) if self.variable.definition_period != ETERNITY: if period is None: @@ -236,7 +252,7 @@ def _set(self, period, value): else: self._memory_storage.put(value, period) - def put_in_cache(self, value, period): + def put_in_cache(self, value:ArrayLike, period:Period)->None: if self._do_not_store: return @@ -255,7 +271,7 @@ def default_array(self): return self.variable.default_array(self.population.count) -def set_input_dispatch_by_period(holder, period, array): +def set_input_dispatch_by_period(holder:'Holder', period:Period, array): """ This function can be declared as a ``set_input`` attribute of a variable. @@ -290,7 +306,7 @@ def set_input_dispatch_by_period(holder, period, array): sub_period = sub_period.offset(1) -def set_input_divide_by_period(holder, period, array): +def set_input_divide_by_period(holder:Holder, period:Period, array)->None: """ This function can be declared as a ``set_input`` attribute of a variable. From eee3ac9c7cc1202a7b870a80bd2372e4ce46655a Mon Sep 17 00:00:00 2001 From: cescofran Date: Tue, 16 Feb 2021 13:18:16 +0100 Subject: [PATCH 4/7] Add typings to the populations module --- openfisca_core/populations.py | 41 ++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/openfisca_core/populations.py b/openfisca_core/populations.py index f6084f6726..f74d6f4f48 100644 --- a/openfisca_core/populations.py +++ b/openfisca_core/populations.py @@ -2,37 +2,38 @@ import traceback -from typing import Iterable +from typing import Iterable, Optional, Dict, Callable import numpy as np -from openfisca_core.entities import Role +from openfisca_core.entities import Role, Entity from openfisca_core.indexed_enums import EnumArray from openfisca_core.holders import Holder +from openfisca_core.simulations import Simulation ADD = 'add' DIVIDE = 'divide' -def projectable(function): +def projectable(function:Callable)->Callable: """ Decorator to indicate that when called on a projector, the outcome of the function must be projected. For instance person.household.sum(...) must be projected on person, while it would not make sense for person.household.get_holder. """ - function.projectable = True + function.projectable:bool = True return function class Population(object): - def __init__(self, entity): - self.simulation = None - self.entity = entity - self._holders = {} - self.count = 0 + def __init__(self, entity:Entity)->None: + self.simulation:Optional[Simulation] = None + self.entity:Entity = entity + self._holders:Dict[str, Holder] = {} + self.count:int = 0 self.ids = [] - def clone(self, simulation): + def clone(self, simulation:Simulation): result = Population(self.entity) result.simulation = simulation result._holders = {variable: holder.clone(result) for (variable, holder) in self._holders.items()} @@ -40,10 +41,10 @@ def clone(self, simulation): result.ids = self.ids return result - def empty_array(self): + def empty_array(self)->np.array: return np.zeros(self.count) - def filled_array(self, value, dtype = None): + def filled_array(self, value, dtype = None)->np.array: return np.full(self.count, value, dtype) def __getattr__(self, attribute): @@ -57,12 +58,12 @@ def get_index(self, id): # Calculations - def check_array_compatible_with_entity(self, array): + def check_array_compatible_with_entity(self, array:np.array): if not self.count == array.size: raise ValueError("Input {} is not a valid value for the entity {} (size = {} != {} = count)".format( array, self.key, array.size, self.count)) - def check_period_validity(self, variable_name, period): + def check_period_validity(self, variable_name:str, period): if period is None: stack = traceback.extract_stack() filename, line_number, function_name, line_of_code = stack[-3] @@ -74,7 +75,7 @@ def check_period_validity(self, variable_name, period): See more information at . '''.format(variable_name, filename, line_number, line_of_code)) - def __call__(self, variable_name, period = None, options = None): + def __call__(self, variable_name:str, period = None, options:Optional[List[str]] = None)->np.array: """ Calculate the variable ``variable_name`` for the entity and the period ``period``, using the variable formula if it exists. @@ -102,7 +103,7 @@ def __call__(self, variable_name, period = None, options = None): # Helpers - def get_holder(self, variable_name): + def get_holder(self, variable_name:str)->Holder: self.entity.check_variable_defined_for_entity(variable_name) holder = self._holders.get(variable_name) if holder: @@ -111,7 +112,7 @@ def get_holder(self, variable_name): self._holders[variable_name] = holder = Holder(variable, self) return holder - def get_memory_usage(self, variables = None): + def get_memory_usage(self, variables:Optional[Dict[str, Variable]] = None): holders_memory_usage = { variable_name: holder.get_memory_usage() for variable_name, holder in self._holders.items() @@ -128,7 +129,7 @@ def get_memory_usage(self, variables = None): ) @projectable - def has_role(self, role): + def has_role(self, role:Role)->np.array: """ Check if a person has a given role within its :any:`GroupEntity` @@ -145,7 +146,7 @@ def has_role(self, role): return group_population.members_role == role @projectable - def value_from_partner(self, array, entity, role): + def value_from_partner(self, array:np.array, entity:Entity, role:Role)->np.array: self.check_array_compatible_with_entity(array) self.entity.check_role_validity(role) @@ -162,7 +163,7 @@ def value_from_partner(self, array, entity, role): ) @projectable - def get_rank(self, entity, criteria, condition = True): + def get_rank(self, entity:Entity, criteria:np.array, condition:bool = True)->np.array: """ Get the rank of a person within an entity according to a criteria. The person with rank 0 has the minimum value of criteria. From 0d90d06917a7e4338eb4ef4af192557db6f61725 Mon Sep 17 00:00:00 2001 From: cescofran Date: Tue, 16 Feb 2021 13:19:43 +0100 Subject: [PATCH 5/7] Add typings to the taxbenefitsystems module --- openfisca_core/taxbenefitsystems.py | 41 +++++++++++++++-------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/openfisca_core/taxbenefitsystems.py b/openfisca_core/taxbenefitsystems.py index 08ca122f91..d76dc7744f 100644 --- a/openfisca_core/taxbenefitsystems.py +++ b/openfisca_core/taxbenefitsystems.py @@ -12,6 +12,7 @@ import pkg_resources import traceback import copy +from typing import Type from openfisca_core.periods import Period, Instant, instant as make_instant from openfisca_core.entities import Entity @@ -26,6 +27,7 @@ log = logging.getLogger(__name__) + class VariableNameConflict(Exception): """ Exception raised when two variables with the same name are added to a tax and benefit system. @@ -59,7 +61,7 @@ def __init__(self, entities): # TODO: Currently: Don't use a weakref, because they are cleared by Paste (at least) at each call. self.parameters = None self._parameters_at_instant_cache = {} # weakref.WeakValueDictionary() - self.variables = {} + self.variables:Dict[str, Variable] = {} self.open_api_config = {} # Tax benefit systems are mutable, so entities (which need to know about our variables) can't be shared among them if entities is None or len(entities) == 0: @@ -137,7 +139,7 @@ def new_simulation(self, debug = False, opt_out_cache = False, use_baseline = Fa def prefill_cache(self): pass - def load_variable(self, variable_class, update = False): + def load_variable(self, variable_class:Type[Variable], update:bool = False)->Variable: name = variable_class.__name__ # Check if a Variable with the same name is already registered. @@ -151,7 +153,7 @@ def load_variable(self, variable_class, update = False): return variable - def add_variable(self, variable): + def add_variable(self, variable:Type[Variable])->Variable: """ Adds an OpenFisca variable to the tax and benefit system. @@ -161,7 +163,7 @@ def add_variable(self, variable): """ return self.load_variable(variable, update = False) - def replace_variable(self, variable): + def replace_variable(self, variable:Type[Variable])->None: """ Replaces an existing OpenFisca variable in the tax and benefit system by a new one. @@ -176,7 +178,7 @@ def replace_variable(self, variable): del self.variables[name] self.load_variable(variable, update = False) - def update_variable(self, variable): + def update_variable(self, variable:Type[Variable])->Variable: """ Updates an existing OpenFisca variable in the tax and benefit system. @@ -190,7 +192,7 @@ def update_variable(self, variable): """ return self.load_variable(variable, update = True) - def add_variables_from_file(self, file_path): + def add_variables_from_file(self, file_path:str)->None: """ Adds all OpenFisca variables contained in a given file to the tax and benefit system. """ @@ -217,7 +219,7 @@ def add_variables_from_file(self, file_path): log.error('Unable to load OpenFisca variables from file "{}"'.format(file_path)) raise - def add_variables_from_directory(self, directory): + def add_variables_from_directory(self, directory:str)->None: """ Recursively explores a directory, and adds all OpenFisca variables found there to the tax and benefit system. """ @@ -228,7 +230,7 @@ def add_variables_from_directory(self, directory): for subdirectory in subdirectories: self.add_variables_from_directory(subdirectory) - def add_variables(self, *variables): + def add_variables(self, *variables:Type[Variable])->None: """ Adds a list of OpenFisca Variables to the `TaxBenefitSystem`. @@ -237,7 +239,7 @@ def add_variables(self, *variables): for variable in variables: self.add_variable(variable) - def load_extension(self, extension): + def load_extension(self, extension)->None: """ Loads an extension to the tax and benefit system. @@ -261,7 +263,7 @@ def load_extension(self, extension): extension_parameters = ParameterNode(directory_path = param_dir) self.parameters.merge(extension_parameters) - def apply_reform(self, reform_path): + def apply_reform(self, reform_path:str)->Reform: """ Generates a new tax and benefit system applying a reform to the tax and benefit system. @@ -296,7 +298,7 @@ def apply_reform(self, reform_path): return reform(self) - def get_variable(self, variable_name, check_existence = False): + def get_variable(self, variable_name:str, check_existence:bool = False)->Variable: """ Get a variable from the tax and benefit system. @@ -309,7 +311,7 @@ def get_variable(self, variable_name, check_existence = False): raise VariableNotFound(variable_name, self) return found - def neutralize_variable(self, variable_name): + def neutralize_variable(self, variable_name:str)->None: """ Neutralizes an OpenFisca variable existing in the tax and benefit system. @@ -319,10 +321,10 @@ def neutralize_variable(self, variable_name): """ self.variables[variable_name] = get_neutralized_variable(self.get_variable(variable_name)) - def annualize_variable(self, variable_name: str, period: Optional[Period] = None): + def annualize_variable(self, variable_name: str, period: Optional[Period] = None)->None: self.variables[variable_name] = get_annualized_variable(self.get_variable(variable_name, period)) - def load_parameters(self, path_to_yaml_dir): + def load_parameters(self, path_to_yaml_dir:str)->ParameterNode: """ Loads the legislation parameter for a directory containing YAML parameters files. @@ -346,11 +348,10 @@ def _get_baseline_parameters_at_instant(self, instant): return self.get_parameters_at_instant(instant) return baseline._get_baseline_parameters_at_instant(instant) - def get_parameters_at_instant(self, instant): + def get_parameters_at_instant(self, instant:Union[str, Instant]): """ Get the parameters of the legislation at a given instant - :param instant: string of the format 'YYYY-MM-DD' or `openfisca_core.periods.Instant` instance. :returns: The parameters of the legislation at a given instant. :rtype: :any:`ParameterNodeAtInstant` """ @@ -367,7 +368,7 @@ def get_parameters_at_instant(self, instant): self._parameters_at_instant_cache[instant] = parameters_at_instant return parameters_at_instant - def get_package_metadata(self): + def get_package_metadata(self)->Dict[str, str]: """ Gets metatada relative to the country package the tax and benefit system is built from. @@ -418,7 +419,7 @@ def get_package_metadata(self): 'location': location, } - def get_variables(self, entity = None): + def get_variables(self, entity = None)->Dict[str, Variable]: """ Gets all variables contained in a tax and benefit system. @@ -454,8 +455,8 @@ def clone(self): new_dict['open_api_config'] = self.open_api_config.copy() return new - def entities_plural(self): + def entities_plural(self)->Dict[str, Entity]: return {entity.plural for entity in self.entities} - def entities_by_singular(self): + def entities_by_singular(self)->Dict[str, Entity]: return {entity.key: entity for entity in self.entities} From a02fad9525cf6c4b0bd0c5d0a1f017135a33491f Mon Sep 17 00:00:00 2001 From: cescofran Date: Tue, 16 Feb 2021 13:20:47 +0100 Subject: [PATCH 6/7] Add typings to the simulations module --- openfisca_core/simulations.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/openfisca_core/simulations.py b/openfisca_core/simulations.py index c27dd677d1..17e732fe7f 100644 --- a/openfisca_core/simulations.py +++ b/openfisca_core/simulations.py @@ -2,6 +2,7 @@ import tempfile import logging +from typing import Set, Tuple, Union import numpy as np @@ -9,6 +10,7 @@ from openfisca_core.commons import empty_clone from openfisca_core.tracers import TracingParameterNodeAtInstant, SimpleTracer, FullTracer from openfisca_core.indexed_enums import Enum, EnumArray +from .taxbenefitstysms import TaxBenefitSystem log = logging.getLogger(__name__) @@ -36,7 +38,7 @@ class Simulation(object): def __init__( self, - tax_benefit_system, + tax_benefit_system:TaxBenefitSystem, populations ): """ @@ -44,7 +46,7 @@ def __init__( which is the preferred way to obtain a Simulation initialized with a consistent set of Entities. """ - self.tax_benefit_system = tax_benefit_system + self.tax_benefit_system:TaxBenefitSystem = tax_benefit_system assert tax_benefit_system is not None self.populations = populations @@ -52,11 +54,11 @@ def __init__( self.link_to_entities_instances() self.create_shortcuts() - self.invalidated_caches = set() + self.invalidated_caches:Set[Tuple[str, str]] = set() - self.debug = False - self.trace = False - self.tracer = SimpleTracer() + self.debug:bool = False + self.trace:bool = False + self.tracer:Union[SimpleTracer, FullTracer] = SimpleTracer() self.opt_out_cache = False # controls the spirals detection; check for performance impact if > 1 @@ -65,22 +67,22 @@ def __init__( self._data_storage_dir = None @property - def trace(self): + def trace(self)->bool: return self._trace @trace.setter - def trace(self, trace): + def trace(self, trace:bool)->None: self._trace = trace if trace: self.tracer = FullTracer() else: self.tracer = SimpleTracer() - def link_to_entities_instances(self): + def link_to_entities_instances(self)->None: for _key, entity_instance in self.populations.items(): entity_instance.simulation = self - def create_shortcuts(self): + def create_shortcuts(self)->None: for _key, population in self.populations.items(): # create shortcut simulation.person and simulation.household (for instance) setattr(self, population.entity.key, population) @@ -115,7 +117,7 @@ def calculate(self, variable_name, period): self.tracer.record_calculation_end() self.purge_cache_of_invalid_values() - def _calculate(self, variable_name, period: periods.Period): + def _calculate(self, variable_name:str, period: periods.Period)->np.array: """ Calculate the variable ``variable_name`` for the period ``period``, using the variable formula if it exists. @@ -158,7 +160,7 @@ def purge_cache_of_invalid_values(self): for (_name, _period) in self.invalidated_caches: holder = self.get_holder(_name) holder.delete_arrays(_period) - self.invalidated_caches = set() + self.invalidated_caches:Set[Tuple[str, str]] = set() def calculate_add(self, variable_name, period): variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True) @@ -306,7 +308,7 @@ def _check_for_cycle(self, variable: str, period): message = "Quasicircular definition detected on formula {}@{} involving {}".format(variable, period, self.tracer.stack) raise SpiralError(message, variable) - def invalidate_cache_entry(self, variable: str, period): + def invalidate_cache_entry(self, variable: str, period:str): self.invalidated_caches.add((variable, period)) def invalidate_spiral_variables(self, variable: str): From 714bea97d21a5dd97fa7381474a6f8087086564a Mon Sep 17 00:00:00 2001 From: cescofran Date: Tue, 16 Feb 2021 13:21:19 +0100 Subject: [PATCH 7/7] Add typings to the simulation_builder module --- openfisca_core/simulation_builder.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/openfisca_core/simulation_builder.py b/openfisca_core/simulation_builder.py index efc341bf5d..a5a5527969 100644 --- a/openfisca_core/simulation_builder.py +++ b/openfisca_core/simulation_builder.py @@ -12,6 +12,8 @@ from openfisca_core.errors import VariableNotFound, SituationParsingError, PeriodMismatchError from openfisca_core.periods import period, key_period_size from openfisca_core.simulations import Simulation +from .taxbenefitsystems import TaxBenefitSystem +from .simulations import Simulation class SimulationBuilder(object): @@ -40,7 +42,7 @@ def __init__(self): self.axes_memberships: Dict[Entity.plural, List[int]] = {} self.axes_roles: Dict[Entity.plural, List[int]] = {} - def build_from_dict(self, tax_benefit_system, input_dict): + def build_from_dict(self, tax_benefit_system:TaxBenefitSystem, input_dict)->Simulation: """ Build a simulation from ``input_dict`` @@ -56,7 +58,7 @@ def build_from_dict(self, tax_benefit_system, input_dict): else: return self.build_from_variables(tax_benefit_system, input_dict) - def build_from_entities(self, tax_benefit_system, input_dict): + def build_from_entities(self, tax_benefit_system:TaxBenefitSystem, input_dict)->Simulation: """ Build a simulation from a Python dict ``input_dict`` fully specifying entities.