Skip to content

Commit 3a17861

Browse files
riemanliThe Meridian Authors
authored andcommitted
Adds KPI invariance check to EDA engine.
PiperOrigin-RevId: 825871774
1 parent 7ada9c0 commit 3a17861

File tree

10 files changed

+269
-13
lines changed

10 files changed

+269
-13
lines changed

meridian/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@
133133
NON_MEDIA_TREATMENTS_SCALED = 'non_media_treatments_scaled'
134134
CONTROLS_SCALED = 'controls_scaled'
135135
KPI_SCALED = f'{KPI}_scaled'
136+
POPULATION_SCALED_KPI = f'{POPULATION}_scaled_{KPI}'
136137
RF_IMPRESSIONS_SCALED = f'{RF_IMPRESSIONS}_scaled'
137138

138139
# Non-media treatments baseline value constants.

meridian/model/eda/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,6 @@
1515
"""The Meridian API module that performs EDA checks."""
1616

1717
from meridian.model.eda import eda_engine
18+
from meridian.model.eda import eda_outcome
19+
from meridian.model.eda import eda_spec
1820
from meridian.model.eda import meridian_eda

meridian/model/eda/eda_engine.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,15 @@
1414

1515
"""Meridian EDA Engine."""
1616

17+
from __future__ import annotations
18+
1719
import dataclasses
1820
import functools
21+
import typing
1922
from typing import Optional, Sequence
23+
24+
from meridian import backend
2025
from meridian import constants
21-
from meridian.model import model
2226
from meridian.model import transformers
2327
from meridian.model.eda import eda_outcome
2428
from meridian.model.eda import eda_spec
@@ -30,6 +34,9 @@
3034
import xarray as xr
3135

3236

37+
if typing.TYPE_CHECKING:
38+
from meridian.model import model # pylint: disable=g-bad-import-order,g-import-not-at-top
39+
3340
_DEFAULT_DA_VAR_AGG_FUNCTION = np.sum
3441
_CORRELATION_COL_NAME = 'correlation'
3542
_STACK_VAR_COORD_NAME = 'var'
@@ -134,7 +141,7 @@ class ReachFrequencyData:
134141

135142

