@@ -815,3 +815,286 @@ def run(self) -> results.PriorPosteriorShiftCheckResult:
815815 media_results = self ._run_for_channel_type (constants .MEDIA_CHANNEL )
816816 rf_results = self ._run_for_channel_type (constants .RF_CHANNEL )
817817 return self ._aggregate (media_results , rf_results )
818+
819+
820+ # ==============================================================================
821+ # Check: Unrealistic ROI
822+ # ==============================================================================
823+ class UnrealisticROICheck (
824+ BaseCheck [configs .UnrealisticROIConfig , results .UnrealisticROICheckResult ]
825+ ):
826+ """Checks for paid channels with unrealistic posterior ROI estimates."""
827+
828+ def run (self ) -> results .UnrealisticROICheckResult :
829+ # 1. Get spend and calculate spend share
830+ spend = self ._model_context .input_data .get_total_spend ()
831+ if spend .ndim == 3 :
832+ spend = np .sum (spend , axis = (0 , 1 ))
833+
834+ total_spend_sum = np .sum (spend )
835+ if total_spend_sum > 0 :
836+ spend_share = spend / total_spend_sum
837+ else :
838+ spend_share = np .zeros_like (spend )
839+
840+ # 2. Get posterior ROI and channels
841+ posterior_rois = []
842+ channels = []
843+ if constants .MEDIA_CHANNEL in self ._inference_data .posterior .coords :
844+ posterior_rois .append (self ._inference_data .posterior .roi_m .values )
845+ channels .extend (
846+ self ._inference_data .posterior .media_channel .values .tolist ()
847+ )
848+ if constants .RF_CHANNEL in self ._inference_data .posterior .coords :
849+ posterior_rois .append (self ._inference_data .posterior .roi_rf .values )
850+ channels .extend (
851+ self ._inference_data .posterior .rf_channel .values .tolist ()
852+ )
853+
854+ if not posterior_rois :
855+ raise ValueError ("No posterior ROI data found in inference_data." )
856+
857+ posterior_roi_concat = np .concatenate (posterior_rois , axis = - 1 )
858+ roi_medians = np .median (posterior_roi_concat , axis = (0 , 1 ))
859+
860+ # 3. Evaluate checks
861+ channel_results = []
862+ high_roi_channels = []
863+ low_roi_channels = []
864+
865+ for i , channel in enumerate (channels ):
866+ med = roi_medians [i ]
867+ share = spend_share [i ]
868+
869+ spend_weighted_roi = med * share
870+ if share > 0 :
871+ reciprocal_spend_weighted_roi = med / share
872+ else :
873+ reciprocal_spend_weighted_roi = np .nan
874+
875+ if spend_weighted_roi > self ._config .roi_upper_bound :
876+ case = results .UnrealisticROIChannelCases .ROI_HIGH
877+ high_roi_channels .append (channel )
878+ elif (
879+ share > 0
880+ and reciprocal_spend_weighted_roi < self ._config .roi_lower_bound
881+ ):
882+ case = results .UnrealisticROIChannelCases .ROI_LOW
883+ low_roi_channels .append (channel )
884+ else :
885+ case = results .UnrealisticROIChannelCases .ROI_PASS
886+
887+ channel_results .append (
888+ results .UnrealisticROIChannelResult (
889+ case = case ,
890+ channel_name = channel ,
891+ spend_share = share ,
892+ roi_median = med ,
893+ spend_weighted_roi = spend_weighted_roi ,
894+ )
895+ )
896+
897+ if high_roi_channels or low_roi_channels :
898+ agg_case = results .UnrealisticROIAggregateCases .REVIEW
899+ else :
900+ agg_case = results .UnrealisticROIAggregateCases .PASS
901+
902+ return results .UnrealisticROICheckResult (
903+ case = agg_case ,
904+ channel_results = channel_results ,
905+ high_roi_channels = high_roi_channels ,
906+ low_roi_channels = low_roi_channels ,
907+ )
908+
909+
910+ # ==============================================================================
911+ # Check: High Variance ROI
912+ # ==============================================================================
913+ class HighVarianceCheck (
914+ BaseCheck [configs .HighVarianceConfig , results .HighVarianceCheckResult ]
915+ ):
916+ """Checks for paid channels with high uncertainty (variance) in posterior ROI."""
917+
918+ def run (self ) -> results .HighVarianceCheckResult :
919+ # 1. Get spend and calculate spend share
920+ spend = self ._model_context .input_data .get_total_spend ()
921+ if spend .ndim == 3 :
922+ spend = np .sum (spend , axis = (0 , 1 ))
923+
924+ total_spend_sum = np .sum (spend )
925+ if total_spend_sum > 0 :
926+ spend_share = spend / total_spend_sum
927+ else :
928+ spend_share = np .zeros_like (spend )
929+
930+ # 2. Get posterior ROI and channels
931+ posterior_rois = []
932+ channels = []
933+ if constants .MEDIA_CHANNEL in self ._inference_data .posterior .coords :
934+ posterior_rois .append (self ._inference_data .posterior .roi_m .values )
935+ channels .extend (
936+ self ._inference_data .posterior .media_channel .values .tolist ()
937+ )
938+ if constants .RF_CHANNEL in self ._inference_data .posterior .coords :
939+ posterior_rois .append (self ._inference_data .posterior .roi_rf .values )
940+ channels .extend (
941+ self ._inference_data .posterior .rf_channel .values .tolist ()
942+ )
943+
944+ if not posterior_rois :
945+ raise ValueError ("No posterior ROI data found in inference_data." )
946+
947+ posterior_roi_concat = np .concatenate (posterior_rois , axis = - 1 )
948+ roi_medians = np .median (posterior_roi_concat , axis = (0 , 1 ))
949+
950+ # 3. Compute credible intervals using az.hdi
951+ hdi = az .hdi (posterior_roi_concat , hdi_prob = 0.8 )
952+ hdi_lower = hdi [:, 0 ]
953+ hdi_upper = hdi [:, 1 ]
954+
955+ rel_width_post = np .divide (
956+ hdi_upper - hdi_lower ,
957+ np .abs (roi_medians ),
958+ out = np .zeros_like (roi_medians , dtype = float ),
959+ where = (roi_medians != 0 ),
960+ )
961+
962+ # 4. Compute high variance check
963+ # Rel_Width_prior ≈ 2.07377
964+ relative_width_ratio = rel_width_post / 2.07377
965+ spend_weighted_ratio = relative_width_ratio * spend_share
966+
967+ channel_results = []
968+ high_variance_channels = []
969+
970+ for i , channel in enumerate (channels ):
971+ share = spend_share [i ]
972+ ratio = relative_width_ratio [i ]
973+ weighted_ratio = spend_weighted_ratio [i ]
974+
975+ if weighted_ratio > self ._config .high_variance_threshold :
976+ case = results .HighVarianceChannelCases .HIGH_VARIANCE
977+ high_variance_channels .append (channel )
978+ else :
979+ case = results .HighVarianceChannelCases .ROI_PASS
980+
981+ channel_results .append (
982+ results .HighVarianceChannelResult (
983+ case = case ,
984+ channel_name = channel ,
985+ spend_share = share ,
986+ relative_width_ratio = ratio ,
987+ )
988+ )
989+
990+ if high_variance_channels :
991+ agg_case = results .HighVarianceAggregateCases .REVIEW
992+ else :
993+ agg_case = results .HighVarianceAggregateCases .PASS
994+
995+ return results .HighVarianceCheckResult (
996+ case = agg_case ,
997+ channel_results = channel_results ,
998+ high_variance_channels = high_variance_channels ,
999+ )
1000+
1001+
1002+ # Alias for HighVarianceCheck to match some designs
1003+ LargeVarianceCheck = HighVarianceCheck
1004+
1005+
1006+ # ==============================================================================
1007+ # Check: Control Bias
1008+ # ==============================================================================
1009+ class ControlBiasCheck (
1010+ BaseCheck [configs .ControlBiasConfig , results .ControlBiasCheckResult ]
1011+ ):
1012+ """Checks correlation between paid channels and control variables to flag potential confounding."""
1013+
1014+ def run (self ) -> results .ControlBiasCheckResult :
1015+ # 1. Get channels names
1016+ channels = self ._model_context .input_data .get_all_paid_channels ().tolist ()
1017+
1018+ # 2. If no controls are included in the model, correlation is 0 for all
1019+ controls = self ._model_context .input_data .controls
1020+ if controls is None or self ._model_context .n_controls == 0 :
1021+ channel_results = [
1022+ results .ControlBiasChannelResult (
1023+ case = results .ControlBiasChannelCases .LOW_CORRELATION ,
1024+ channel_name = channel ,
1025+ max_correlation = 0.0 ,
1026+ )
1027+ for channel in channels
1028+ ]
1029+ return results .ControlBiasCheckResult (
1030+ case = results .ControlBiasAggregateCases .NO_CONTROLS ,
1031+ channel_results = channel_results ,
1032+ low_correlation_channels = channels ,
1033+ )
1034+
1035+ # 3. Retrieve channel and control data
1036+ media_data = self ._model_context .input_data .get_all_media_and_rf ()
1037+ n_times = self ._model_context .n_times
1038+ media_aligned = media_data [:, - n_times :, :]
1039+
1040+ controls_data = controls .values
1041+
1042+ # 4. Vectorized Pearson correlation over time (axis 1) per geo
1043+ media_centered = media_aligned - np .mean (
1044+ media_aligned , axis = 1 , keepdims = True
1045+ )
1046+ controls_centered = controls_data - np .mean (
1047+ controls_data , axis = 1 , keepdims = True
1048+ )
1049+
1050+ numerator = np .einsum ("gtc,gtz->gcz" , media_centered , controls_centered )
1051+
1052+ sum_sq_media = np .sum (media_centered ** 2 , axis = 1 , keepdims = True )
1053+ sum_sq_controls = np .sum (controls_centered ** 2 , axis = 1 , keepdims = True )
1054+
1055+ denom_media = np .transpose (sum_sq_media , (0 , 2 , 1 ))
1056+ denom_controls = np .transpose (sum_sq_controls , (0 , 1 , 2 ))
1057+
1058+ denominator = np .sqrt (denom_media * denom_controls )
1059+
1060+ correlation = np .divide (
1061+ numerator ,
1062+ denominator ,
1063+ out = np .zeros_like (numerator , dtype = float ),
1064+ where = (denominator != 0 ),
1065+ )
1066+
1067+ # 5. Max absolute correlation across all geos and controls
1068+ abs_correlation = np .abs (correlation )
1069+ max_correlations = np .max (abs_correlation , axis = (0 , 2 ))
1070+
1071+ # 6. Evaluate checks
1072+ channel_results = []
1073+ low_correlation_channels = []
1074+
1075+ for i , channel in enumerate (channels ):
1076+ max_corr = max_correlations [i ]
1077+ if max_corr < self ._config .correlation_threshold :
1078+ case = results .ControlBiasChannelCases .LOW_CORRELATION
1079+ low_correlation_channels .append (channel )
1080+ else :
1081+ case = results .ControlBiasChannelCases .ROI_PASS
1082+
1083+ channel_results .append (
1084+ results .ControlBiasChannelResult (
1085+ case = case ,
1086+ channel_name = channel ,
1087+ max_correlation = max_corr ,
1088+ )
1089+ )
1090+
1091+ if low_correlation_channels :
1092+ agg_case = results .ControlBiasAggregateCases .REVIEW
1093+ else :
1094+ agg_case = results .ControlBiasAggregateCases .PASS
1095+
1096+ return results .ControlBiasCheckResult (
1097+ case = agg_case ,
1098+ channel_results = channel_results ,
1099+ low_correlation_channels = low_correlation_channels ,
1100+ )
0 commit comments