Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions src/ert/config/parameter_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,8 @@ def group_name(self) -> str:
def transform_data(self) -> Callable[[float], float]:
return lambda x: x

def sample_value(
self,
global_seed: str,
realization: int,
def sample_values(
self, global_seed: str, active_realizations: list[int], num_realizations: int
) -> npt.NDArray[np.double]:
"""
Generate a sample value for each key in a parameter group.
Expand Down Expand Up @@ -162,9 +160,7 @@ def sample_value(
seed = np.frombuffer(key_hash.digest(), dtype="uint32")
rng = np.random.default_rng(seed)

# Advance the RNG state to the realization point
rng.standard_normal(realization)

# Generate a single sample
value = rng.standard_normal(1)
return np.array([value[0]])
# Generate samples for all active realizations
all_values = rng.standard_normal(num_realizations)
idx = np.asarray(active_realizations, dtype=int)
return all_values[idx]
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def initialize_from_scratch(_: bool) -> None:
active_realizations=active_realizations,
parameters=parameters,
random_seed=self.ert_config.random_seed,
num_realizations=self.ert_config.runpath_config.num_realizations,
design_matrix_df=(
self.ert_config.analysis_config.design_matrix.design_matrix_df
if self.ert_config.analysis_config.design_matrix
Expand Down
1 change: 1 addition & 0 deletions src/ert/run_models/initial_ensemble_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def _sample_and_evaluate_ensemble(
np.where(self.active_realizations)[0],
parameters=[param.name for param in self.parameter_configuration],
random_seed=self.random_seed,
num_realizations=self.runpath_config.num_realizations,
design_matrix_df=self.design_matrix.to_polars()
if self.design_matrix is not None
else None,
Expand Down
26 changes: 12 additions & 14 deletions src/ert/sample_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def sample_prior(
ensemble: Ensemble,
active_realizations: Iterable[int],
random_seed: int,
num_realizations: int,
parameters: list[str] | None = None,
design_matrix_df: pl.DataFrame | None = None,
) -> None:
Expand Down Expand Up @@ -65,21 +66,18 @@ def sample_prior(
f"Sampling parameter {config_node.name} "
f"for realizations {active_realizations}"
)
datasets = [
Ensemble.sample_parameter(
config_node,
realization_nr,
random_seed=random_seed,
)
for realization_nr in active_realizations
]
if datasets:
dataset = pl.concat(datasets, how="vertical")
dataset = Ensemble.sample_parameter(
config_node,
list(active_realizations),
random_seed=random_seed,
num_realizations=num_realizations,
)
if not (dataset is None or dataset.is_empty()):
if complete_dataset is None:
complete_dataset = dataset
elif dataset is not None:
complete_dataset = complete_dataset.join(dataset, on="realization")

if complete_dataset is None:
complete_dataset = dataset
elif dataset is not None:
complete_dataset = complete_dataset.join(dataset, on="realization")
else:
for realization_nr in active_realizations:
ds = config_node.read_from_runpath(Path(), realization_nr, 0)
Expand Down
19 changes: 10 additions & 9 deletions src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,20 +737,21 @@ def load_observation_scaling_factors(
@staticmethod
def sample_parameter(
parameter: ParameterConfig,
real_nr: int,
active_realizations: list[int],
random_seed: int,
num_realizations: int,
) -> pl.DataFrame:
parameter_value = parameter.sample_value(
str(random_seed),
real_nr,
parameter_values = parameter.sample_values(
str(random_seed), active_realizations, num_realizations=num_realizations
)

parameter_dict = {parameter.name: parameter_value[0]}
parameter_dict["realization"] = real_nr
return pl.DataFrame(
parameter_dict,
schema={parameter.name: pl.Float64, "realization": pl.Int64},
parameters = pl.DataFrame(
{parameter.name: parameter_values},
schema={parameter.name: pl.Float64},
)
realizations_series = pl.Series("realization", list(active_realizations))

return parameters.with_columns(realizations_series)

def load_responses(self, key: str, realizations: tuple[int, ...]) -> pl.DataFrame:
"""Load responses for key and realizations into xarray Dataset.
Expand Down
2 changes: 1 addition & 1 deletion tests/ert/performance_tests/test_memory_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def fill_storage_with_data(poly_template: Path, ert_config: ErtConfig) -> None:
real,
)

sample_prior(source, realizations, 42, ens_config.parameters)
sample_prior(source, realizations, 42, len(realizations), ens_config.parameters)

