Skip to content

Commit 1af6c01

Browse files
riemanliThe Meridian Authors
authored andcommitted
Adds KPI invariance check to EDA engine.
PiperOrigin-RevId: 825871774
1 parent 7ada9c0 commit 1af6c01

File tree

10 files changed

+253
-12
lines changed

10 files changed

+253
-12
lines changed

meridian/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@
133133
NON_MEDIA_TREATMENTS_SCALED = 'non_media_treatments_scaled'
134134
CONTROLS_SCALED = 'controls_scaled'
135135
KPI_SCALED = f'{KPI}_scaled'
136+
POPULATION_SCALED_KPI = f'{POPULATION}_scaled_{KPI}'
136137
RF_IMPRESSIONS_SCALED = f'{RF_IMPRESSIONS}_scaled'
137138

138139
# Non-media treatments baseline value constants.

meridian/model/eda/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,6 @@
1515
"""The Meridian API module that performs EDA checks."""
1616

1717
from meridian.model.eda import eda_engine
18+
from meridian.model.eda import eda_outcome
19+
from meridian.model.eda import eda_spec
1820
from meridian.model.eda import meridian_eda

meridian/model/eda/eda_engine.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@
1414

1515
"""Meridian EDA Engine."""
1616

17+
from __future__ import annotations
18+
1719
import dataclasses
1820
import functools
21+
import typing
1922
from typing import Optional, Sequence
23+
2024
from meridian import constants
21-
from meridian.model import model
2225
from meridian.model import transformers
2326
from meridian.model.eda import eda_outcome
2427
from meridian.model.eda import eda_spec
@@ -30,6 +33,9 @@
3033
import xarray as xr
3134

3235

36+
if typing.TYPE_CHECKING:
37+
from meridian.model import model # pylint: disable=g-bad-import-order,g-import-not-at-top
38+
3339
_DEFAULT_DA_VAR_AGG_FUNCTION = np.sum
3440
_CORRELATION_COL_NAME = 'correlation'
3541
_STACK_VAR_COORD_NAME = 'var'
@@ -727,6 +733,27 @@ def geo_population_da(self) -> xr.DataArray | None:
727733
name=constants.POPULATION,
728734
)
729735

736+
@functools.cached_property
737+
def _population_scaled_kpi_artifact(
738+
self,
739+
) -> eda_outcome.KpiInvarianceArtifact:
740+
"""Returns an artifact containing population-scaled KPI data."""
741+
kpi_transformer = self._meridian.kpi_transformer
742+
743+
population_scaled_kpi_da = _data_array_like(
744+
da=self._meridian.input_data.kpi,
745+
values=kpi_transformer.population_scaled_kpi,
746+
)
747+
population_scaled_kpi_da.name = constants.POPULATION_SCALED_KPI
748+
749+
artifact = eda_outcome.KpiInvarianceArtifact(
750+
level=eda_outcome.AnalysisLevel.OVERALL,
751+
population_scaled_kpi_da=population_scaled_kpi_da,
752+
population_scaled_mean=float(kpi_transformer.population_scaled_mean),
753+
population_scaled_stdev=float(kpi_transformer.population_scaled_stdev),
754+
)
755+
return artifact
756+
730757
@functools.cached_property
731758
def kpi_scaled_da(self) -> xr.DataArray:
732759
scaled_kpi_da = _data_array_like(
@@ -1643,3 +1670,35 @@ def check_national_vif(
16431670
findings=findings,
16441671
analysis_artifacts=[national_vif_artifact],
16451672
)
1673+
1674+
def kpi_has_variability(self) -> bool:
1675+
"""Returns True if the KPI has variability across geos and times."""
1676+
stdev = float(self._meridian.kpi_transformer.population_scaled_stdev)
1677+
return stdev != 0
1678+
1679+
def check_overall_kpi_invariance(self) -> eda_outcome.EDAOutcome:
1680+
"""Checks if the KPI is constant across all geos and times."""
1681+
if not self.kpi_has_variability():
1682+
kpi = 'kpi' if self._meridian.is_national else 'population_scaled_kpi'
1683+
1684+
eda_finding = eda_outcome.EDAFinding(
1685+
severity=eda_outcome.EDASeverity.ERROR,
1686+
explanation=(
1687+
f'`{kpi}` is constant across all geos and times, indicating no'
1688+
' signal in the data. Please fix this data error.'
1689+
),
1690+
)
1691+
else:
1692+
eda_finding = eda_outcome.EDAFinding(
1693+
severity=eda_outcome.EDASeverity.INFO,
1694+
explanation=(
1695+
'The KPI has variability across geos and times, indicating'
1696+
' variability in the data.'
1697+
),
1698+
)
1699+
1700+
return eda_outcome.EDAOutcome(
1701+
check_type=eda_outcome.EDACheckType.KPI_INVARIANCE,
1702+
findings=[eda_finding],
1703+
analysis_artifacts=[self._population_scaled_kpi_artifact],
1704+
)

meridian/model/eda/eda_engine_test.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4605,6 +4605,100 @@ def test_check_national_vif_has_correct_vif_value(self):
46054605
]
46064606
self.assertAllClose(national_artifact.vif_da.values, expected_national_vif)
46074607