136143
def _data_array_like(
137-
*, da: xr.DataArray, values: np.ndarray | tf.Tensor
144+
*, da: xr.DataArray, values: np.ndarray | backend.Tensor
138145
) -> xr.DataArray:
139146
"""Returns a DataArray from `values` with the same structure as `da`.
140147
@@ -727,6 +734,27 @@ def geo_population_da(self) -> xr.DataArray | None:
727734
name=constants.POPULATION,
728735
)
729736

737+
@functools.cached_property
738+
def _population_scaled_kpi_artifact(
739+
self,
740+
) -> eda_outcome.KpiInvarianceArtifact:
741+
"""Returns an artifact containing population-scaled KPI data."""
742+
kpi_transformer = self._meridian.kpi_transformer
743+
744+
population_scaled_kpi_da = _data_array_like(
745+
da=self._meridian.input_data.kpi,
746+
values=kpi_transformer.population_scaled_kpi,
747+
)
748+
population_scaled_kpi_da.name = constants.POPULATION_SCALED_KPI
749+
750+
artifact = eda_outcome.KpiInvarianceArtifact(
751+
level=eda_outcome.AnalysisLevel.OVERALL,
752+
population_scaled_kpi_da=population_scaled_kpi_da,
753+
population_scaled_mean=float(kpi_transformer.population_scaled_mean),
754+
population_scaled_stdev=float(kpi_transformer.population_scaled_stdev),
755+
)
756+
return artifact
757+
730758
@functools.cached_property
731759
def kpi_scaled_da(self) -> xr.DataArray:
732760
scaled_kpi_da = _data_array_like(
@@ -1643,3 +1671,35 @@ def check_national_vif(
16431671
findings=findings,
16441672
analysis_artifacts=[national_vif_artifact],
16451673
)
1674+
1675+
def kpi_has_variability(self) -> bool:
1676+
"""Returns True if the KPI has variability across geos and times."""
1677+
stdev = float(self._meridian.kpi_transformer.population_scaled_stdev)
1678+
return stdev != 0
1679+
1680+
def check_overall_kpi_invariance(self) -> eda_outcome.EDAOutcome:
1681+
"""Checks if the KPI is constant across all geos and times."""
1682+
if not self.kpi_has_variability():
1683+
kpi = 'kpi' if self._meridian.is_national else 'population_scaled_kpi'
1684+
1685+
eda_finding = eda_outcome.EDAFinding(
1686+
severity=eda_outcome.EDASeverity.ERROR,
1687+
explanation=(
1688+
f'`{kpi}` is constant across all geos and times, indicating no'
1689+
' signal in the data. Please fix this data error.'
1690+
),
1691+
)
1692+
else:
1693+
eda_finding = eda_outcome.EDAFinding(
1694+
severity=eda_outcome.EDASeverity.INFO,
1695+
explanation=(
1696+
'The KPI has variability across geos and times, indicating'
1697+
' variability in the data.'
1698+
),
1699+
)
1700+
1701+
return eda_outcome.EDAOutcome(
1702+
check_type=eda_outcome.EDACheckType.KPI_INVARIANCE,
1703+
findings=[eda_finding],
1704+
analysis_artifacts=[self._population_scaled_kpi_artifact],
1705+
)

meridian/model/eda/eda_engine_test.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from unittest import mock
1616
from absl.testing import absltest
1717
from absl.testing import parameterized
18+
from meridian import backend
1819
from meridian import constants
1920
from meridian.model import model
2021
from meridian.model import model_test_data
@@ -4605,6 +4606,113 @@ def test_check_national_vif_has_correct_vif_value(self):
46054606
]
46064607
self.assertAllClose(national_artifact.vif_da.values, expected_national_vif)
46074608

4609+
@parameterized.named_parameters(
4610+
dict(
4611+
testcase_name="has_variability",
4612+
population_scaled_stdev=1.0,
4613+
expected_result=True,
4614+
),
4615+
dict(
4616+
testcase_name="no_variability",
4617+
population_scaled_stdev=0.0,
4618+
expected_result=False,
4619+
),
4620+
)
4621+
def test_kpi_has_variability(self, population_scaled_stdev, expected_result):
4622+
meridian = mock.Mock(spec=model.Meridian)
4623+
meridian.kpi_transformer.population_scaled_stdev = population_scaled_stdev
4624+
engine = eda_engine.EDAEngine(meridian)
4625+
self.assertEqual(engine.kpi_has_variability(), expected_result)
4626+
4627+
@parameterized.named_parameters(
4628+
dict(
4629+
testcase_name="geo",
4630+
is_national=False,
4631+
expected_kpi_name="population_scaled_kpi",
4632+
),
4633+
dict(
4634+
testcase_name="national",
4635+
is_national=True,
4636+
expected_kpi_name="kpi",
4637+
),
4638+
)
4639+
def test_check_overall_kpi_invariance_no_variability(
4640+
self, is_national, expected_kpi_name
4641+
):
4642+
meridian = mock.Mock(spec=model.Meridian)
4643+
meridian.is_national = is_national
4644+
meridian.input_data.kpi = self.input_data_with_media_only.kpi
4645+
mock_kpi_transformer = mock.Mock()
4646+
mock_kpi_transformer.population_scaled_stdev = 0.0
4647+
mock_kpi_transformer.population_scaled_mean = 100.0
4648+
mock_kpi_transformer.population_scaled_kpi = backend.zeros(
4649+
(self._N_GEOS, self._N_TIMES), dtype=backend.float32
4650+
)
4651+
meridian.kpi_transformer = mock_kpi_transformer
4652+
engine = eda_engine.EDAEngine(meridian)
4653+
4654+
outcome = engine.check_overall_kpi_invariance()
4655+
4656+
self.assertEqual(
4657+
outcome.check_type, eda_outcome.EDACheckType.KPI_INVARIANCE
4658+
)
4659+
self.assertLen(outcome.findings, 1)
4660+
self.assertEqual(
4661+
outcome.findings[0].severity, eda_outcome.EDASeverity.ERROR
4662+
)
4663+
self.assertIn(
4664+
f"`{expected_kpi_name}` is constant across all geos and times",
4665+
outcome.findings[0].explanation,
4666+
)
4667+
self.assertLen(outcome.analysis_artifacts, 1)
4668+
artifact = outcome.analysis_artifacts[0]
4669+
self.assertIsInstance(artifact, eda_outcome.KpiInvarianceArtifact)
4670+
self.assertEqual(artifact.level, eda_outcome.AnalysisLevel.OVERALL)
4671+
self.assertEqual(artifact.population_scaled_stdev, 0.0)
4672+
self.assertEqual(artifact.population_scaled_mean, 100.0)
4673+
self.assertAllClose(
4674+
artifact.population_scaled_kpi_da.values,
4675+
backend.to_tensor(mock_kpi_transformer.population_scaled_kpi),
4676+
)
4677+
4678+
def test_check_overall_kpi_invariance_has_variability(self):
4679+
meridian = mock.Mock(spec=model.Meridian)
4680+
meridian.is_national = False
4681+
meridian.input_data.kpi = self.input_data_with_media_only.kpi
4682+
mock_kpi_transformer = mock.Mock()
4683+
mock_kpi_transformer.population_scaled_stdev = 1.0
4684+
mock_kpi_transformer.population_scaled_mean = 100.0
4685+
mock_kpi_transformer.population_scaled_kpi = backend.to_tensor(
4686+
np.arange(self._N_GEOS * self._N_TIMES).reshape(
4687+
self._N_GEOS, self._N_TIMES
4688+
),
4689+
dtype=backend.float32,
4690+
)
4691+
meridian.kpi_transformer = mock_kpi_transformer
4692+
engine = eda_engine.EDAEngine(meridian)
4693+
4694+
outcome = engine.check_overall_kpi_invariance()
4695+
4696+
self.assertEqual(
4697+
outcome.check_type, eda_outcome.EDACheckType.KPI_INVARIANCE
4698+
)
4699+
self.assertLen(outcome.findings, 1)
4700+
self.assertEqual(outcome.findings[0].severity, eda_outcome.EDASeverity.INFO)
4701+
self.assertIn(
4702+
"The KPI has variability across geos and times",
4703+
outcome.findings[0].explanation,
4704+
)
4705+
self.assertLen(outcome.analysis_artifacts, 1)
4706+
artifact = outcome.analysis_artifacts[0]
4707+
self.assertIsInstance(artifact, eda_outcome.KpiInvarianceArtifact)
4708+
self.assertEqual(artifact.level, eda_outcome.AnalysisLevel.OVERALL)
4709+
self.assertEqual(artifact.population_scaled_stdev, 1.0)
4710+
self.assertEqual(artifact.population_scaled_mean, 100.0)
4711+
self.assertAllClose(
4712+
artifact.population_scaled_kpi_da.values,
4713+
backend.to_tensor(mock_kpi_transformer.population_scaled_kpi),
4714+
)
4715+
46084716

46094717
if __name__ == "__main__":
46104718
absltest.main()

meridian/model/eda/eda_outcome.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,30 @@ class VIFArtifact(AnalysisArtifact):
122122
outlier_df: pd.DataFrame
123123

124124

125+
@dataclasses.dataclass(frozen=True)
126+
class KpiInvarianceArtifact(AnalysisArtifact):
127+
"""Encapsulates artifacts from a KPI invariance analysis.
128+
129+
Attributes:
130+
population_scaled_kpi_da: DataArray of the population-scaled KPI.
131+
population_scaled_mean: The mean of the population-scaled KPI.
132+
population_scaled_stdev: The standard deviation of the population-scaled
133+
KPI.
134+
"""
135+
136+
population_scaled_kpi_da: xr.DataArray
137+
population_scaled_mean: float
138+
population_scaled_stdev: float
139+
140+
125141
@enum.unique
126142
class EDACheckType(enum.Enum):
127143
"""Enumeration for the type of an EDA check."""
128144

129145
PAIRWISE_CORR = enum.auto()
130146
STD = enum.auto()
131147
VIF = enum.auto()
148+
KPI_INVARIANCE = enum.auto()
132149

133150

134151
ArtifactType = typing.TypeVar('ArtifactType', bound='AnalysisArtifact')

meridian/model/eda/meridian_eda.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,18 @@
1313
# limitations under the License.
1414

1515
"""Module containing Meridian related exploratory data analysis (EDA) functionalities."""
16+
from __future__ import annotations
17+
18+
from typing import TYPE_CHECKING
1619

1720
import altair as alt
18-
from meridian.model import model
21+
22+
if TYPE_CHECKING:
23+
from meridian.model import model
24+
1925

2026
__all__ = [
21-
'MeridianEDA',
27+
"MeridianEDA",
2228
]
2329

2430

meridian/model/model.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,10 @@
3434
from meridian.model import prior_sampler
3535
from meridian.model import spec
3636
from meridian.model import transformers
37+
from meridian.model.eda import eda_engine
38+
from meridian.model.eda import eda_spec as eda_spec_module
3739
import numpy as np
3840

39-
4041
__all__ = [
4142
"MCMCSamplingError",
4243
"MCMCOOMError",
@@ -91,6 +92,8 @@ class Meridian:
9192
model_spec: A `ModelSpec` object containing the model specification.
9293
inference_data: A _mutable_ `arviz.InferenceData` object containing the
9394
resulting data from fitting the model.
95+
eda_engine: An `EDAEngine` object containing the EDA engine.
96+
eda_spec: An `EDASpec` object containing the EDA specification.
9497
n_geos: Number of geos in the data.
9598
n_media_channels: Number of media channels in the data.
9699
n_rf_channels: Number of reach and frequency (RF) channels in the data.
@@ -154,12 +157,14 @@ def __init__(
154157
inference_data: (
155158
az.InferenceData | None
156159
) = None, # for deserializer use only
160+
eda_spec: eda_spec_module.EDASpec = eda_spec_module.EDASpec(),
157161
):
158162
self._input_data = input_data
159163
self._model_spec = model_spec if model_spec else spec.ModelSpec()
160164
self._inference_data = (
161165
inference_data if inference_data else az.InferenceData()
162166
)
167+
self._eda_spec = eda_spec
163168

164169
self._validate_data_dependent_model_spec()
165170
self._validate_injected_inference_data()
@@ -190,6 +195,14 @@ def model_spec(self) -> spec.ModelSpec:
190195
def inference_data(self) -> az.InferenceData:
191196
return self._inference_data
192197

198+
@functools.cached_property
199+
def eda_engine(self) -> eda_engine.EDAEngine:
200+
return eda_engine.EDAEngine(self, spec=self._eda_spec)
201+
202+
@property
203+
def eda_spec(self) -> eda_spec_module.EDASpec:
204+
return self._eda_spec
205+
193206
@functools.cached_property
194207
def media_tensors(self) -> media.MediaTensors:
195208
return media.build_media_tensors(self.input_data, self.model_spec)
@@ -1142,13 +1155,9 @@ def _check_if_no_time_variation(
11421155
" time."
11431156
)
11441157

1145-
def _kpi_has_variability(self):
1146-
"""Returns True if the KPI has variability across geos and times."""
1147-
return self.kpi_transformer.population_scaled_stdev != 0
1148-
11491158
def _validate_kpi_transformer(self):
11501159
"""Validates the KPI transformer."""
1151-
if self._kpi_has_variability():
1160+
if self.eda_engine.kpi_has_variability():
11521161
return
11531162

11541163
kpi = "kpi" if self.is_national else "population_scaled_kpi"

0 commit comments

Comments
 (0)