Skip to content

Commit 3e94227

Browse files
lukmazThe Meridian Authors
authored andcommitted
Implement channel recommendation checks.
PiperOrigin-RevId: 918015128
1 parent a60be99 commit 3e94227

10 files changed

Lines changed: 1060 additions & 4 deletions

File tree

meridian/analysis/review/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""Meridian Model Quality Module."""
1616

17-
from meridian.analysis.review import checks
1817
from meridian.analysis.review import configs
1918
from meridian.analysis.review import results
19+
from meridian.analysis.review import checks
2020
from meridian.analysis.review import reviewer

meridian/analysis/review/checks.py

Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -815,3 +815,286 @@ def run(self) -> results.PriorPosteriorShiftCheckResult:
815815
media_results = self._run_for_channel_type(constants.MEDIA_CHANNEL)
816816
rf_results = self._run_for_channel_type(constants.RF_CHANNEL)
817817
return self._aggregate(media_results, rf_results)
818+
819+
820+
# ==============================================================================
821+
# Check: Unrealistic ROI
822+
# ==============================================================================
823+
class UnrealisticROICheck(
824+
BaseCheck[configs.UnrealisticROIConfig, results.UnrealisticROICheckResult]
825+
):
826+
"""Checks for paid channels with unrealistic posterior ROI estimates."""
827+
828+
def run(self) -> results.UnrealisticROICheckResult:
829+
# 1. Get spend and calculate spend share
830+
spend = self._model_context.input_data.get_total_spend()
831+
if spend.ndim == 3:
832+
spend = np.sum(spend, axis=(0, 1))
833+
834+
total_spend_sum = np.sum(spend)
835+
if total_spend_sum > 0:
836+
spend_share = spend / total_spend_sum
837+
else:
838+
spend_share = np.zeros_like(spend)
839+
840+
# 2. Get posterior ROI and channels
841+
posterior_rois = []
842+
channels = []
843+
if constants.MEDIA_CHANNEL in self._inference_data.posterior.coords:
844+
posterior_rois.append(self._inference_data.posterior.roi_m.values)
845+
channels.extend(
846+
self._inference_data.posterior.media_channel.values.tolist()
847+
)
848+
if constants.RF_CHANNEL in self._inference_data.posterior.coords:
849+
posterior_rois.append(self._inference_data.posterior.roi_rf.values)
850+
channels.extend(
851+
self._inference_data.posterior.rf_channel.values.tolist()
852+
)
853+
854+
if not posterior_rois:
855+
raise ValueError("No posterior ROI data found in inference_data.")
856+
857+
posterior_roi_concat = np.concatenate(posterior_rois, axis=-1)
858+
roi_medians = np.median(posterior_roi_concat, axis=(0, 1))
859+
860+
# 3. Evaluate checks
861+
channel_results = []
862+
high_roi_channels = []
863+
low_roi_channels = []
864+
865+
for i, channel in enumerate(channels):
866+
med = roi_medians[i]
867+
share = spend_share[i]
868+
869+
spend_weighted_roi = med * share
870+
if share > 0:
871+
reciprocal_spend_weighted_roi = med / share
872+
else:
873+
reciprocal_spend_weighted_roi = np.nan
874+
875+
if spend_weighted_roi > self._config.roi_upper_bound:
876+
case = results.UnrealisticROIChannelCases.ROI_HIGH
877+
high_roi_channels.append(channel)
878+
elif (
879+
share > 0
880+
and reciprocal_spend_weighted_roi < self._config.roi_lower_bound
881+
):
882+
case = results.UnrealisticROIChannelCases.ROI_LOW
883+
low_roi_channels.append(channel)
884+
else:
885+
case = results.UnrealisticROIChannelCases.ROI_PASS
886+
887+
channel_results.append(
888+
results.UnrealisticROIChannelResult(
889+
case=case,
890+
channel_name=channel,
891+
spend_share=share,
892+
roi_median=med,
893+
spend_weighted_roi=spend_weighted_roi,
894+
)
895+
)
896+
897+
if high_roi_channels or low_roi_channels:
898+
agg_case = results.UnrealisticROIAggregateCases.REVIEW
899+
else:
900+
agg_case = results.UnrealisticROIAggregateCases.PASS
901+
902+
return results.UnrealisticROICheckResult(
903+
case=agg_case,
904+
channel_results=channel_results,
905+
high_roi_channels=high_roi_channels,
906+
low_roi_channels=low_roi_channels,
907+
)
908+
909+
910+
# ==============================================================================
911+
# Check: High Variance ROI
912+
# ==============================================================================
913+
class HighVarianceCheck(
914+
BaseCheck[configs.HighVarianceConfig, results.HighVarianceCheckResult]
915+
):
916+
"""Checks for paid channels with high uncertainty (variance) in posterior ROI."""
917+
918+
def run(self) -> results.HighVarianceCheckResult:
919+
# 1. Get spend and calculate spend share
920+
spend = self._model_context.input_data.get_total_spend()
921+
if spend.ndim == 3:
922+
spend = np.sum(spend, axis=(0, 1))
923+
924+
total_spend_sum = np.sum(spend)
925+
if total_spend_sum > 0:
926+
spend_share = spend / total_spend_sum
927+
else:
928+
spend_share = np.zeros_like(spend)
929+
930+
# 2. Get posterior ROI and channels
931+
posterior_rois = []
932+
channels = []
933+
if constants.MEDIA_CHANNEL in self._inference_data.posterior.coords:
934+
posterior_rois.append(self._inference_data.posterior.roi_m.values)
935+
channels.extend(
936+
self._inference_data.posterior.media_channel.values.tolist()
937+
)
938+
if constants.RF_CHANNEL in self._inference_data.posterior.coords:
939+
posterior_rois.append(self._inference_data.posterior.roi_rf.values)
940+
channels.extend(
941+
self._inference_data.posterior.rf_channel.values.tolist()
942+
)
943+
944+
if not posterior_rois:
945+
raise ValueError("No posterior ROI data found in inference_data.")
946+
947+
posterior_roi_concat = np.concatenate(posterior_rois, axis=-1)
948+
roi_medians = np.median(posterior_roi_concat, axis=(0, 1))
949+
950+
# 3. Compute credible intervals using az.hdi
951+
hdi = az.hdi(posterior_roi_concat, hdi_prob=0.8)
952+
hdi_lower = hdi[:, 0]
953+
hdi_upper = hdi[:, 1]
954+
955+
rel_width_post = np.divide(
956+
hdi_upper - hdi_lower,
957+
np.abs(roi_medians),
958+
out=np.zeros_like(roi_medians, dtype=float),
959+
where=(roi_medians != 0),
960+
)
961+
962+
# 4. Compute high variance check
963+
# Rel_Width_prior ≈ 2.07377
964+
relative_width_ratio = rel_width_post / 2.07377
965+
spend_weighted_ratio = relative_width_ratio * spend_share
966+
967+
channel_results = []
968+
high_variance_channels = []
969+
970+
for i, channel in enumerate(channels):
971+
share = spend_share[i]
972+
ratio = relative_width_ratio[i]
973+
weighted_ratio = spend_weighted_ratio[i]
974+
975+
if weighted_ratio > self._config.high_variance_threshold:
976+
case = results.HighVarianceChannelCases.HIGH_VARIANCE
977+
high_variance_channels.append(channel)
978+
else:
979+
case = results.HighVarianceChannelCases.ROI_PASS
980+
981+
channel_results.append(
982+
results.HighVarianceChannelResult(
983+
case=case,
984+
channel_name=channel,
985+
spend_share=share,
986+
relative_width_ratio=ratio,
987+
)
988+
)
989+
990+
if high_variance_channels:
991+
agg_case = results.HighVarianceAggregateCases.REVIEW
992+
else:
993+
agg_case = results.HighVarianceAggregateCases.PASS
994+
995+
return results.HighVarianceCheckResult(
996+
case=agg_case,
997+
channel_results=channel_results,
998+
high_variance_channels=high_variance_channels,
999+
)
1000+
1001+
1002+
# Alias for HighVarianceCheck to match some designs
1003+
LargeVarianceCheck = HighVarianceCheck
1004+
1005+
1006+
# ==============================================================================
1007+
# Check: Control Bias
1008+
# ==============================================================================
1009+
class ControlBiasCheck(
1010+
BaseCheck[configs.ControlBiasConfig, results.ControlBiasCheckResult]
1011+
):
1012+
"""Checks correlation between paid channels and control variables to flag potential confounding."""
1013+
1014+
def run(self) -> results.ControlBiasCheckResult:
1015+
# 1. Get channels names
1016+
channels = self._model_context.input_data.get_all_paid_channels().tolist()
1017+
1018+
# 2. If no controls are included in the model, correlation is 0 for all
1019+
controls = self._model_context.input_data.controls
1020+
if controls is None or self._model_context.n_controls == 0:
1021+
channel_results = [
1022+
results.ControlBiasChannelResult(
1023+
case=results.ControlBiasChannelCases.LOW_CORRELATION,
1024+
channel_name=channel,
1025+
max_correlation=0.0,
1026+
)
1027+
for channel in channels
1028+
]
1029+
return results.ControlBiasCheckResult(
1030+
case=results.ControlBiasAggregateCases.NO_CONTROLS,
1031+
channel_results=channel_results,
1032+
low_correlation_channels=channels,
1033+
)
1034+
1035+
# 3. Retrieve channel and control data
1036+
media_data = self._model_context.input_data.get_all_media_and_rf()
1037+
n_times = self._model_context.n_times
1038+
media_aligned = media_data[:, -n_times:, :]
1039+
1040+
controls_data = controls.values
1041+
1042+
# 4. Vectorized Pearson correlation over time (axis 1) per geo
1043+
media_centered = media_aligned - np.mean(
1044+
media_aligned, axis=1, keepdims=True
1045+
)
1046+
controls_centered = controls_data - np.mean(
1047+
controls_data, axis=1, keepdims=True
1048+
)
1049+
1050+
numerator = np.einsum("gtc,gtz->gcz", media_centered, controls_centered)
1051+
1052+
sum_sq_media = np.sum(media_centered**2, axis=1, keepdims=True)
1053+
sum_sq_controls = np.sum(controls_centered**2, axis=1, keepdims=True)
1054+
1055+
denom_media = np.transpose(sum_sq_media, (0, 2, 1))
1056+
denom_controls = np.transpose(sum_sq_controls, (0, 1, 2))
1057+
1058+
denominator = np.sqrt(denom_media * denom_controls)
1059+
1060+
correlation = np.divide(
1061+
numerator,
1062+
denominator,
1063+
out=np.zeros_like(numerator, dtype=float),
1064+
where=(denominator != 0),
1065+
)
1066+
1067+
# 5. Max absolute correlation across all geos and controls
1068+
abs_correlation = np.abs(correlation)
1069+
max_correlations = np.max(abs_correlation, axis=(0, 2))
1070+
1071+
# 6. Evaluate checks
1072+
channel_results = []
1073+
low_correlation_channels = []
1074+
1075+
for i, channel in enumerate(channels):
1076+
max_corr = max_correlations[i]
1077+
if max_corr < self._config.correlation_threshold:
1078+
case = results.ControlBiasChannelCases.LOW_CORRELATION
1079+
low_correlation_channels.append(channel)
1080+
else:
1081+
case = results.ControlBiasChannelCases.ROI_PASS
1082+
1083+
channel_results.append(
1084+
results.ControlBiasChannelResult(
1085+
case=case,
1086+
channel_name=channel,
1087+
max_correlation=max_corr,
1088+
)
1089+
)
1090+
1091+
if low_correlation_channels:
1092+
agg_case = results.ControlBiasAggregateCases.REVIEW
1093+
else:
1094+
agg_case = results.ControlBiasAggregateCases.PASS
1095+
1096+
return results.ControlBiasCheckResult(
1097+
case=agg_case,
1098+
channel_results=channel_results,
1099+
low_correlation_channels=low_correlation_channels,
1100+
)

0 commit comments

Comments
 (0)