4608+
@parameterized.named_parameters(
4609+
dict(
4610+
testcase_name="has_variability",
4611+
population_scaled_stdev=1.0,
4612+
expected_result=True,
4613+
),
4614+
dict(
4615+
testcase_name="no_variability",
4616+
population_scaled_stdev=0.0,
4617+
expected_result=False,
4618+
),
4619+
)
4620+
def test_kpi_has_variability(self, population_scaled_stdev, expected_result):
4621+
meridian = mock.Mock(spec=model.Meridian)
4622+
meridian.kpi_transformer.population_scaled_stdev = population_scaled_stdev
4623+
engine = eda_engine.EDAEngine(meridian)
4624+
self.assertEqual(engine.kpi_has_variability(), expected_result)
4625+
4626+
@parameterized.named_parameters(
4627+
dict(
4628+
testcase_name="geo",
4629+
is_national=False,
4630+
expected_kpi_name="population_scaled_kpi",
4631+
),
4632+
dict(
4633+
testcase_name="national",
4634+
is_national=True,
4635+
expected_kpi_name="kpi",
4636+
),
4637+
)
4638+
def test_check_overall_kpi_invariance_no_variability(
4639+
self, is_national, expected_kpi_name
4640+
):
4641+
meridian = mock.Mock(spec=model.Meridian)
4642+
meridian.is_national = is_national
4643+
meridian.kpi_transformer.population_scaled_stdev = 0.0
4644+
engine = eda_engine.EDAEngine(meridian)
4645+
4646+
mock_artifact = eda_outcome.KpiInvarianceArtifact(
4647+
level=eda_outcome.AnalysisLevel.OVERALL,
4648+
population_scaled_kpi_da=mock.Mock(),
4649+
population_scaled_mean=0.0,
4650+
population_scaled_stdev=0.0,
4651+
)
4652+
self._mock_eda_engine_property(
4653+
"_population_scaled_kpi_artifact", mock_artifact
4654+
)
4655+
4656+
outcome = engine.check_overall_kpi_invariance()
4657+
4658+
self.assertEqual(
4659+
outcome.check_type, eda_outcome.EDACheckType.KPI_INVARIANCE
4660+
)
4661+
self.assertLen(outcome.findings, 1)
4662+
self.assertEqual(
4663+
outcome.findings[0].severity, eda_outcome.EDASeverity.ERROR
4664+
)
4665+
self.assertIn(
4666+
f"`{expected_kpi_name}` is constant across all geos and times",
4667+
outcome.findings[0].explanation,
4668+
)
4669+
self.assertLen(outcome.analysis_artifacts, 1)
4670+
self.assertEqual(outcome.analysis_artifacts[0], mock_artifact)
4671+
4672+
def test_check_overall_kpi_invariance_has_variability(self):
4673+
meridian = mock.Mock(spec=model.Meridian)
4674+
meridian.is_national = False
4675+
meridian.kpi_transformer.population_scaled_stdev = 1.0
4676+
engine = eda_engine.EDAEngine(meridian)
4677+
4678+
mock_artifact = eda_outcome.KpiInvarianceArtifact(
4679+
level=eda_outcome.AnalysisLevel.OVERALL,
4680+
population_scaled_kpi_da=mock.Mock(),
4681+
population_scaled_mean=1.0,
4682+
population_scaled_stdev=1.0,
4683+
)
4684+
self._mock_eda_engine_property(
4685+
"_population_scaled_kpi_artifact", mock_artifact
4686+
)
4687+
4688+
outcome = engine.check_overall_kpi_invariance()
4689+
4690+
self.assertEqual(
4691+
outcome.check_type, eda_outcome.EDACheckType.KPI_INVARIANCE
4692+
)
4693+
self.assertLen(outcome.findings, 1)
4694+
self.assertEqual(outcome.findings[0].severity, eda_outcome.EDASeverity.INFO)
4695+
self.assertIn(
4696+
"The KPI has variability across geos and times",
4697+
outcome.findings[0].explanation,
4698+
)
4699+
self.assertLen(outcome.analysis_artifacts, 1)
4700+
self.assertEqual(outcome.analysis_artifacts[0], mock_artifact)
4701+
46084702