storage.create_ensemble(
source.experiment_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ def setup_es_benchmark(tmp_path, request):
prior,
range(config.num_realizations),
42,
config.num_realizations,
[c.name for c in info.gen_kw_configs],
)
posterior = experiment.create_ensemble(
Expand Down
2 changes: 1 addition & 1 deletion tests/ert/ui_tests/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ def test_that_prior_is_not_overwritten_in_ensemble_experiment(
ensemble = storage.create_ensemble(
experiment_id, name="iter-0", ensemble_size=num_realizations
)
sample_prior(ensemble, prior_mask, ert_config.random_seed)
sample_prior(ensemble, prior_mask, ert_config.random_seed, num_realizations)
experiment = storage.get_experiment_by_name("test-experiment")
prior_values = experiment.get_ensemble_by_name(
ensemble.name
Expand Down
30 changes: 12 additions & 18 deletions tests/ert/unit_tests/analysis/test_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,13 @@ def test_update_handles_precision_loss_in_std_dev(tmp_path):
],
)
prior = storage.create_ensemble(experiment.id, ensemble_size=23, name="prior")
datasets = [
Ensemble.sample_parameter(
gen_kw,
realization_nr,
random_seed=1234,
)
for realization_nr in range(prior.ensemble_size)
]
prior.save_parameters(pl.concat(datasets, how="vertical"))
datasets = Ensemble.sample_parameter(
gen_kw,
list(range(prior.ensemble_size)),
random_seed=1234,
num_realizations=23,
)
prior.save_parameters(datasets)

prior.save_response(
"gen_data",
Expand Down Expand Up @@ -379,15 +377,11 @@ def test_update_raises_on_singular_matrix(tmp_path):
],
)
prior = storage.create_ensemble(experiment.id, ensemble_size=2, name="prior")
datasets = [
Ensemble.sample_parameter(
gen_kw,
realization_nr,
random_seed=1234,
)
for realization_nr in range(prior.ensemble_size)
]
prior.save_parameters(pl.concat(datasets, how="vertical"))
datasets = Ensemble.sample_parameter(
gen_kw, [0, 1], random_seed=1234, num_realizations=2
)

prior.save_parameters(datasets)

