diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ed781161..99c4419ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +### 41.4.1 [#1202](https://github.com/openfisca/openfisca-core/pull/1202) + +#### Technical changes + +- Check that entities are fully specified when expanding over axes. + ## 41.4.0 [#1197](https://github.com/openfisca/openfisca-core/pull/1197) #### New features diff --git a/openfisca_core/errors/situation_parsing_error.py b/openfisca_core/errors/situation_parsing_error.py index 7b68430db..ff3839d5f 100644 --- a/openfisca_core/errors/situation_parsing_error.py +++ b/openfisca_core/errors/situation_parsing_error.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from collections.abc import Iterable + import os import dpath.util @@ -8,7 +12,9 @@ class SituationParsingError(Exception): Exception raised when the situation provided as an input for a simulation cannot be parsed """ - def __init__(self, path, message, code=None): + def __init__( + self, path: Iterable[str], message: str, code: int | None = None + ) -> None: self.error = {} dpath_path = "/".join([str(item) for item in path]) message = str(message).strip(os.linesep).replace(os.linesep, " ") @@ -16,5 +22,5 @@ def __init__(self, path, message, code=None): self.code = code Exception.__init__(self, str(self.error)) - def __str__(self): + def __str__(self) -> str: return str(self.error) diff --git a/openfisca_core/simulations/_build_default_simulation.py b/openfisca_core/simulations/_build_default_simulation.py new file mode 100644 index 000000000..f99c1d210 --- /dev/null +++ b/openfisca_core/simulations/_build_default_simulation.py @@ -0,0 +1,162 @@ +"""This module contains the _BuildDefaultSimulation class.""" + +from typing import Union +from typing_extensions import Self + +import numpy + +from .simulation import Simulation +from .typing import Entity, Population, TaxBenefitSystem + + +class _BuildDefaultSimulation: + """Build a default simulation. + + Args: + tax_benefit_system(TaxBenefitSystem): The tax-benefit system. + count(int): The number of periods. + + Examples: + >>> from openfisca_core import entities, taxbenefitsystems + + >>> role = {"key": "stray", "plural": "stray", "label": "", "doc": ""} + >>> single_entity = entities.Entity("dog", "dogs", "", "") + >>> group_entity = entities.GroupEntity("pack", "packs", "", "", [role]) + >>> test_entities = [single_entity, group_entity] + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem(test_entities) + >>> count = 1 + >>> builder = ( + ... _BuildDefaultSimulation(tax_benefit_system, count) + ... .add_count() + ... .add_ids() + ... .add_members_entity_id() + ... ) + + >>> builder.count + 1 + + >>> sorted(builder.populations.keys()) + ['dog', 'pack'] + + >>> sorted(builder.simulation.populations.keys()) + ['dog', 'pack'] + + """ + + #: The number of Population. + count: int + + #: The built populations. + populations: dict[str, Union[Population[Entity]]] + + #: The built simulation. + simulation: Simulation + + def __init__(self, tax_benefit_system: TaxBenefitSystem, count: int) -> None: + self.count = count + self.populations = tax_benefit_system.instantiate_entities() + self.simulation = Simulation(tax_benefit_system, self.populations) + + def add_count(self) -> Self: + """Add the number of Population to the simulation. + + Returns: + _BuildDefaultSimulation: The builder. + + Examples: + >>> from openfisca_core import entities, taxbenefitsystems + + >>> role = {"key": "stray", "plural": "stray", "label": "", "doc": ""} + >>> single_entity = entities.Entity("dog", "dogs", "", "") + >>> group_entity = entities.GroupEntity("pack", "packs", "", "", [role]) + >>> test_entities = [single_entity, group_entity] + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem(test_entities) + >>> count = 2 + >>> builder = _BuildDefaultSimulation(tax_benefit_system, count) + + >>> builder.add_count() + <..._BuildDefaultSimulation object at ...> + + >>> builder.populations["dog"].count + 2 + + >>> builder.populations["pack"].count + 2 + + """ + + for population in self.populations.values(): + population.count = self.count + + return self + + def add_ids(self) -> Self: + """Add the populations ids to the simulation. + + Returns: + _BuildDefaultSimulation: The builder. + + Examples: + >>> from openfisca_core import entities, taxbenefitsystems + + >>> role = {"key": "stray", "plural": "stray", "label": "", "doc": ""} + >>> single_entity = entities.Entity("dog", "dogs", "", "") + >>> group_entity = entities.GroupEntity("pack", "packs", "", "", [role]) + >>> test_entities = [single_entity, group_entity] + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem(test_entities) + >>> count = 2 + >>> builder = _BuildDefaultSimulation(tax_benefit_system, count) + + >>> builder.add_ids() + <..._BuildDefaultSimulation object at ...> + + >>> builder.populations["dog"].ids + array([0, 1]) + + >>> builder.populations["pack"].ids + array([0, 1]) + + """ + + for population in self.populations.values(): + population.ids = numpy.array(range(self.count)) + + return self + + def add_members_entity_id(self) -> Self: + """Add ??? + + Each SingleEntity has its own GroupEntity. + + Returns: + _BuildDefaultSimulation: The builder. + + Examples: + >>> from openfisca_core import entities, taxbenefitsystems + + >>> role = {"key": "stray", "plural": "stray", "label": "", "doc": ""} + >>> single_entity = entities.Entity("dog", "dogs", "", "") + >>> group_entity = entities.GroupEntity("pack", "packs", "", "", [role]) + >>> test_entities = [single_entity, group_entity] + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem(test_entities) + >>> count = 2 + >>> builder = _BuildDefaultSimulation(tax_benefit_system, count) + + >>> builder.add_members_entity_id() + <..._BuildDefaultSimulation object at ...> + + >>> population = builder.populations["pack"] + + >>> hasattr(population, "members_entity_id") + True + + >>> population.members_entity_id + array([0, 1]) + + """ + + for population in self.populations.values(): + if hasattr(population, "members_entity_id"): + population.members_entity_id = numpy.array(range(self.count)) + + return self diff --git a/openfisca_core/simulations/_build_from_variables.py b/openfisca_core/simulations/_build_from_variables.py new file mode 100644 index 000000000..60ff6148e --- /dev/null +++ b/openfisca_core/simulations/_build_from_variables.py @@ -0,0 +1,232 @@ +"""This module contains the _BuildFromVariables class.""" + +from __future__ import annotations + +from typing_extensions import Self + +from openfisca_core import errors + +from ._build_default_simulation import _BuildDefaultSimulation +from ._type_guards import is_variable_dated +from .simulation import Simulation +from .typing import Entity, Population, TaxBenefitSystem, Variables + + +class _BuildFromVariables: + """Build a simulation from variables. + + Args: + tax_benefit_system(TaxBenefitSystem): The tax-benefit system. + params(Variables): The simulation parameters. + + Examples: + >>> from openfisca_core import entities, periods, taxbenefitsystems, variables + + >>> role = {"key": "stray", "plural": "stray", "label": "", "doc": ""} + >>> single_entity = entities.Entity("dog", "dogs", "", "") + >>> group_entity = entities.GroupEntity("pack", "packs", "", "", [role]) + + >>> class salary(variables.Variable): + ... definition_period = periods.DateUnit.MONTH + ... entity = single_entity + ... value_type = int + + >>> class taxes(variables.Variable): + ... definition_period = periods.DateUnit.MONTH + ... entity = group_entity + ... value_type = int + + >>> test_entities = [single_entity, group_entity] + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem(test_entities) + >>> tax_benefit_system.load_variable(salary) + <...salary object at ...> + >>> tax_benefit_system.load_variable(taxes) + <...taxes object at ...> + >>> period = "2023-12" + >>> variables = {"salary": {period: 10000}, "taxes": 5000} + >>> builder = ( + ... _BuildFromVariables(tax_benefit_system, variables, period) + ... .add_dated_values() + ... .add_undated_values() + ... ) + + >>> dogs = builder.populations["dog"].get_holder("salary") + >>> dogs.get_array(period) + array([10000], dtype=int32) + + >>> pack = builder.populations["pack"].get_holder("taxes") + >>> pack.get_array(period) + array([5000], dtype=int32) + + """ + + #: The number of Population. + count: int + + #: The Simulation's default period. + default_period: str | None + + #: The built populations. + populations: dict[str, Population[Entity]] + + #: The built simulation. + simulation: Simulation + + #: The simulation parameters. + variables: Variables + + def __init__( + self, + tax_benefit_system: TaxBenefitSystem, + params: Variables, + default_period: str | None = None, + ) -> None: + self.count = _person_count(params) + + default_builder = ( + _BuildDefaultSimulation(tax_benefit_system, self.count) + .add_count() + .add_ids() + .add_members_entity_id() + ) + + self.variables = params + self.simulation = default_builder.simulation + self.populations = default_builder.populations + self.default_period = default_period + + def add_dated_values(self) -> Self: + """Add the dated input values to the Simulation. + + Returns: + _BuildFromVariables: The builder. + + Examples: + >>> from openfisca_core import entities, periods, taxbenefitsystems, variables + + >>> role = {"key": "stray", "plural": "stray", "label": "", "doc": ""} + >>> single_entity = entities.Entity("dog", "dogs", "", "") + >>> group_entity = entities.GroupEntity("pack", "packs", "", "", [role]) + + + >>> class salary(variables.Variable): + ... definition_period = periods.DateUnit.MONTH + ... entity = single_entity + ... value_type = int + + >>> class taxes(variables.Variable): + ... definition_period = periods.DateUnit.MONTH + ... entity = group_entity + ... value_type = int + + >>> test_entities = [single_entity, group_entity] + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem(test_entities) + >>> tax_benefit_system.load_variable(salary) + <...salary object at ...> + >>> tax_benefit_system.load_variable(taxes) + <...taxes object at ...> + >>> period = "2023-12" + >>> variables = {"salary": {period: 10000}, "taxes": 5000} + >>> builder = _BuildFromVariables(tax_benefit_system, variables) + >>> builder.add_dated_values() + <..._BuildFromVariables object at ...> + + >>> dogs = builder.populations["dog"].get_holder("salary") + >>> dogs.get_array(period) + array([10000], dtype=int32) + + >>> pack = builder.populations["pack"].get_holder("taxes") + >>> pack.get_array(period) + + """ + + for variable, value in self.variables.items(): + if is_variable_dated(dated_variable := value): + for period, dated_value in dated_variable.items(): + self.simulation.set_input(variable, period, dated_value) + + return self + + def add_undated_values(self) -> Self: + """Add the undated input values to the Simulation. + + Returns: + _BuildFromVariables: The builder. + + Raises: + SituationParsingError: If there is not a default period set. + + Examples: + >>> from openfisca_core import entities, periods, taxbenefitsystems, variables + + >>> role = {"key": "stray", "plural": "stray", "label": "", "doc": ""} + >>> single_entity = entities.Entity("dog", "dogs", "", "") + >>> group_entity = entities.GroupEntity("pack", "packs", "", "", [role]) + + >>> class salary(variables.Variable): + ... definition_period = periods.DateUnit.MONTH + ... entity = single_entity + ... value_type = int + + >>> class taxes(variables.Variable): + ... definition_period = periods.DateUnit.MONTH + ... entity = group_entity + ... value_type = int + + >>> test_entities = [single_entity, group_entity] + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem(test_entities) + >>> tax_benefit_system.load_variable(salary) + <...salary object at ...> + >>> tax_benefit_system.load_variable(taxes) + <...taxes object at ...> + >>> period = "2023-12" + >>> variables = {"salary": {period: 10000}, "taxes": 5000} + >>> builder = _BuildFromVariables(tax_benefit_system, variables) + >>> builder.add_undated_values() + Traceback (most recent call last): + openfisca_core.errors.situation_parsing_error.SituationParsingError + >>> builder.default_period = period + >>> builder.add_undated_values() + <..._BuildFromVariables object at ...> + + >>> dogs = builder.populations["dog"].get_holder("salary") + >>> dogs.get_array(period) + + >>> pack = builder.populations["pack"].get_holder("taxes") + >>> pack.get_array(period) + array([5000], dtype=int32) + + """ + + for variable, value in self.variables.items(): + if not is_variable_dated(undated_value := value): + if (period := self.default_period) is None: + message = ( + "Can't deal with type: expected object. Input " + "variables should be set for specific periods. For " + "instance: " + " {'salary': {'2017-01': 2000, '2017-02': 2500}}" + " {'birth_date': {'ETERNITY': '1980-01-01'}}" + ) + + raise errors.SituationParsingError([variable], message) + + self.simulation.set_input(variable, period, undated_value) + + return self + + +def _person_count(params: Variables) -> int: + try: + first_value = next(iter(params.values())) + + if isinstance(first_value, dict): + first_value = next(iter(first_value.values())) + + if isinstance(first_value, str): + return 1 + + return len(first_value) + + except Exception: + return 1 diff --git a/openfisca_core/simulations/_type_guards.py b/openfisca_core/simulations/_type_guards.py new file mode 100644 index 000000000..c34361041 --- /dev/null +++ b/openfisca_core/simulations/_type_guards.py @@ -0,0 +1,304 @@ +"""Type guards to help type narrowing simulation parameters.""" + +from __future__ import annotations + +from typing import Iterable +from typing_extensions import TypeGuard + +from .typing import ( + Axes, + DatedVariable, + FullySpecifiedEntities, + ImplicitGroupEntities, + Params, + UndatedVariable, + Variables, +) + + +def are_entities_fully_specified( + params: Params, items: Iterable[str] +) -> TypeGuard[FullySpecifiedEntities]: + """Check if the params contain fully specified entities. + + Args: + params(Params): Simulation parameters. + items(Iterable[str]): List of entities in plural form. + + Returns: + bool: True if the params contain fully specified entities. + + Examples: + >>> entities = {"persons", "households"} + + >>> params = { + ... "axes": [ + ... [{"count": 2, "max": 3000, "min": 0, "name": "rent", "period": "2018-11"}] + ... ], + ... "households": { + ... "housea": {"parents": ["Alicia", "Javier"]}, + ... "houseb": {"parents": ["Tom"]}, + ... }, + ... "persons": {"Alicia": {"salary": {"2018-11": 0}}, "Javier": {}, "Tom": {}}, + ... } + + >>> are_entities_fully_specified(params, entities) + True + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": [2000, 3000]}}} + ... } + + >>> are_entities_fully_specified(params, entities) + True + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}} + ... } + + >>> are_entities_fully_specified(params, entities) + True + + >>> params = {"household": {"parents": ["Javier"]}} + + >>> are_entities_fully_specified(params, entities) + False + + >>> params = {"salary": {"2016-10": 12000}} + + >>> are_entities_fully_specified(params, entities) + False + + >>> params = {"salary": 12000} + + >>> are_entities_fully_specified(params, entities) + False + + >>> params = {} + + >>> are_entities_fully_specified(params, entities) + False + + """ + + if not params: + return False + + return all(key in items for key in params.keys() if key != "axes") + + +def are_entities_short_form( + params: Params, items: Iterable[str] +) -> TypeGuard[ImplicitGroupEntities]: + """Check if the params contain short form entities. + + Args: + params(Params): Simulation parameters. + items(Iterable[str]): List of entities in singular form. + + Returns: + bool: True if the params contain short form entities. + + Examples: + >>> entities = {"person", "household"} + + >>> params = { + ... "persons": {"Javier": { "salary": { "2018-11": 2000}}}, + ... "households": {"household": {"parents": ["Javier"]}}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]] + ... } + + >>> are_entities_short_form(params, entities) + False + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}} + ... } + + >>> are_entities_short_form(params, entities) + False + + >>> params = { + ... "persons": {"Javier": { "salary": { "2018-11": 2000}}}, + ... "household": {"parents": ["Javier"]}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]] + ... } + + >>> are_entities_short_form(params, entities) + True + + >>> params = { + ... "household": {"parents": ["Javier"]}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]] + ... } + + >>> are_entities_short_form(params, entities) + True + + >>> params = {"household": {"parents": ["Javier"]}} + + >>> are_entities_short_form(params, entities) + True + + >>> params = {"household": {"parents": "Javier"}} + + >>> are_entities_short_form(params, entities) + True + + >>> params = {"salary": {"2016-10": 12000}} + + >>> are_entities_short_form(params, entities) + False + + >>> params = {"salary": 12000} + + >>> are_entities_short_form(params, entities) + False + + >>> params = {} + + >>> are_entities_short_form(params, entities) + False + + """ + + return not not set(params).intersection(items) + + +def are_entities_specified( + params: Params, items: Iterable[str] +) -> TypeGuard[Variables]: + """Check if the params contains entities at all. + + Args: + params(Params): Simulation parameters. + items(Iterable[str]): List of variables. + + Returns: + bool: True if the params does not contain variables at the root level. + + Examples: + >>> variables = {"salary"} + + >>> params = { + ... "persons": {"Javier": { "salary": { "2018-11": 2000}}}, + ... "households": {"household": {"parents": ["Javier"]}}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]] + ... } + + >>> are_entities_specified(params, variables) + True + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": [2000, 3000]}}} + ... } + + >>> are_entities_specified(params, variables) + True + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}} + ... } + + >>> are_entities_specified(params, variables) + True + + >>> params = {"household": {"parents": ["Javier"]}} + + >>> are_entities_specified(params, variables) + True + + >>> params = {"salary": {"2016-10": [12000, 13000]}} + + >>> are_entities_specified(params, variables) + False + + >>> params = {"salary": {"2016-10": 12000}} + + >>> are_entities_specified(params, variables) + False + + >>> params = {"salary": [12000, 13000]} + + >>> are_entities_specified(params, variables) + False + + >>> params = {"salary": 12000} + + >>> are_entities_specified(params, variables) + False + + >>> params = {} + + >>> are_entities_specified(params, variables) + False + + """ + + if not params: + return False + + return not any(key in items for key in params.keys()) + + +def has_axes(params: Params) -> TypeGuard[Axes]: + """Check if the params contains axes. + + Args: + params(Params): Simulation parameters. + + Returns: + bool: True if the params contain axes. + + Examples: + >>> params = { + ... "persons": {"Javier": { "salary": { "2018-11": 2000}}}, + ... "households": {"household": {"parents": ["Javier"]}}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]] + ... } + + >>> has_axes(params) + True + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": [2000, 3000]}}} + ... } + + >>> has_axes(params) + False + + """ + + return params.get("axes", None) is not None + + +def is_variable_dated( + variable: DatedVariable | UndatedVariable, +) -> TypeGuard[DatedVariable]: + """Check if the variable is dated. + + Args: + variable(DatedVariable | UndatedVariable): A variable. + + Returns: + bool: True if the variable is dated. + + Examples: + >>> variable = {"2018-11": [2000, 3000]} + + >>> is_variable_dated(variable) + True + + >>> variable = {"2018-11": 2000} + + >>> is_variable_dated(variable) + True + + >>> variable = 2000 + + >>> is_variable_dated(variable) + False + + """ + + return isinstance(variable, dict) diff --git a/openfisca_core/simulations/helpers.py b/openfisca_core/simulations/helpers.py index b559f7d07..d5984d88b 100644 --- a/openfisca_core/simulations/helpers.py +++ b/openfisca_core/simulations/helpers.py @@ -1,4 +1,8 @@ -from openfisca_core.errors import SituationParsingError +from collections.abc import Iterable + +from openfisca_core import errors + +from .typing import ParamsWithoutAxes def calculate_output_add(simulation, variable_name: str, period): @@ -20,28 +24,93 @@ def check_type(input, input_type, path=None): path = [] if not isinstance(input, input_type): - raise SituationParsingError( + raise errors.SituationParsingError( path, "Invalid type: must be of type '{}'.".format(json_type_map[input_type]), ) +def check_unexpected_entities( + params: ParamsWithoutAxes, entities: Iterable[str] +) -> None: + """Check if the input contains entities that are not in the system. + + Args: + params(ParamsWithoutAxes): Simulation parameters. + entities(Iterable[str]): List of entities in plural form. + + Raises: + SituationParsingError: If there are entities that are not in the system. + + Examples: + >>> entities = {"persons", "households"} + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, + ... "households": {"household": {"parents": ["Javier"]}} + ... } + + >>> check_unexpected_entities(params, entities) + + >>> params = { + ... "dogs": {"Bart": {"damages": {"2018-11": 2000}}} + ... } + + >>> check_unexpected_entities(params, entities) + Traceback (most recent call last): + openfisca_core.errors.situation_parsing_error.SituationParsingError + + """ + + if has_unexpected_entities(params, entities): + unexpected_entities = [entity for entity in params if entity not in entities] + + message = ( + "Some entities in the situation are not defined in the loaded tax " + "and benefit system. " + f"These entities are not found: {', '.join(unexpected_entities)}. " + f"The defined entities are: {', '.join(entities)}." + ) + + raise errors.SituationParsingError([unexpected_entities[0]], message) + + +def has_unexpected_entities(params: ParamsWithoutAxes, entities: Iterable[str]) -> bool: + """Check if the input contains entities that are not in the system. + + Args: + params(ParamsWithoutAxes): Simulation parameters. + entities(Iterable[str]): List of entities in plural form. + + Returns: + bool: True if the input contains entities that are not in the system. + + Examples: + >>> entities = {"persons", "households"} + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, + ... "households": {"household": {"parents": ["Javier"]}} + ... } + + >>> has_unexpected_entities(params, entities) + False + + >>> params = { + ... "dogs": {"Bart": {"damages": {"2018-11": 2000}}} + ... } + + >>> has_unexpected_entities(params, entities) + True + + """ + + return any(entity for entity in params if entity not in entities) + + def transform_to_strict_syntax(data): if isinstance(data, (str, int)): data = [data] if isinstance(data, list): return [str(item) if isinstance(item, int) else item for item in data] return data - - -def _get_person_count(input_dict): - try: - first_value = next(iter(input_dict.values())) - if isinstance(first_value, dict): - first_value = next(iter(first_value.values())) - if isinstance(first_value, str): - return 1 - - return len(first_value) - except Exception: - return 1 diff --git a/openfisca_core/simulations/simulation.py b/openfisca_core/simulations/simulation.py index 304d7338a..9e7e0034e 100644 --- a/openfisca_core/simulations/simulation.py +++ b/openfisca_core/simulations/simulation.py @@ -1,7 +1,7 @@ from __future__ import annotations from openfisca_core.types import Population, TaxBenefitSystem, Variable -from typing import Dict, NamedTuple, Optional, Set +from typing import Dict, Mapping, NamedTuple, Optional, Set import tempfile import warnings @@ -24,7 +24,7 @@ class Simulation: def __init__( self, tax_benefit_system: TaxBenefitSystem, - populations: Dict[str, Population], + populations: Mapping[str, Population], ): """ This constructor is reserved for internal use; see :any:`SimulationBuilder`, diff --git a/openfisca_core/simulations/simulation_builder.py b/openfisca_core/simulations/simulation_builder.py index d48ac0152..c42d0e4f2 100644 --- a/openfisca_core/simulations/simulation_builder.py +++ b/openfisca_core/simulations/simulation_builder.py @@ -1,4 +1,8 @@ -from typing import Dict, Iterable, List +from __future__ import annotations + +from collections.abc import Iterable, Sequence +from numpy.typing import NDArray as Array +from typing import Dict, List import copy @@ -8,7 +12,30 @@ from openfisca_core import entities, errors, periods, populations, variables from . import helpers +from ._build_default_simulation import _BuildDefaultSimulation +from ._build_from_variables import _BuildFromVariables +from ._type_guards import ( + are_entities_fully_specified, + are_entities_short_form, + are_entities_specified, + has_axes, +) from .simulation import Simulation +from .typing import ( + Axis, + Entity, + FullySpecifiedEntities, + GroupEntities, + GroupEntity, + ImplicitGroupEntities, + Params, + ParamsWithoutAxes, + Population, + Role, + SingleEntity, + TaxBenefitSystem, + Variables, +) class SimulationBuilder: @@ -42,93 +69,177 @@ def __init__(self): self.axes_memberships: Dict[entities.Entity.plural, List[int]] = {} self.axes_roles: Dict[entities.Entity.plural, List[int]] = {} - def build_from_dict(self, tax_benefit_system, input_dict): - """ - Build a simulation from ``input_dict`` + def build_from_dict( + self, + tax_benefit_system: TaxBenefitSystem, + input_dict: Params, + ) -> Simulation: + """Build a simulation from an input dictionary. - This method uses :any:`build_from_entities` if entities are fully specified, or :any:`build_from_variables` if not. + This method uses :meth:`.SimulationBuilder.build_from_entities` if + entities are fully specified, or + :meth:`.SimulationBuilder.build_from_variables` if they are not. - :param dict input_dict: A dict represeting the input of the simulation - :return: A :any:`Simulation` - """ + Args: + tax_benefit_system(TaxBenefitSystem): The system to use. + input_dict(Params): The input of the simulation. - input_dict = self.explicit_singular_entities(tax_benefit_system, input_dict) - if any( - key in tax_benefit_system.entities_plural() for key in input_dict.keys() - ): - return self.build_from_entities(tax_benefit_system, input_dict) - else: - return self.build_from_variables(tax_benefit_system, input_dict) + Returns: + Simulation: The built simulation. + + Examples: + >>> entities = {"person", "household"} + + >>> params = { + ... "persons": {"Javier": { "salary": { "2018-11": 2000}}}, + ... "household": {"parents": ["Javier"]}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]] + ... } + + >>> are_entities_short_form(params, entities) + True + + >>> entities = {"persons", "households"} + + >>> params = { + ... "axes": [ + ... [{"count": 2, "max": 3000, "min": 0, "name": "rent", "period": "2018-11"}] + ... ], + ... "households": { + ... "housea": {"parents": ["Alicia", "Javier"]}, + ... "houseb": {"parents": ["Tom"]}, + ... }, + ... "persons": {"Alicia": {"salary": {"2018-11": 0}}, "Javier": {}, "Tom": {}}, + ... } + + >>> are_entities_short_form(params, entities) + True + + >>> params = {"salary": [12000, 13000]} + + >>> not are_entities_specified(params, {"salary"}) + True - def build_from_entities(self, tax_benefit_system, input_dict): """ - Build a simulation from a Python dict ``input_dict`` fully specifying entities. + + #: The plural names of the entities in the tax and benefits system. + plural: Iterable[str] = tax_benefit_system.entities_plural() + + #: The singular names of the entities in the tax and benefits system. + singular: Iterable[str] = tax_benefit_system.entities_by_singular() + + #: The names of the variables in the tax and benefits system. + variables: Iterable[str] = tax_benefit_system.variables.keys() + + if are_entities_short_form(input_dict, singular): + params = self.explicit_singular_entities(tax_benefit_system, input_dict) + return self.build_from_entities(tax_benefit_system, params) + + if are_entities_fully_specified(params := input_dict, plural): + return self.build_from_entities(tax_benefit_system, params) + + if not are_entities_specified(params := input_dict, variables): + return self.build_from_variables(tax_benefit_system, params) + + def build_from_entities( + self, + tax_benefit_system: TaxBenefitSystem, + input_dict: FullySpecifiedEntities, + ) -> Simulation: + """Build a simulation from a Python dict ``input_dict`` fully specifying + entities. Examples: + >>> entities = {"person", "household"} + + >>> params = { + ... "persons": {"Javier": { "salary": { "2018-11": 2000}}}, + ... "household": {"parents": ["Javier"]}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]] + ... } + + >>> are_entities_short_form(params, entities) + True - >>> simulation_builder.build_from_entities({ - 'persons': {'Javier': { 'salary': {'2018-11': 2000}}}, - 'households': {'household': {'parents': ['Javier']}} - }) """ + + # Create the populations + populations = tax_benefit_system.instantiate_entities() + + # Create the simulation + simulation = Simulation(tax_benefit_system, populations) + + # Why? input_dict = copy.deepcopy(input_dict) - simulation = Simulation( - tax_benefit_system, tax_benefit_system.instantiate_entities() - ) + # The plural names of the entities in the tax and benefits system. + plural: Iterable[str] = tax_benefit_system.entities_plural() # Register variables so get_variable_entity can find them - for variable_name, _variable in tax_benefit_system.variables.items(): - self.register_variable( - variable_name, simulation.get_variable_population(variable_name).entity - ) + self.register_variables(simulation) + + # Declare axes + axes: list[list[Axis]] | None = None + # ? helpers.check_type(input_dict, dict, ["error"]) - axes = input_dict.pop("axes", None) - unexpected_entities = [ - entity - for entity in input_dict - if entity not in tax_benefit_system.entities_plural() - ] - if unexpected_entities: - unexpected_entity = unexpected_entities[0] - raise errors.SituationParsingError( - [unexpected_entity], - "".join( - [ - "Some entities in the situation are not defined in the loaded tax and benefit system.", - "These entities are not found: {0}.", - "The defined entities are: {1}.", - ] - ).format( - ", ".join(unexpected_entities), - ", ".join(tax_benefit_system.entities_plural()), - ), - ) - persons_json = input_dict.get(tax_benefit_system.person_entity.plural, None) + # Remove axes from input_dict + params: ParamsWithoutAxes = { + key: value for key, value in input_dict.items() if key != "axes" + } + + # Save axes for later + if has_axes(axes_params := input_dict): + axes = copy.deepcopy(axes_params.get("axes", None)) + + # Check for unexpected entities + helpers.check_unexpected_entities(params, plural) + + person_entity: SingleEntity = tax_benefit_system.person_entity + + persons_json = params.get(person_entity.plural, None) if not persons_json: raise errors.SituationParsingError( - [tax_benefit_system.person_entity.plural], + [person_entity.plural], "No {0} found. At least one {0} must be defined to run a simulation.".format( - tax_benefit_system.person_entity.key + person_entity.key ), ) persons_ids = self.add_person_entity(simulation.persons.entity, persons_json) for entity_class in tax_benefit_system.group_entities: - instances_json = input_dict.get(entity_class.plural) + instances_json = params.get(entity_class.plural) + if instances_json is not None: self.add_group_entity( self.persons_plural, persons_ids, entity_class, instances_json ) + + elif axes is not None: + message = ( + f"We could not find any specified {entity_class.plural}. " + "In order to expand over axes, all group entities and roles " + "must be fully specified. For further support, please do " + "not hesitate to take a look at the official documentation: " + "https://openfisca.org/doc/simulate/replicate-simulation-inputs.html." + ) + + raise errors.SituationParsingError([entity_class.plural], message) + else: self.add_default_group_entity(persons_ids, entity_class) - if axes: - self.axes = axes + if axes is not None: + for axis in axes[0]: + self.add_parallel_axis(axis) + + if len(axes) >= 1: + for axis in axes[1:]: + self.add_perpendicular_axis(axis[0]) + self.expand_axes() try: @@ -145,52 +256,65 @@ def build_from_entities(self, tax_benefit_system, input_dict): return simulation - def build_from_variables(self, tax_benefit_system, input_dict): - """ - Build a simulation from a Python dict ``input_dict`` describing variables values without expliciting entities. + def build_from_variables( + self, tax_benefit_system: TaxBenefitSystem, input_dict: Variables + ) -> Simulation: + """Build a simulation from a Python dict ``input_dict`` describing + variables values without expliciting entities. - This method uses :any:`build_default_simulation` to infer an entity structure + This method uses :meth:`.SimulationBuilder.build_default_simulation` to + infer an entity structure. - Example: + Args: + tax_benefit_system(TaxBenefitSystem): The system to use. + input_dict(Variables): The input of the simulation. - >>> simulation_builder.build_from_variables( - {'salary': {'2016-10': 12000}} - ) - """ - count = helpers._get_person_count(input_dict) - simulation = self.build_default_simulation(tax_benefit_system, count) - for variable, value in input_dict.items(): - if not isinstance(value, dict): - if self.default_period is None: - raise errors.SituationParsingError( - [variable], - "Can't deal with type: expected object. Input variables should be set for specific periods. For instance: {'salary': {'2017-01': 2000, '2017-02': 2500}}, or {'birth_date': {'ETERNITY': '1980-01-01'}}.", - ) - simulation.set_input(variable, self.default_period, value) - else: - for period_str, dated_value in value.items(): - simulation.set_input(variable, period_str, dated_value) - return simulation + Returns: + Simulation: The built simulation. + + Raises: + SituationParsingError: If the input is not valid. + + Examples: + >>> params = {'salary': {'2016-10': 12000}} + + >>> are_entities_specified(params, {"salary"}) + False + + >>> params = {'salary': 12000} + + >>> are_entities_specified(params, {"salary"}) + False - def build_default_simulation(self, tax_benefit_system, count=1): """ - Build a simulation where: + + return ( + _BuildFromVariables(tax_benefit_system, input_dict, self.default_period) + .add_dated_values() + .add_undated_values() + .simulation + ) + + @staticmethod + def build_default_simulation( + tax_benefit_system: TaxBenefitSystem, count: int = 1 + ) -> Simulation: + """Build a default simulation. + + Where: - There are ``count`` persons - - There are ``count`` instances of each group entity, containing one person + - There are ``count`` of each group entity, containing one person - Every person has, in each entity, the first role + """ - simulation = Simulation( - tax_benefit_system, tax_benefit_system.instantiate_entities() + return ( + _BuildDefaultSimulation(tax_benefit_system, count) + .add_count() + .add_ids() + .add_members_entity_id() + .simulation ) - for population in simulation.populations.values(): - population.count = count - population.ids = numpy.array(range(count)) - if not population.entity.is_person: - population.members_entity_id = ( - population.ids - ) # Each person is its own group entity - return simulation def create_entities(self, tax_benefit_system): self.populations = tax_benefit_system.instantiate_entities() @@ -238,23 +362,35 @@ def join_with_persons( def build(self, tax_benefit_system): return Simulation(tax_benefit_system, self.populations) - def explicit_singular_entities(self, tax_benefit_system, input_dict): - """ - Preprocess ``input_dict`` to explicit entities defined using the single-entity shortcut + def explicit_singular_entities( + self, tax_benefit_system: TaxBenefitSystem, input_dict: ImplicitGroupEntities + ) -> GroupEntities: + """Preprocess ``input_dict`` to explicit entities defined using the + single-entity shortcut - Example: + Examples: + + >>> params = {'persons': {'Javier': {}, }, 'household': {'parents': ['Javier']}} + + >>> are_entities_fully_specified(params, {"persons", "households"}) + False + + >>> are_entities_short_form(params, {"person", "household"}) + True + + >>> params = {'persons': {'Javier': {}}, 'households': {'household': {'parents': ['Javier']}}} + + >>> are_entities_fully_specified(params, {"persons", "households"}) + True + + >>> are_entities_short_form(params, {"person", "household"}) + False - >>> simulation_builder.explicit_singular_entities( - {'persons': {'Javier': {}, }, 'household': {'parents': ['Javier']}} - ) - >>> {'persons': {'Javier': {}}, 'households': {'household': {'parents': ['Javier']}} """ singular_keys = set(input_dict).intersection( tax_benefit_system.entities_by_singular() ) - if not singular_keys: - return input_dict result = { entity_id: entity_description @@ -284,18 +420,25 @@ def add_person_entity(self, entity, instances_json): return self.get_ids(entity.plural) - def add_default_group_entity(self, persons_ids, entity): + def add_default_group_entity( + self, persons_ids: list[str], entity: GroupEntity + ) -> None: persons_count = len(persons_ids) + roles = list(entity.flattened_roles) self.entity_ids[entity.plural] = persons_ids self.entity_counts[entity.plural] = persons_count - self.memberships[entity.plural] = numpy.arange( - 0, persons_count, dtype=numpy.int32 - ) - self.roles[entity.plural] = numpy.repeat( - entity.flattened_roles[0], persons_count + self.memberships[entity.plural] = list( + numpy.arange(0, persons_count, dtype=numpy.int32) ) + self.roles[entity.plural] = [roles[0]] * persons_count - def add_group_entity(self, persons_plural, persons_ids, entity, instances_json): + def add_group_entity( + self, + persons_plural: str, + persons_ids: list[str], + entity: GroupEntity, + instances_json, + ) -> None: """ Add all instances of one of the model's entities as described in ``instances_json``. """ @@ -387,9 +530,10 @@ def set_default_period(self, period_str): if period_str: self.default_period = str(periods.period(period_str)) - def get_input(self, variable, period_str): + def get_input(self, variable: str, period_str: str) -> Array | None: if variable not in self.input_buffer: self.input_buffer[variable] = {} + return self.input_buffer[variable].get(period_str) def check_persons_to_allocate( @@ -513,7 +657,7 @@ def raise_period_mismatch(self, entity, json, e): # We do a basic research to find the culprit path culprit_path = next( dpath.util.search( - json, "*/{}/{}".format(e.variable_name, str(e.period)), yielded=True + json, f"*/{e.variable_name}/{str(e.period)}", yielded=True ), None, ) @@ -527,11 +671,11 @@ def raise_period_mismatch(self, entity, json, e): raise errors.SituationParsingError(path, e.message) # Returns the total number of instances of this entity, including when there is replication along axes - def get_count(self, entity_name): + def get_count(self, entity_name: str) -> int: return self.axes_entity_counts.get(entity_name, self.entity_counts[entity_name]) # Returns the ids of instances of this entity, including when there is replication along axes - def get_ids(self, entity_name): + def get_ids(self, entity_name: str) -> list[str]: return self.axes_entity_ids.get(entity_name, self.entity_ids[entity_name]) # Returns the memberships of individuals in this entity, including when there is replication along axes @@ -542,27 +686,27 @@ def get_memberships(self, entity_name): ) # Returns the roles of individuals in this entity, including when there is replication along axes - def get_roles(self, entity_name): + def get_roles(self, entity_name: str) -> Sequence[Role]: # Return empty array for the "persons" entity return self.axes_roles.get(entity_name, self.roles.get(entity_name, [])) - def add_parallel_axis(self, axis): + def add_parallel_axis(self, axis: Axis) -> None: # All parallel axes have the same count and entity. # Search for a compatible axis, if none exists, error out self.axes[0].append(axis) - def add_perpendicular_axis(self, axis): + def add_perpendicular_axis(self, axis: Axis) -> None: # This adds an axis perpendicular to all previous dimensions self.axes.append([axis]) - def expand_axes(self): + def expand_axes(self) -> None: # This method should be idempotent & allow change in axes - perpendicular_dimensions = self.axes + perpendicular_dimensions: list[list[Axis]] = self.axes + cell_count: int = 1 - cell_count = 1 for parallel_axes in perpendicular_dimensions: - first_axis = parallel_axes[0] - axis_count = first_axis["count"] + first_axis: Axis = parallel_axes[0] + axis_count: int = first_axis["count"] cell_count *= axis_count # Scale the "prototype" situation, repeating it cell_count times @@ -572,10 +716,16 @@ def expand_axes(self): self.get_count(entity_name) * cell_count ) # Adjust ids - original_ids = self.get_ids(entity_name) * cell_count - indices = numpy.arange(0, cell_count * self.entity_counts[entity_name]) - adjusted_ids = [id + str(ix) for id, ix in zip(original_ids, indices)] + original_ids: list[str] = self.get_ids(entity_name) * cell_count + indices: Array[numpy.int_] = numpy.arange( + 0, cell_count * self.entity_counts[entity_name] + ) + adjusted_ids: list[str] = [ + original_id + str(index) + for original_id, index in zip(original_ids, indices) + ] self.axes_entity_ids[entity_name] = adjusted_ids + # Adjust roles original_roles = self.get_roles(entity_name) adjusted_roles = original_roles * cell_count @@ -620,7 +770,7 @@ def expand_axes(self): # Set input self.input_buffer[axis_name][str(axis_period)] = array else: - first_axes_count: List[int] = ( + first_axes_count: list[int] = ( parallel_axes[0]["count"] for parallel_axes in self.axes ) axes_linspaces = [ @@ -636,9 +786,9 @@ def expand_axes(self): # Distribute values along the grid for axis in parallel_axes: axis_index = axis.get("index", 0) - axis_period = axis["period"] or self.default_period + axis_period = axis.get("period", self.default_period) axis_name = axis["name"] - variable = axis_entity.get_variable(axis_name) + variable = axis_entity.get_variable(axis_name, check_existence=True) array = self.get_input(axis_name, str(axis_period)) if array is None: array = variable.default_array( @@ -653,8 +803,17 @@ def expand_axes(self): ) self.input_buffer[axis_name][str(axis_period)] = array - def get_variable_entity(self, variable_name: str): + def get_variable_entity(self, variable_name: str) -> Entity: return self.variable_entities[variable_name] - def register_variable(self, variable_name: str, entity): + def register_variable(self, variable_name: str, entity: Entity) -> None: self.variable_entities[variable_name] = entity + + def register_variables(self, simulation: Simulation) -> None: + tax_benefit_system: TaxBenefitSystem = simulation.tax_benefit_system + variables: Iterable[str] = tax_benefit_system.variables.keys() + + for name in variables: + population: Population = simulation.get_variable_population(name) + entity: Entity = population.entity + self.register_variable(name, entity) diff --git a/openfisca_core/simulations/typing.py b/openfisca_core/simulations/typing.py new file mode 100644 index 000000000..18fa797c4 --- /dev/null +++ b/openfisca_core/simulations/typing.py @@ -0,0 +1,199 @@ +"""Type aliases of OpenFisca models to use in the context of simulations.""" + +from __future__ import annotations + +from abc import abstractmethod +from collections.abc import Iterable, Sequence +from numpy.typing import NDArray as Array +from typing import Protocol, TypeVar, TypedDict, Union +from typing_extensions import NotRequired, Required, TypeAlias + +import datetime + +from numpy import bool_ as Bool +from numpy import datetime64 as Date +from numpy import float32 as Float +from numpy import int16 as Enum +from numpy import int32 as Int +from numpy import str_ as String + +#: Generic type variables. +E = TypeVar("E") +G = TypeVar("G", covariant=True) +T = TypeVar("T", Bool, Date, Enum, Float, Int, String, covariant=True) +U = TypeVar("U", bool, datetime.date, float, str) +V = TypeVar("V", covariant=True) + + +#: Type alias for a simulation dictionary defining the roles. +Roles: TypeAlias = dict[str, Union[str, Iterable[str]]] + +#: Type alias for a simulation dictionary with undated variables. +UndatedVariable: TypeAlias = dict[str, object] + +#: Type alias for a simulation dictionary with dated variables. +DatedVariable: TypeAlias = dict[str, UndatedVariable] + +#: Type alias for a simulation dictionary with abbreviated entities. +Variables: TypeAlias = dict[str, Union[UndatedVariable, DatedVariable]] + +#: Type alias for a simulation with fully specified single entities. +SingleEntities: TypeAlias = dict[str, dict[str, Variables]] + +#: Type alias for a simulation dictionary with implicit group entities. +ImplicitGroupEntities: TypeAlias = dict[str, Union[Roles, Variables]] + +#: Type alias for a simulation dictionary with explicit group entities. +GroupEntities: TypeAlias = dict[str, ImplicitGroupEntities] + +#: Type alias for a simulation dictionary with fully specified entities. +FullySpecifiedEntities: TypeAlias = Union[SingleEntities, GroupEntities] + +#: Type alias for a simulation dictionary with axes parameters. +Axes: TypeAlias = dict[str, Iterable[Iterable["Axis"]]] + +#: Type alias for a simulation dictionary without axes parameters. +ParamsWithoutAxes: TypeAlias = Union[ + Variables, ImplicitGroupEntities, FullySpecifiedEntities +] + +#: Type alias for a simulation dictionary with axes parameters. +ParamsWithAxes: TypeAlias = Union[Axes, ParamsWithoutAxes] + +#: Type alias for a simulation dictionary with all the possible scenarios. +Params: TypeAlias = ParamsWithAxes + + +class Axis(TypedDict, total=False): + """Interface representing an axis of a simulation.""" + + count: Required[int] + index: NotRequired[int] + max: Required[float] + min: Required[float] + name: Required[str] + period: NotRequired[str | int] + + +class Entity(Protocol): + """Interface representing an entity of a simulation.""" + + key: str + plural: str | None + + def get_variable( + self, + __variable_name: str, + __check_existence: bool = ..., + ) -> Variable[T] | None: + """Get a variable.""" + + +class SingleEntity(Entity, Protocol): + """Interface representing a single entity of a simulation.""" + + +class GroupEntity(Entity, Protocol): + """Interface representing a group entity of a simulation.""" + + @property + @abstractmethod + def flattened_roles(self) -> Iterable[Role[G]]: + """Get the flattened roles of the GroupEntity.""" + + +class Holder(Protocol[V]): + """Interface representing a holder of a simulation's computed values.""" + + @property + @abstractmethod + def variable(self) -> Variable[T]: + """Get the Variable of the Holder.""" + + def get_array(self, __period: str) -> Array[T] | None: + """Get the values of the Variable for a given Period.""" + + def set_input( + self, + __period: Period, + __array: Array[T] | Sequence[U], + ) -> Array[T] | None: + """Set values for a Variable for a given Period.""" + + +class Period(Protocol): + """Interface representing a period of a simulation.""" + + +class Population(Protocol[E]): + """Interface representing a data vector of an Entity.""" + + count: int + entity: E + ids: Array[String] + + def get_holder(self, __variable_name: str) -> Holder[V]: + """Get the holder of a Variable.""" + + +class SinglePopulation(Population[E], Protocol): + """Interface representing a data vector of a SingleEntity.""" + + +class GroupPopulation(Population[E], Protocol): + """Interface representing a data vector of a GroupEntity.""" + + members_entity_id: Array[String] + + def nb_persons(self, __role: Role[G] | None = ...) -> int: + """Get the number of persons for a given Role.""" + + +class Role(Protocol[G]): + """Interface representing a role of the group entities of a simulation.""" + + +class TaxBenefitSystem(Protocol): + """Interface representing a tax-benefit system.""" + + @property + @abstractmethod + def person_entity(self) -> SingleEntity: + """Get the person entity of the tax-benefit system.""" + + @person_entity.setter + @abstractmethod + def person_entity(self, person_entity: SingleEntity) -> None: + """Set the person entity of the tax-benefit system.""" + + @property + @abstractmethod + def variables(self) -> dict[str, V]: + """Get the variables of the tax-benefit system.""" + + def entities_by_singular(self) -> dict[str, E]: + """Get the singular form of the entities' keys.""" + + def entities_plural(self) -> Iterable[str]: + """Get the plural form of the entities' keys.""" + + def get_variable( + self, + __variable_name: str, + __check_existence: bool = ..., + ) -> V | None: + """Get a variable.""" + + def instantiate_entities( + self, + ) -> dict[str, Population[E]]: + """Instantiate the populations of each Entity.""" + + +class Variable(Protocol[T]): + """Interface representing a variable of a tax-benefit system.""" + + end: str + + def default_array(self, __array_size: int) -> Array[T]: + """Fill an array with the default value of the Variable.""" diff --git a/setup.py b/setup.py index 2b25a81e1..fc0bbfef7 100644 --- a/setup.py +++ b/setup.py @@ -67,7 +67,7 @@ setup( name="OpenFisca-Core", - version="41.4.0", + version="41.4.1", author="OpenFisca Team", author_email="contact@openfisca.org", classifiers=[ diff --git a/tests/core/test_axes.py b/tests/core/test_axes.py index eb0f58caa..799439e9c 100644 --- a/tests/core/test_axes.py +++ b/tests/core/test_axes.py @@ -1,5 +1,6 @@ import pytest +from openfisca_core import errors from openfisca_core.simulations import SimulationBuilder from openfisca_core.tools import test_runner @@ -322,7 +323,7 @@ def test_add_perpendicular_axis_on_an_existing_variable_with_input(persons): ) -# Integration test +# Integration tests def test_simulation_with_axes(tax_benefit_system): @@ -350,3 +351,27 @@ def test_simulation_with_axes(tax_benefit_system): [0, 0, 0, 0, 0, 0] ) assert simulation.get_array("rent", "2018-11") == pytest.approx([0, 0, 3000, 0]) + + +# Test for missing group entities with build_from_entities() + + +def test_simulation_with_axes_missing_entities(tax_benefit_system): + input_yaml = """ + persons: + Alicia: {salary: {2018-11: 0}} + Javier: {} + Tom: {} + axes: + - + - count: 2 + name: rent + min: 0 + max: 3000 + period: 2018-11 + """ + data = test_runner.yaml.safe_load(input_yaml) + with pytest.raises(errors.SituationParsingError) as error: + SimulationBuilder().build_from_dict(tax_benefit_system, data) + assert "In order to expand over axes" in error.value() + assert "all group entities and roles must be fully specified" in error.value() diff --git a/tests/core/test_tracers.py b/tests/core/test_tracers.py index 1ecfd09aa..d0e25f866 100644 --- a/tests/core/test_tracers.py +++ b/tests/core/test_tracers.py @@ -20,6 +20,10 @@ from .parameters_fancy_indexing.test_fancy_indexing import parameters +class TestException(Exception): + ... + + class StubSimulation(Simulation): def __init__(self): self.exception = None @@ -91,9 +95,9 @@ def test_tracer_contract(tracer): def test_exception_robustness(): simulation = StubSimulation() simulation.tracer = MockTracer() - simulation.exception = Exception(":-o") + simulation.exception = TestException(":-o") - with raises(Exception): + with raises(TestException): simulation.calculate("a", 2017) assert simulation.tracer.calculation_start_recorded diff --git a/tests/fixtures/entities.py b/tests/fixtures/entities.py index 6cab008f4..4d103f10d 100644 --- a/tests/fixtures/entities.py +++ b/tests/fixtures/entities.py @@ -6,7 +6,9 @@ class TestEntity(Entity): - def get_variable(self, variable_name: str): + def get_variable( + self, variable_name: str, check_existence: bool = False + ) -> TestVariable: result = TestVariable(self) result.name = variable_name return result @@ -16,7 +18,9 @@ def check_variable_defined_for_entity(self, variable_name: str): class TestGroupEntity(GroupEntity): - def get_variable(self, variable_name: str): + def get_variable( + self, variable_name: str, check_existence: bool = False + ) -> TestVariable: result = TestVariable(self) result.name = variable_name return result