46094703
if __name__ == "__main__":
46104704
absltest.main()

meridian/model/eda/eda_outcome.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,30 @@ class VIFArtifact(AnalysisArtifact):
122122
outlier_df: pd.DataFrame
123123

124124

125+
@dataclasses.dataclass(frozen=True)
126+
class KpiInvarianceArtifact(AnalysisArtifact):
127+
"""Encapsulates artifacts from a KPI invariance analysis.
128+
129+
Attributes:
130+
population_scaled_kpi_da: DataArray of the population-scaled KPI.
131+
population_scaled_mean: The mean of the population-scaled KPI.
132+
population_scaled_stdev: The standard deviation of the population-scaled
133+
KPI.
134+
"""
135+
136+
population_scaled_kpi_da: xr.DataArray
137+
population_scaled_mean: float
138+
population_scaled_stdev: float
139+
140+
125141
@enum.unique
126142
class EDACheckType(enum.Enum):
127143
"""Enumeration for the type of an EDA check."""
128144

129145
PAIRWISE_CORR = enum.auto()
130146
STD = enum.auto()
131147
VIF = enum.auto()
148+
KPI_INVARIANCE = enum.auto()
132149

133150

134151
ArtifactType = typing.TypeVar('ArtifactType', bound='AnalysisArtifact')

meridian/model/eda/meridian_eda.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,18 @@
1313
# limitations under the License.
1414

1515
"""Module containing Meridian related exploratory data analysis (EDA) functionalities."""
16+
from __future__ import annotations
17+
18+
from typing import TYPE_CHECKING
1619

1720
import altair as alt
18-
from meridian.model import model
21+
22+
if TYPE_CHECKING:
23+
from meridian.model import model
24+
1925

2026
__all__ = [
21-
'MeridianEDA',
27+
"MeridianEDA",
2228
]
2329

2430

meridian/model/model.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,10 @@
3434
from meridian.model import prior_sampler
3535
from meridian.model import spec
3636
from meridian.model import transformers
37+
from meridian.model.eda import eda_engine
38+
from meridian.model.eda import eda_spec as eda_spec_module
3739
import numpy as np
3840