for i, v in enumerate(
[
Expand Down
2 changes: 1 addition & 1 deletion tests/ert/unit_tests/config/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_write_to_runpath_produces_the_transformed_field_in_storage(
experiment_id, name="prior", ensemble_size=5
)
active_realizations = [0, 3, 4]
sample_prior(prior_ensemble, active_realizations, 123)
sample_prior(prior_ensemble, active_realizations, 123, 5)
permx_field = ensemble_config["PERMX"]
assert (permx_field.nx, permx_field.ny, permx_field.nz) == (10, 10, 5)
assert permx_field.truncation_min is None
Expand Down
2 changes: 1 addition & 1 deletion tests/ert/unit_tests/config/test_gen_kw_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def test_gen_kw_is_log_or_not(
prior_ensemble = storage.create_ensemble(
experiment_id, name="prior", ensemble_size=1
)
sample_prior(prior_ensemble, [0], 123)
sample_prior(prior_ensemble, [0], 123, 1)
create_run_path(
run_args=run_args(ert_config, prior_ensemble),
ensemble=prior_ensemble,
Expand Down
42 changes: 36 additions & 6 deletions tests/ert/unit_tests/scenarios/test_summary_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,12 @@ def create_responses(prior_ensemble, response_times):

@pytest.mark.filterwarnings("ignore:Config contains a SUMMARY key")
def test_that_reading_matching_time_is_ok(ert_config, storage, prior_ensemble):
sample_prior(prior_ensemble, range(prior_ensemble.ensemble_size), 123)
sample_prior(
prior_ensemble,
range(prior_ensemble.ensemble_size),
123,
prior_ensemble.ensemble_size,
)

create_responses(
prior_ensemble,
Expand All @@ -110,7 +115,12 @@ def test_that_reading_matching_time_is_ok(ert_config, storage, prior_ensemble):

@pytest.mark.filterwarnings("ignore:Config contains a SUMMARY key")
def test_that_mismatched_responses_give_error(ert_config, storage, prior_ensemble):
sample_prior(prior_ensemble, range(prior_ensemble.ensemble_size), 123)
sample_prior(
prior_ensemble,
range(prior_ensemble.ensemble_size),
123,
prior_ensemble.ensemble_size,
)

response_times = [
[datetime(2014, 9, 9)],
Expand Down Expand Up @@ -144,7 +154,12 @@ def test_that_different_length_is_ok_as_long_as_observation_time_exists(
storage,
prior_ensemble,
):
sample_prior(prior_ensemble, range(prior_ensemble.ensemble_size), 123)
sample_prior(
prior_ensemble,
range(prior_ensemble.ensemble_size),
123,
prior_ensemble.ensemble_size,
)
response_times = [
[datetime(2014, 9, 9)],
[datetime(2014, 9, 9)],
Expand Down Expand Up @@ -193,7 +208,12 @@ def test_that_duplicate_summary_time_steps_does_not_fail(
storage,
prior_ensemble,
):
sample_prior(prior_ensemble, range(prior_ensemble.ensemble_size), 123)
sample_prior(
prior_ensemble,
range(prior_ensemble.ensemble_size),
123,
prior_ensemble.ensemble_size,
)
response_times = [
[datetime(2014, 9, 9)],
[datetime(2014, 9, 9)],
Expand Down Expand Up @@ -224,7 +244,12 @@ def test_that_duplicate_summary_time_steps_does_not_fail(
@pytest.mark.flaky(reruns=5)
@pytest.mark.filterwarnings("ignore:Config contains a SUMMARY key")
def test_that_mismatched_responses_gives_nan_measured_data(prior_ensemble):
sample_prior(prior_ensemble, range(prior_ensemble.ensemble_size), 123)
sample_prior(
prior_ensemble,
range(prior_ensemble.ensemble_size),
123,
prior_ensemble.ensemble_size,
)

response_times = [
[datetime(2014, 9, 9)],
Expand Down Expand Up @@ -253,7 +278,12 @@ def test_that_mismatched_responses_gives_nan_measured_data(prior_ensemble):

@pytest.mark.filterwarnings("ignore:Config contains a SUMMARY key")
def test_reading_past_2263_is_ok(ert_config, prior_ensemble):
sample_prior(prior_ensemble, range(prior_ensemble.ensemble_size), 123)
sample_prior(
prior_ensemble,
range(prior_ensemble.ensemble_size),
123,
prior_ensemble.ensemble_size,
)

create_responses(
prior_ensemble,
Expand Down
1 change: 1 addition & 0 deletions tests/ert/unit_tests/storage/create_runpath.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def create_runpath(
ensemble,
[i for i, active in enumerate(active_mask) if active],
random_seed=ert_config.random_seed,
num_realizations=ert_config.runpath_config.num_realizations,
)
create_run_path(
run_args=run_args,
Expand Down
9 changes: 8 additions & 1 deletion tests/ert/unit_tests/storage/test_local_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,7 @@ def test_sample_parameter_with_design_matrix(tmp_path, reals, expect_error):
ensemble,
reals,
random_seed=123,
num_realizations=ensemble_size,
parameters=[
param.name for param in design_matrix.parameter_configurations
],
Expand All @@ -991,6 +992,7 @@ def test_sample_parameter_with_design_matrix(tmp_path, reals, expect_error):
ensemble,
reals,
random_seed=123,
num_realizations=ensemble_size,
parameters=[
param.name for param in design_matrix.parameter_configurations
],
Expand Down Expand Up @@ -1037,7 +1039,12 @@ def test_load_gen_kw_not_sorted(storage, tmpdir, snapshot):
experiment_id, name="default", ensemble_size=ensemble_size
)

sample_prior(ensemble, range(ensemble_size), random_seed=1234)
sample_prior(
ensemble,
range(ensemble_size),
random_seed=1234,
num_realizations=ensemble_size,
)
data = ensemble.load_scalars()
snapshot.assert_match(data.write_csv(float_precision=12), "gen_kw_unsorted")

Expand Down
12 changes: 10 additions & 2 deletions tests/ert/unit_tests/storage/test_parameter_sample_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,12 @@ def test_that_sampling_is_fixed_from_name(
name="prior",
ensemble_size=num_realisations,
)
sample_prior(fs, range(num_realisations), random_seed=1234)
sample_prior(
fs,
range(num_realisations),
random_seed=1234,
num_realizations=num_realisations,
)

key_hash = sha256(b"1234" + b"KW_NAME:MY_KEYWORD")
seed = np.frombuffer(key_hash.digest(), dtype="uint32")
Expand Down Expand Up @@ -365,7 +370,10 @@ def test_that_sub_sample_maintains_order(tmpdir, storage, mask, expected):
ensemble_size=5,
)
sample_prior(
fs, [i for i, active in enumerate(mask) if active], random_seed=1234
fs,
[i for i, active in enumerate(mask) if active],
random_seed=1234,
num_realizations=5,
)

df = fs.load_parameters("KW_NAME")
Expand Down
Loading
Loading