Skip to content

Commit 9d58ee2

Browse files
committed
remove observations from ErtConfig
1 parent 19b8b27 commit 9d58ee2

4 files changed

Lines changed: 51 additions & 81 deletions

File tree

src/ert/config/ert_config.py

Lines changed: 5 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,13 @@
1212
from pathlib import Path
1313
from typing import TYPE_CHECKING, Any, ClassVar, Self, cast, overload
1414

15-
import polars as pl
1615
from numpy.random import SeedSequence
17-
from pydantic import BaseModel, Field, PrivateAttr, model_validator
16+
from pydantic import BaseModel, Field, model_validator
1817
from pydantic import ValidationError as PydanticValidationError
1918

19+
from ert.config._create_observation_dataframes import create_observation_dataframes
2020
from ert.substitutions import Substitutions
2121

22-
from ._create_observation_dataframes import create_observation_dataframes
2322
from ._design_matrix_validator import DesignMatrixValidator
2423
from ._observations import (
2524
GeneralObservation,
@@ -714,32 +713,6 @@ class ErtConfig(BaseModel):
714713
config_path: str = Field(init=False, default="")
715714
observation_declarations: list[Observation] = Field(default_factory=list)
716715
zonemap: dict[int, list[str]] = Field(default_factory=dict)
717-
_observations: dict[str, pl.DataFrame] | None = PrivateAttr(None)
718-
719-
@property
720-
def observations(self) -> dict[str, pl.DataFrame]:
721-
if self._observations is None:
722-
has_rft_observations = any(
723-
isinstance(o, RFTObservation) for o in self.observation_declarations
724-
)
725-
if (
726-
has_rft_observations
727-
and "rft" not in self.ensemble_config.response_configs
728-
):
729-
self.ensemble_config.response_configs["rft"] = RFTConfig(
730-
input_files=[self.runpath_config.eclbase_format_string],
731-
data_to_read={},
732-
locations=[],
733-
zonemap=self.zonemap,
734-
)
735-
self._observations = create_observation_dataframes(
736-
self.observation_declarations,
737-
cast(
738-
RFTConfig | None,
739-
self.ensemble_config.response_configs.get("rft", None),
740-
),
741-
)
742-
return self._observations
743716

744717
@model_validator(mode="after")
745718
def set_fields(self) -> Self:
@@ -834,28 +807,6 @@ def validate_observations_against_responses(self) -> Self:
834807

835808
return self
836809

837-
def __eq__(self, other: object) -> bool:
838-
if not isinstance(other, ErtConfig):
839-
return False
840-
841-
for attr in vars(self):
842-
if attr == "observations":
843-
if self.observations.keys() != other.observations.keys():
844-
return False
845-
846-
if not all(
847-
self.observations[k].equals(other.observations[k])
848-
for k in self.observations
849-
):
850-
return False
851-
852-
continue
853-
854-
if getattr(self, attr) != getattr(other, attr):
855-
return False
856-
857-
return True
858-
859810
@staticmethod
860811
def with_plugins(runtime_plugins: ErtRuntimePlugins) -> type[ErtConfig]:
861812
class ErtConfigWithPlugins(ErtConfig):
@@ -1157,12 +1108,10 @@ def from_dict(cls, config_dict: ConfigDict) -> Self:
11571108

11581109
# PS:
11591110
# This mutates the rft config and is necessary for the moment
1160-
cls_config._observations = create_observation_dataframes(
1111+
# Consider changing this pattern
1112+
_ = create_observation_dataframes(
11611113
obs_configs,
1162-
cast(
1163-
RFTConfig | None,
1164-
ensemble_config.response_configs.get("rft", None),
1165-
),
1114+
cast(RFTConfig | None, ensemble_config.response_configs.get("rft")),
11661115
)
11671116
except PydanticValidationError as err:
11681117
raise ConfigValidationError.from_pydantic(err) from err

tests/ert/unit_tests/config/test_observations.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from datetime import datetime, timedelta
44
from pathlib import Path
55
from textwrap import dedent
6+
from typing import cast
67

78
import hypothesis.strategies as st
89
import polars as pl
@@ -25,6 +26,7 @@
2526
ObservationConfigError,
2627
ObservationType,
2728
)
29+
from ert.config.rft_config import RFTConfig
2830
from ert.namespace import Namespace
2931

3032
pytestmark = pytest.mark.filterwarnings("ignore:Config contains a SUMMARY key")
@@ -103,7 +105,7 @@ def make_refcase_observations(
103105
run_convert_observations(Namespace(config="config.ert"))
104106

105107
migrated_config = ErtConfig.from_file("config.ert")
106-
return migrated_config.observations
108+
return create_observation_dataframes(migrated_config.observation_declarations, None)
107109

108110

109111
@pytest.mark.usefixtures("use_tmpdir")
@@ -228,7 +230,9 @@ def test_that_summary_observations_can_use_restart_for_index_if_refcase_is_given
228230
run_convert_observations(Namespace(config=str(config_file)))
229231

230232
migrated_config = ErtConfig.from_file("config.ert")
231-
observations = migrated_config.observations["summary"]
233+
observations = create_observation_dataframes(
234+
migrated_config.observation_declarations, None
235+
)["summary"]
232236

233237
assert len(observations["time"]) == 1
234238
assert list(observations["observations"]) == pytest.approx([value])
@@ -279,7 +283,9 @@ def test_that_summary_observations_can_use_restart_for_index_if_time_map_is_give
279283
run_convert_observations(Namespace(config=str(config_file)))
280284

281285
migrated_config = ErtConfig.from_file("config.ert")
282-
observations = migrated_config.observations["summary"]
286+
observations = create_observation_dataframes(
287+
migrated_config.observation_declarations, None
288+
)["summary"]
283289

284290
# RESTART is a 1-based index; Python lists are 0-based.
285291
assert list(observations["time"]) == [datetime.fromisoformat(time_map[restart])]
@@ -308,8 +314,10 @@ def test_that_rft_config_is_created_from_observations():
308314
),
309315
}
310316
)
311-
312-
observations = ert_config.observations["rft"]
317+
rft_config = cast(RFTConfig, ert_config.ensemble_config.response_configs["rft"])
318+
observations = create_observation_dataframes(
319+
ert_config.observation_declarations, rft_config
320+
)["rft"]
313321
assert_frame_equal(
314322
observations,
315323
pl.DataFrame(
@@ -325,14 +333,13 @@ def test_that_rft_config_is_created_from_observations():
325333
}
326334
),
327335
)
328-
rft_config = ert_config.ensemble_config.response_configs["rft"]
329336
assert rft_config.data_to_read == {"well": {"2013-03-31": ["PRESSURE"]}}
330337
assert rft_config.locations == [(30.0, 71.0, 2000.0)]
331338

