Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ert/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def run_cli(args: Namespace, runtime_plugins: ErtRuntimePlugins | None = None) -
f"Config contains forward model step {fm_step_name} {count} time(s)",
)

if not ert_config.observations and args.mode not in {
if not ert_config.observation_declarations and args.mode not in {
ENSEMBLE_EXPERIMENT_MODE,
TEST_RUN_MODE,
WORKFLOW_MODE,
Expand Down
61 changes: 5 additions & 56 deletions src/ert/config/ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Self, cast, overload

import polars as pl
from numpy.random import SeedSequence
from pydantic import BaseModel, Field, PrivateAttr, model_validator
from pydantic import BaseModel, Field, model_validator
from pydantic import ValidationError as PydanticValidationError

from ert.config._create_observation_dataframes import create_observation_dataframes
from ert.substitutions import Substitutions

from ._create_observation_dataframes import create_observation_dataframes
from ._design_matrix_validator import DesignMatrixValidator
from ._observations import (
GeneralObservation,
Expand Down Expand Up @@ -714,32 +713,6 @@ class ErtConfig(BaseModel):
config_path: str = Field(init=False, default="")
observation_declarations: list[Observation] = Field(default_factory=list)
zonemap: dict[int, list[str]] = Field(default_factory=dict)
_observations: dict[str, pl.DataFrame] | None = PrivateAttr(None)

@property
def observations(self) -> dict[str, pl.DataFrame]:
if self._observations is None:
has_rft_observations = any(
isinstance(o, RFTObservation) for o in self.observation_declarations
)
if (
has_rft_observations
and "rft" not in self.ensemble_config.response_configs
):
self.ensemble_config.response_configs["rft"] = RFTConfig(
input_files=[self.runpath_config.eclbase_format_string],
data_to_read={},
locations=[],
zonemap=self.zonemap,
)
self._observations = create_observation_dataframes(
self.observation_declarations,
cast(
RFTConfig | None,
self.ensemble_config.response_configs.get("rft", None),
),
)
return self._observations

@model_validator(mode="after")
def set_fields(self) -> Self:
Expand Down Expand Up @@ -834,28 +807,6 @@ def validate_observations_against_responses(self) -> Self:

return self

def __eq__(self, other: object) -> bool:
if not isinstance(other, ErtConfig):
return False

for attr in vars(self):
if attr == "observations":
if self.observations.keys() != other.observations.keys():
return False

if not all(
self.observations[k].equals(other.observations[k])
for k in self.observations
):
return False

continue

if getattr(self, attr) != getattr(other, attr):
return False

return True

@staticmethod
def with_plugins(runtime_plugins: ErtRuntimePlugins) -> type[ErtConfig]:
class ErtConfigWithPlugins(ErtConfig):
Expand Down Expand Up @@ -1157,12 +1108,10 @@ def from_dict(cls, config_dict: ConfigDict) -> Self:

# PS:
# This mutates the rft config and is necessary for the moment
cls_config._observations = create_observation_dataframes(
# Consider changing this pattern
_ = create_observation_dataframes(
obs_configs,
cast(
RFTConfig | None,
ensemble_config.response_configs.get("rft", None),
),
cast(RFTConfig | None, ensemble_config.response_configs.get("rft")),
)
except PydanticValidationError as err:
raise ConfigValidationError.from_pydantic(err) from err
Expand Down
10 changes: 7 additions & 3 deletions tests/ert/performance_tests/test_memory_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ert.__main__ import run_convert_observations
from ert.analysis import enif_update, smoother_update
from ert.config import ErtConfig, ESSettings, ObservationSettings
from ert.config._create_observation_dataframes import create_observation_dataframes
from ert.mode_definitions import ENSEMBLE_SMOOTHER_MODE
from ert.namespace import Namespace
from ert.sample_prior import sample_prior
Expand Down Expand Up @@ -109,18 +110,21 @@ def fill_storage_with_data(poly_template: Path, ert_config: ErtConfig) -> None:
path = Path(poly_template) / "ensembles"
with open_storage(path, mode="w") as storage:
ens_config = ert_config.ensemble_config
observations = create_observation_dataframes(
ert_config.observation_declarations, None
)
experiment_id = storage.create_experiment(
parameters=ens_config.parameter_configuration,
responses=ens_config.response_configuration,
observations=ert_config.observations,
observations=observations,
name="test-experiment",
)
source = storage.create_ensemble(experiment_id, name="prior", ensemble_size=100)

realizations = list(range(ert_config.runpath_config.num_realizations))
for real in realizations:
gendatas = []
gen_obs = ert_config.observations["gen_data"]
gen_obs = observations["gen_data"]
for response_key, df in gen_obs.group_by("response_key"):
gendata_df = make_gen_data(df["index"].max() + 1)
gendata_df = gendata_df.insert_column(
Expand All @@ -142,7 +146,7 @@ def fill_storage_with_data(poly_template: Path, ert_config: ErtConfig) -> None:
for i in range((refcase_end - refcase_start).days + 1)
]

summary_keys = ert_config.observations["summary"]["response_key"].unique(
summary_keys = observations["summary"]["response_key"].unique(
maintain_order=True
)

Expand Down
5 changes: 4 additions & 1 deletion tests/ert/ui_tests/cli/test_field_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from ert.analysis import smoother_update
from ert.config import ErtConfig, ESSettings, ObservationSettings
from ert.config._create_observation_dataframes import create_observation_dataframes
from ert.mode_definitions import ENSEMBLE_SMOOTHER_MODE
from ert.storage import open_storage

Expand Down Expand Up @@ -435,7 +436,9 @@ def test_field_param_update_using_heat_equation_zero_var_params_and_adaptive_loc
new_experiment = storage.create_experiment(
parameters=config.ensemble_config.parameter_configuration,
responses=config.ensemble_config.response_configuration,
observations=config.observations,
observations=create_observation_dataframes(
config.observation_declarations, None
),
name="exp-zero-var",
)
new_prior = storage.create_ensemble(
Expand Down
5 changes: 4 additions & 1 deletion tests/ert/ui_tests/gui/test_manage_experiments_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)

from ert.config import ErtConfig, SummaryConfig
from ert.config._create_observation_dataframes import create_observation_dataframes
from ert.gui.ertnotifier import ErtNotifier
from ert.gui.tools.manage_experiments import ManageExperimentsPanel
from ert.gui.tools.manage_experiments.storage_info_widget import (
Expand Down Expand Up @@ -162,7 +163,9 @@ def test_that_init_updates_the_info_tab(qtbot):
ensemble = storage.create_experiment(
parameters=config.ensemble_config.parameter_configuration,
responses=config.ensemble_config.response_configuration,
observations=config.observations,
observations=create_observation_dataframes(
config.observation_declarations, None
),
name="my-experiment",
).create_ensemble(
ensemble_size=config.runpath_config.num_realizations, name="default"
Expand Down
5 changes: 4 additions & 1 deletion tests/ert/unit_tests/config/test_ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
QueueSystem,
RFTConfig,
)
from ert.config._create_observation_dataframes import create_observation_dataframes
from ert.config.ert_config import _split_string_into_sections, create_forward_model_json
from ert.config.forward_model_step import (
ForwardModelStepPlugin,
Expand Down Expand Up @@ -2733,7 +2734,9 @@ def test_that_breakthrough_observations_can_be_internalized_in_ert_config():
""",
)

breakthrough_observations = config.observations["breakthrough"]
breakthrough_observations = create_observation_dataframes(
config.observation_declarations, None
)["breakthrough"]
assert breakthrough_observations["observation_key"].to_list() == ["BRT_OBS"]
assert breakthrough_observations["response_key"].to_list() == [
"BREAKTHROUGH:WWCT:OP_1:0.1"
Expand Down
Loading
Loading