Skip to content

Commit 90507ba

Browse files
ez96The Meridian Authors
authored andcommitted
Add ranked_geos property to InputData.
PiperOrigin-RevId: 822204516
1 parent c005d8f commit 90507ba

2 files changed

Lines changed: 150 additions & 0 deletions

File tree

meridian/data/input_data.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,66 @@ def media_time(self) -> xr.DataArray:
363363
else:
364364
return self.reach[constants.MEDIA_TIME]
365365

366+
@functools.cached_property
367+
def ranked_geos(self) -> list[str]:
368+
"""Ranks geos by total spend then by total KPI."""
369+
n_geos = len(self.geo)
370+
if n_geos == 1:
371+
return [self.geo[0]]
372+
else:
373+
spend_da = None
374+
375+
if self.media_spend is not None:
376+
# Sum across media_channel and any other non-geo/time dimensions
377+
media_spend_sum = self.media_spend.sum(
378+
dim=[
379+
d
380+
for d in self.media_spend.dims
381+
if d not in (constants.GEO, constants.TIME)
382+
]
383+
)
384+
spend_da = media_spend_sum
385+
386+
if self.rf_spend is not None:
387+
# Sum across rf_channel and any other non-geo/time dimensions
388+
rf_spend_sum = self.rf_spend.sum(
389+
dim=[
390+
d
391+
for d in self.rf_spend.dims
392+
if d not in (constants.GEO, constants.TIME)
393+
]
394+
)
395+
396+
if spend_da is None:
397+
spend_da = rf_spend_sum
398+
else:
399+
# Add media and rf spend. xarray aligns on geo/time.
400+
# Use fillna(0) for safe addition.
401+
spend_da = spend_da.fillna(0) + rf_spend_sum.fillna(0)
402+
403+
if spend_da is None:
404+
raise ValueError(
405+
"It is required to have at least one of media spend or rf spend."
406+
)
407+
408+
# 2. Calculate Total Spend and KPI per Geo
409+
geo_spend_sum_da = spend_da.sum(dim=constants.TIME)
410+
geo_kpi_sum_da = self.kpi.sum(dim=constants.TIME)
411+
412+
# 3. Get the underlying NumPy arrays and coordinates
413+
geo_coords = geo_spend_sum_da[constants.GEO].values
414+
spend_values = geo_spend_sum_da.values
415+
kpi_values = geo_kpi_sum_da.values
416+
417+
# 4. Use NumPy's lexsort for multi-criteria sorting (descending)
418+
sort_indices = np.lexsort((
419+
-kpi_values, # Secondary sort key (minor)
420+
-spend_values, # Primary sort key (major)
421+
))
422+
423+
# 5. Apply the sorted indices to the geo coordinates
424+
return geo_coords[sort_indices].tolist()
425+
366426
@functools.cached_property
367427
def media_time_coordinates(self) -> tc.TimeCoordinates:
368428
"""Returns the media time dimension in a `TimeCoordinates` wrapper."""

meridian/data/input_data_test.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1642,6 +1642,96 @@ def test_scaled_centered_kpi_supports_dtype_int(self):
16421642
data.population = data.population.astype(int)
16431643
self.assertNotEmpty(data.scaled_centered_kpi)
16441644

1645+
def test_rank_geos_only_media_spend(self):
1646+
data = input_data.InputData(
1647+
kpi=self.not_lagged_kpi,
1648+
kpi_type=constants.NON_REVENUE,
1649+
population=self.population,
1650+
media=self.lagged_media,
1651+
media_spend=self.media_spend,
1652+
)
1653+
self.assertListEqual(
1654+
data.ranked_geos,
1655+
[
1656+
"geo_1",
1657+
"geo_5",
1658+
"geo_4",
1659+
"geo_6",
1660+
"geo_7",
1661+
"geo_8",
1662+
"geo_2",
1663+
"geo_3",
1664+
"geo_9",
1665+
"geo_0",
1666+
],
1667+
)
1668+
1669+
def test_rank_geos_only_rf_spend(self):
1670+
data = input_data.InputData(
1671+
kpi=self.not_lagged_kpi,
1672+
kpi_type=constants.NON_REVENUE,
1673+
population=self.population,
1674+
reach=self.lagged_reach,
1675+
frequency=self.lagged_frequency,
1676+
rf_spend=self.rf_spend,
1677+
)
1678+
self.assertListEqual(
1679+
data.ranked_geos,
1680+
[
1681+
"geo_3",
1682+
"geo_5",
1683+
"geo_0",
1684+
"geo_4",
1685+
"geo_8",
1686+
"geo_6",
1687+
"geo_7",
1688+
"geo_1",
1689+
"geo_9",
1690+
"geo_2",
1691+
],
1692+
)
1693+
1694+
def test_rank_geos_media_spend_and_rf_spend(self):
1695+
data = input_data.InputData(
1696+
kpi=self.not_lagged_kpi,
1697+
kpi_type=constants.NON_REVENUE,
1698+
population=self.population,
1699+
media=self.lagged_media,
1700+
media_spend=self.media_spend,
1701+
reach=self.lagged_reach,
1702+
frequency=self.lagged_frequency,
1703+
rf_spend=self.rf_spend,
1704+
)
1705+
self.assertListEqual(
1706+
data.ranked_geos,
1707+
[
1708+
"geo_5",
1709+
"geo_4",
1710+
"geo_1",
1711+
"geo_6",
1712+
"geo_8",
1713+
"geo_3",
1714+
"geo_7",
1715+
"geo_0",
1716+
"geo_2",
1717+
"geo_9",
1718+
],
1719+
)
1720+
1721+
def test_rank_geos_national(self):
1722+
data = test_utils.sample_input_data_from_dataset(
1723+
test_utils.random_dataset(
1724+
n_geos=1,
1725+
n_times=20,
1726+
n_media_times=20,
1727+
n_controls=2,
1728+
n_media_channels=5,
1729+
),
1730+
"non_revenue",
1731+
)
1732+
actual_ranked = data.ranked_geos
1733+
self.assertListEqual(actual_ranked, ["geo_0"])
1734+
16451735

16461736
class NonpaidInputDataTest(parameterized.TestCase):
16471737
"""Tests for non-paid InputData."""

0 commit comments

Comments
 (0)