39-
4041
__all__ = [
4142
"MCMCSamplingError",
4243
"MCMCOOMError",
@@ -91,6 +92,8 @@ class Meridian:
9192
model_spec: A `ModelSpec` object containing the model specification.
9293
inference_data: A _mutable_ `arviz.InferenceData` object containing the
9394
resulting data from fitting the model.
95+
eda_engine: An `EDAEngine` object containing the EDA engine.
96+
eda_spec: An `EDASpec` object containing the EDA specification.
9497
n_geos: Number of geos in the data.
9598
n_media_channels: Number of media channels in the data.
9699
n_rf_channels: Number of reach and frequency (RF) channels in the data.
@@ -154,12 +157,14 @@ def __init__(
154157
inference_data: (
155158
az.InferenceData | None
156159
) = None, # for deserializer use only
160+
eda_spec: eda_spec_module.EDASpec = eda_spec_module.EDASpec(),
157161
):
158162
self._input_data = input_data
159163
self._model_spec = model_spec if model_spec else spec.ModelSpec()
160164
self._inference_data = (
161165
inference_data if inference_data else az.InferenceData()
162166
)
167+
self._eda_spec = eda_spec
163168

164169
self._validate_data_dependent_model_spec()
165170
self._validate_injected_inference_data()
@@ -190,6 +195,14 @@ def model_spec(self) -> spec.ModelSpec:
190195
def inference_data(self) -> az.InferenceData:
191196
return self._inference_data
192197

198+
@functools.cached_property
199+
def eda_engine(self) -> eda_engine.EDAEngine:
200+
return eda_engine.EDAEngine(self, spec=self._eda_spec)
201+
202+
@property
203+
def eda_spec(self) -> eda_spec_module.EDASpec:
204+
return self._eda_spec
205+
193206
@functools.cached_property
194207
def media_tensors(self) -> media.MediaTensors:
195208
return media.build_media_tensors(self.input_data, self.model_spec)
@@ -1142,13 +1155,9 @@ def _check_if_no_time_variation(
11421155
" time."
11431156
)
11441157

1145-
def _kpi_has_variability(self):
1146-
"""Returns True if the KPI has variability across geos and times."""
1147-
return self.kpi_transformer.population_scaled_stdev != 0
1148-
11491158
def _validate_kpi_transformer(self):
11501159
"""Validates the KPI transformer."""
1151-
if self._kpi_has_variability():
1160+
if self.eda_engine.kpi_has_variability():
11521161
return
11531162

11541163
kpi = "kpi" if self.is_national else "population_scaled_kpi"

meridian/model/model_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
from meridian.model import model_test_data
3333
from meridian.model import prior_distribution
3434
from meridian.model import spec
35+
from meridian.model.eda import eda_engine
36+
from meridian.model.eda import eda_spec as eda_spec_module
3537
import numpy as np
3638
import xarray as xr
3739

@@ -412,6 +414,34 @@ def test_init_with_default_parameters_works(self):
412414
# Compare model spec.
413415
self.assertEqual(repr(meridian.model_spec), repr(sample_spec))
414416

417+
@parameterized.named_parameters(
418+
dict(
419+
testcase_name="with_default_spec",
420+
eda_spec_kwargs={},
421+
expected_eda_spec=eda_spec_module.EDASpec(),
422+
),
423+
dict(
424+
testcase_name="with_custom_spec",
425+
eda_spec_kwargs={
426+
"eda_spec": eda_spec_module.EDASpec(
427+
vif_spec=eda_spec_module.VIFSpec(geo_threshold=500.0)
428+
)
429+
},
430+
expected_eda_spec=eda_spec_module.EDASpec(
431+
vif_spec=eda_spec_module.VIFSpec(geo_threshold=500.0)
432+
),
433+
),
434+
)
435+
def test_eda_engine_and_spec_initialization(
436+
self, eda_spec_kwargs, expected_eda_spec
437+
):
438+
meridian = model.Meridian(
439+
input_data=self.input_data_with_media_only, **eda_spec_kwargs
440+
)
441+
442+
self.assertIsInstance(meridian.eda_engine, eda_engine.EDAEngine)
443+
self.assertEqual(meridian.eda_spec, expected_eda_spec)
444+
415445
def test_init_with_default_national_parameters_works(self):
416446
data = self.national_input_data_media_only
417447
meridian = model.Meridian(input_data=data)
@@ -586,6 +616,7 @@ def test_base_geo_properties(self):
586616
self.assertFalse(meridian.is_national)
587617
self.assertIsNotNone(meridian.prior_broadcast)
588618
self.assertIsNotNone(meridian.inference_data)
619+
self.assertIsNotNone(meridian.eda_engine)
589620
self.assertNotIn(constants.PRIOR, meridian.inference_data.attrs)
590621
self.assertNotIn(constants.POSTERIOR, meridian.inference_data.attrs)
591622

@@ -598,6 +629,7 @@ def test_base_national_properties(self):
598629
self.assertTrue(meridian.is_national)
599630
self.assertIsNotNone(meridian.prior_broadcast)
600631
self.assertIsNotNone(meridian.inference_data)
632+
self.assertIsNotNone(meridian.eda_engine)
601633
self.assertNotIn(constants.PRIOR, meridian.inference_data.attrs)
602634
self.assertNotIn(constants.POSTERIOR, meridian.inference_data.attrs)
603635

@@ -1653,6 +1685,7 @@ def test_base_geo_properties(self):
16531685
self.assertFalse(meridian.is_national)
16541686
self.assertIsNotNone(meridian.prior_broadcast)
16551687
self.assertIsNotNone(meridian.inference_data)
1688+
self.assertIsNotNone(meridian.eda_engine)
16561689
self.assertNotIn(constants.PRIOR, meridian.inference_data.attrs)
16571690
self.assertNotIn(constants.POSTERIOR, meridian.inference_data.attrs)
16581691

@@ -1667,6 +1700,7 @@ def test_base_national_properties(self):
16671700
self.assertTrue(meridian.is_national)
16681701
self.assertIsNotNone(meridian.prior_broadcast)
16691702
self.assertIsNotNone(meridian.inference_data)
1703+
self.assertIsNotNone(meridian.eda_engine)
16701704
self.assertNotIn(constants.PRIOR, meridian.inference_data.attrs)
16711705
self.assertNotIn(constants.POSTERIOR, meridian.inference_data.attrs)
16721706

0 commit comments

Comments
 (0)