Skip to content

Commit 7ada9c0

Browse files
riemanliThe Meridian Authors
authored andcommitted
Refactor EDAEngine to use allocated spend properties.
PiperOrigin-RevId: 825830523
1 parent 2ecfd53 commit 7ada9c0

File tree

2 files changed

+225
-103
lines changed

2 files changed

+225
-103
lines changed

meridian/model/eda/eda_engine.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -384,32 +384,30 @@ def media_scaled_da(self) -> xr.DataArray | None:
384384

385385
@functools.cached_property
386386
def media_spend_da(self) -> xr.DataArray | None:
387-
if self._meridian.input_data.media_spend is None:
388-
return None
389-
media_spend_da = _data_array_like(
390-
da=self._meridian.input_data.media_spend,
391-
values=self._meridian.media_tensors.media_spend,
392-
)
393-
media_spend_da.name = constants.MEDIA_SPEND
387+
"""Returns media spend.
388+
389+
If the input spend is aggregated, it is allocated across geo and time
390+
proportionally to media units.
391+
"""
394392
# No need to truncate the media time for media spend.
395-
return media_spend_da
393+
da = self._meridian.input_data.allocated_media_spend
394+
if da is None:
395+
return None
396+
da = da.copy()
397+
da.name = constants.MEDIA_SPEND
398+
return da
396399

397400
@functools.cached_property
398401
def national_media_spend_da(self) -> xr.DataArray | None:
399402
"""Returns the national media spend data array."""
400-
if self._meridian.input_data.media_spend is None:
403+
if self.media_spend_da is None:
401404
return None
402405
if self._meridian.is_national:
403-
if self.media_spend_da is None:
404-
# This case should be impossible given the check above.
405-
raise RuntimeError(
406-
'media_spend_da is None when media_spend is not None.'
407-
)
408406
national_da = self.media_spend_da.squeeze(constants.GEO, drop=True)
409407
national_da.name = constants.NATIONAL_MEDIA_SPEND
410408
else:
411409
national_da = self._aggregate_and_scale_geo_da(
412-
self._meridian.input_data.media_spend,
410+
self._meridian.input_data.allocated_media_spend,
413411
constants.NATIONAL_MEDIA_SPEND,
414412
None,
415413
)
@@ -540,29 +538,31 @@ def national_non_media_scaled_da(self) -> xr.DataArray | None:
540538

541539
@functools.cached_property
542540
def rf_spend_da(self) -> xr.DataArray | None:
543-
if self._meridian.input_data.rf_spend is None:
541+
"""Returns RF spend.
542+
543+
If the input spend is aggregated, it is allocated across geo and time
544+
proportionally to RF impressions (reach * frequency).
545+
"""
546+
da = self._meridian.input_data.allocated_rf_spend
547+
if da is None:
544548
return None
545-
rf_spend_da = _data_array_like(
546-
da=self._meridian.input_data.rf_spend,
547-
values=self._meridian.rf_tensors.rf_spend,
548-
)
549-
rf_spend_da.name = constants.RF_SPEND
550-
return rf_spend_da
549+
da = da.copy()
550+
da.name = constants.RF_SPEND
551+
return da
551552

552553
@functools.cached_property
553554
def national_rf_spend_da(self) -> xr.DataArray | None:
554555
"""Returns the national RF spend data array."""
555-
if self._meridian.input_data.rf_spend is None:
556+
if self.rf_spend_da is None:
556557
return None
557558
if self._meridian.is_national:
558-
if self.rf_spend_da is None:
559-
# This case should be impossible given the check above.
560-
raise RuntimeError('rf_spend_da is None when rf_spend is not None.')
561559
national_da = self.rf_spend_da.squeeze(constants.GEO, drop=True)
562560
national_da.name = constants.NATIONAL_RF_SPEND
563561
else:
564562
national_da = self._aggregate_and_scale_geo_da(
565-
self._meridian.input_data.rf_spend, constants.NATIONAL_RF_SPEND, None
563+
self._meridian.input_data.allocated_rf_spend,
564+
constants.NATIONAL_RF_SPEND,
565+
None,
566566
)
567567
return national_da
568568

0 commit comments

Comments
 (0)