332339

333340
def test_that_rft_observations_with_unknown_zones_errors():
334341
with pytest.raises(ConfigValidationError, match="no such zone"):
335-
_ = ErtConfig.from_dict(
342+
ErtConfig.from_dict(
336343
{
337344
"ECLBASE": "ECLIPSE_CASE",
338345
"OBS_CONFIG": (
@@ -388,7 +395,9 @@ def test_that_the_date_keyword_sets_the_summary_index_without_time_map_or_refcas
388395
run_convert_observations(Namespace(config="config.ert"))
389396

390397
migrated = ErtConfig.from_file("config.ert")
391-
observations = migrated.observations["summary"]
398+
observations = create_observation_dataframes(
399+
migrated.observation_declarations, None
400+
)["summary"]
392401

393402
assert list(observations["time"]) == [datetime.fromisoformat(date)]
394403

@@ -457,12 +466,10 @@ def test_that_the_date_keyword_sets_the_general_index_by_looking_up_time_map():
457466
Path("config.ert").write_text(config_content, encoding="utf-8")
458467

459468
run_convert_observations(Namespace(config="config.ert"))
460-
assert (
461-
ErtConfig.from_file("config.ert")
462-
.observations["gen_data"]
463-
.to_dicts()[0]["report_step"]
464-
== restart
469+
observations = create_observation_dataframes(
470+
ErtConfig.from_file("config.ert").observation_declarations, None
465471
)
472+
assert observations["gen_data"].to_dicts()[0]["report_step"] == restart
466473

467474

468475
@given(summary=summaries(), data=st.data())
@@ -512,12 +519,10 @@ def test_that_the_date_keyword_sets_the_report_step_by_looking_up_refcase(
512519
)
513520
Path("config.ert").write_text(config_content, encoding="utf-8")
514521
run_convert_observations(Namespace(config="config.ert"))
515-
assert (
516-
ErtConfig.from_file("config.ert")
517-
.observations["gen_data"]
518-
.to_dicts()[0]["report_step"]
519-
== restart
522+
observations = create_observation_dataframes(
523+
ErtConfig.from_file("config.ert").observation_declarations, None
520524
)
525+
assert observations["gen_data"].to_dicts()[0]["report_step"] == restart
521526

522527

523528
@pytest.mark.parametrize("std", [-1.0, 0, 0.0])
@@ -1228,7 +1233,9 @@ def test_that_history_observations_values_are_fetched_from_refcase(
12281233
Path("config.ert").write_text(config_content, encoding="utf-8")
12291234

12301235
run_convert_observations(Namespace(config="config.ert"))
1231-
observations = ErtConfig.from_file("config.ert").observations["summary"]
1236+
observations = create_observation_dataframes(
1237+
ErtConfig.from_file("config.ert").observation_declarations, None
1238+
)["summary"]
12321239

12331240
steps = len(unsmry.steps)
12341241
assert list(observations["response_key"]) == ["FOPR"] * steps
@@ -1407,7 +1414,9 @@ def test_that_history_observation_errors_are_calculated_correctly(tmpdir):
14071414
Path("config.ert").write_text(config_content, encoding="utf-8")
14081415

14091416
run_convert_observations(Namespace(config="config.ert"))
1410-
observations = ErtConfig.from_file("config.ert").observations["summary"]
1417+
observations = create_observation_dataframes(
1418+
ErtConfig.from_file("config.ert").observation_declarations, None
1419+
)["summary"]
14111420

14121421
assert list(observations["response_key"]) == ["FGPR", "FOPR", "FWPR"]
14131422
assert list(observations["observations"]) == pytest.approx([15, 20, 25])
@@ -1448,7 +1457,9 @@ def test_that_segment_defaults_are_applied(tmpdir):
14481457
Path("config.ert").write_text(config_content, encoding="utf-8")
14491458

14501459
run_convert_observations(Namespace(config="config.ert"))
1451-
observations = ErtConfig.from_file("config.ert").observations["summary"]
1460+
observations = create_observation_dataframes(
1461+
ErtConfig.from_file("config.ert").observation_declarations, None
1462+
)["summary"]
14521463

14531464
# default error_min is 0.1
14541465
# default error method is RELMIN

tests/ert/unit_tests/config/test_summary_config.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
InvalidResponseFile,
1717
SummaryConfig,
1818
)
19-
from ert.config._create_observation_dataframes import DEFAULT_LOCALIZATION_RADIUS
19+
from ert.config._create_observation_dataframes import (
20+
DEFAULT_LOCALIZATION_RADIUS,
21+
create_observation_dataframes,
22+
)
2023

2124

2225
@settings(max_examples=10)
@@ -102,7 +105,9 @@ def create_summary_observation(loc_config_lines):
102105
Path("prior.txt").write_text("MY_KEYWORD NORMAL 0 1", encoding="utf-8")
103106

104107
ert_config = ErtConfig.from_file("config.ert")
105-
return ert_config.observations["summary"]
108+
return create_observation_dataframes(ert_config.observation_declarations, None)[
109+
"summary"
110+
]
106111

107112

108113
@pytest.mark.filterwarnings("ignore:Config contains a SUMMARY key but no forward model")
@@ -198,7 +203,9 @@ def test_that_adding_one_localized_observation_to_snake_oil_case_can_be_internal
198203
obs_lines.insert(observation_index + 2 + i, line)
199204
new_obs_content = "\n".join(obs_lines)
200205
Path("observations/observations.txt").write_text(new_obs_content, encoding="utf-8")
201-
summary = ErtConfig.from_file("snake_oil.ert").observations["summary"]
206+
summary = create_observation_dataframes(
207+
ErtConfig.from_file("snake_oil.ert").observation_declarations, None
208+
)["summary"]
202209
assert summary["east"].dtype == pl.Float32
203210
assert summary["north"].dtype == pl.Float32
204211
assert summary["radius"].dtype == pl.Float32

tests/ert/unit_tests/scenarios/test_summary_response.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from ert.analysis import ErtAnalysisError, smoother_update
1313
from ert.config import ErtConfig, ESSettings, ObservationSettings
14+
from ert.config._create_observation_dataframes import create_observation_dataframes
1415
from ert.data import MeasuredData
1516
from ert.sample_prior import sample_prior
1617
from ert.storage.local_ensemble import load_parameters_and_responses_from_runpath
@@ -21,7 +22,9 @@ def prior_ensemble(storage, ert_config):
2122
return storage.create_experiment(
2223
parameters=ert_config.ensemble_config.parameter_configuration,
2324
responses=ert_config.ensemble_config.response_configuration,
24-
observations=ert_config.observations,
25+
observations=create_observation_dataframes(
26+
ert_config.observation_declarations, None
27+
),
2528
).create_ensemble(ensemble_size=3, name="prior")
2629

2730

0 commit comments

Comments
 (0)