Skip to content

Commit dd2d273

Browse files
authored
Merge pull request #671 from PolicyEngine/fix/pipeline-resilience
Improve calibration: AGI-conditional geography, expanded targets, pipeline fixes
2 parents 2611f8d + 4d7c227 commit dd2d273

File tree

10 files changed

+310
-260
lines changed

10 files changed

+310
-260
lines changed

modal_app/local_area.py

Lines changed: 38 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -140,35 +140,27 @@ def get_version() -> str:
140140

141141

142142
def partition_work(
143-
states: List[str],
144-
districts: List[str],
145-
cities: List[str],
143+
work_items: List[Dict],
146144
num_workers: int,
147145
completed: set,
148146
) -> List[List[Dict]]:
149-
"""Partition work items across N workers."""
150-
remaining = []
151-
152-
for s in states:
153-
item_id = f"state:{s}"
154-
if item_id not in completed:
155-
remaining.append({"type": "state", "id": s, "weight": 5})
156-
157-
for d in districts:
158-
item_id = f"district:{d}"
159-
if item_id not in completed:
160-
remaining.append({"type": "district", "id": d, "weight": 1})
147+
"""Partition work items across N workers using LPT scheduling."""
148+
remaining = [
149+
item for item in work_items if f"{item['type']}:{item['id']}" not in completed
150+
]
151+
remaining.sort(key=lambda x: -x["weight"])
161152

162-
for c in cities:
163-
item_id = f"city:{c}"
164-
if item_id not in completed:
165-
remaining.append({"type": "city", "id": c, "weight": 3})
153+
n_workers = min(num_workers, len(remaining))
154+
if n_workers == 0:
155+
return []
166156

167-
remaining.sort(key=lambda x: -x["weight"])
157+
heap = [(0, i) for i in range(n_workers)]
158+
chunks = [[] for _ in range(n_workers)]
168159

169-
chunks = [[] for _ in range(num_workers)]
170-
for i, item in enumerate(remaining):
171-
chunks[i % num_workers].append(item)
160+
for item in remaining:
161+
load, idx = heapq.heappop(heap)
162+
chunks[idx].append(item)
163+
heapq.heappush(heap, (load + item["weight"], idx))
172164

173165
return [c for c in chunks if c]
174166

@@ -197,9 +189,7 @@ def get_completed_from_volume(version_dir: Path) -> set:
197189

198190
def run_phase(
199191
phase_name: str,
200-
states: List[str],
201-
districts: List[str],
202-
cities: List[str],
192+
work_items: List[Dict],
203193
num_workers: int,
204194
completed: set,
205195
branch: str,
@@ -216,7 +206,7 @@ def run_phase(
216206
and crashes, and validation_rows is a list of per-target
217207
validation result dicts.
218208
"""
219-
work_chunks = partition_work(states, districts, cities, num_workers, completed)
209+
work_chunks = partition_work(work_items, num_workers, completed)
220210
total_remaining = sum(len(c) for c in work_chunks)
221211

222212
print(f"\n--- Phase: {phase_name} ---")
@@ -228,7 +218,8 @@ def run_phase(
228218

229219
handles = []
230220
for i, chunk in enumerate(work_chunks):
231-
print(f" Worker {i}: {len(chunk)} items")
221+
total_weight = sum(item["weight"] for item in chunk)
222+
print(f" Worker {i}: {len(chunk)} items, weight {total_weight}")
232223
handle = build_areas_worker.spawn(
233224
branch=branch,
234225
version=version,
@@ -753,7 +744,7 @@ def coordinate_publish(
753744
cds = get_all_cds_from_database(db_uri)
754745
states = list(STATE_CODES.values())
755746
districts = [get_district_friendly_name(cd) for cd in cds]
756-
print(json.dumps({{"states": states, "districts": districts, "cities": ["NYC"]}}))
747+
print(json.dumps({{"states": states, "districts": districts, "cities": ["NYC"], "cds": cds}}))
757748
""",
758749
],
759750
capture_output=True,
@@ -769,6 +760,22 @@ def coordinate_publish(
769760
districts = work_info["districts"]
770761
cities = work_info["cities"]
771762

763+
from collections import Counter
764+
from policyengine_us_data.calibration.calibration_utils import STATE_CODES
765+
766+
raw_cds = work_info["cds"]
767+
cds_per_state = Counter(STATE_CODES.get(int(cd) // 100, "??") for cd in raw_cds)
768+
769+
CITY_WEIGHTS = {"NYC": 11}
770+
771+
work_items = []
772+
for s in states:
773+
work_items.append({"type": "state", "id": s, "weight": cds_per_state.get(s, 1)})
774+
for d in districts:
775+
work_items.append({"type": "district", "id": d, "weight": 1})
776+
for c in cities:
777+
work_items.append({"type": "city", "id": c, "weight": CITY_WEIGHTS.get(c, 3)})
778+
772779
staging_volume.reload()
773780
completed = get_completed_from_volume(version_dir)
774781
print(f"Found {len(completed)} already-completed items on volume")
@@ -786,32 +793,8 @@ def coordinate_publish(
786793
accumulated_validation_rows = []
787794

788795
completed, phase_errors, v_rows = run_phase(
789-
"States",
790-
states=states,
791-
districts=[],
792-
cities=[],
793-
completed=completed,
794-
**phase_args,
795-
)
796-
accumulated_errors.extend(phase_errors)
797-
accumulated_validation_rows.extend(v_rows)
798-
799-
completed, phase_errors, v_rows = run_phase(
800-
"Districts",
801-
states=[],
802-
districts=districts,
803-
cities=[],
804-
completed=completed,
805-
**phase_args,
806-
)
807-
accumulated_errors.extend(phase_errors)
808-
accumulated_validation_rows.extend(v_rows)
809-
810-
completed, phase_errors, v_rows = run_phase(
811-
"Cities",
812-
states=[],
813-
districts=[],
814-
cities=cities,
796+
"All areas",
797+
work_items=work_items,
815798
completed=completed,
816799
**phase_args,
817800
)

modal_app/worker_script.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,7 @@ def main():
208208

209209
from policyengine_us_data.calibration.publish_local_area import (
210210
build_h5,
211-
NYC_COUNTIES,
212-
NYC_CDS,
211+
NYC_COUNTY_FIPS,
213212
AT_LARGE_DISTRICTS,
214213
)
215214
from policyengine_us_data.calibration.calibration_utils import (
@@ -388,22 +387,14 @@ def main():
388387
)
389388

390389
elif item_type == "city":
391-
cd_subset = [cd for cd in cds_to_calibrate if cd in NYC_CDS]
392-
if not cd_subset:
393-
print(
394-
"No NYC CDs found, skipping",
395-
file=sys.stderr,
396-
)
397-
continue
398390
cities_dir = output_dir / "cities"
399391
cities_dir.mkdir(parents=True, exist_ok=True)
400392
path = build_h5(
401393
weights=weights,
402394
geography=geography,
403395
dataset_path=dataset_path,
404396
output_path=cities_dir / "NYC.h5",
405-
cd_subset=cd_subset,
406-
county_filter=NYC_COUNTIES,
397+
county_fips_filter=NYC_COUNTY_FIPS,
407398
takeup_filter=takeup_filter,
408399
)
409400

policyengine_us_data/calibration/block_assignment.py

Lines changed: 0 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -580,92 +580,3 @@ def derive_geography_from_blocks(
580580
"zcta": np.array(zcta_list),
581581
"county_index": county_indices,
582582
}
583-
584-
585-
# === County Filter Functions (for city-level datasets) ===
586-
587-
588-
def get_county_filter_probability(
589-
cd_geoid: str,
590-
county_filter: set,
591-
) -> float:
592-
"""
593-
Calculate P(county in filter | CD) using block-level data.
594-
595-
Returns the probability that a household in this CD would be in the
596-
target area (e.g., NYC). Used for weight scaling when building
597-
city-level datasets.
598-
599-
Args:
600-
cd_geoid: Congressional district geoid (e.g., "3610")
601-
county_filter: Set of county enum names that define the target area
602-
603-
Returns:
604-
Probability between 0 and 1
605-
"""
606-
distributions = _get_block_distributions()
607-
cd_key = str(int(cd_geoid))
608-
609-
if cd_key not in distributions:
610-
return 0.0
611-
612-
dist = distributions[cd_key]
613-
614-
# Convert county enum names to FIPS codes for comparison
615-
fips_to_enum = _build_county_fips_to_enum()
616-
enum_to_fips = {v: k for k, v in fips_to_enum.items()}
617-
target_fips = {enum_to_fips.get(name) for name in county_filter}
618-
target_fips.discard(None)
619-
620-
# Sum probabilities of blocks in target counties
621-
return sum(
622-
prob
623-
for block, prob in dist.items()
624-
if get_county_fips_from_block(block) in target_fips
625-
)
626-
627-
628-
def get_filtered_block_distribution(
629-
cd_geoid: str,
630-
county_filter: set,
631-
) -> Dict[str, float]:
632-
"""
633-
Get normalized distribution over blocks in target counties only.
634-
635-
Used when building city-level datasets to assign only blocks in valid
636-
counties while maintaining relative proportions within the target area.
637-
638-
Args:
639-
cd_geoid: Congressional district geoid (e.g., "3610")
640-
county_filter: Set of county enum names that define the target area
641-
642-
Returns:
643-
Dictionary mapping block GEOIDs to normalized probabilities.
644-
Empty dict if CD has no overlap with target area.
645-
"""
646-
distributions = _get_block_distributions()
647-
cd_key = str(int(cd_geoid))
648-
649-
if cd_key not in distributions:
650-
return {}
651-
652-
dist = distributions[cd_key]
653-
654-
# Convert county enum names to FIPS codes for comparison
655-
fips_to_enum = _build_county_fips_to_enum()
656-
enum_to_fips = {v: k for k, v in fips_to_enum.items()}
657-
target_fips = {enum_to_fips.get(name) for name in county_filter}
658-
target_fips.discard(None)
659-
660-
# Filter to blocks in target counties
661-
filtered = {
662-
block: prob
663-
for block, prob in dist.items()
664-
if get_county_fips_from_block(block) in target_fips
665-
}
666-
667-
# Normalize
668-
total = sum(filtered.values())
669-
if total > 0:
670-
return {block: prob / total for block, prob in filtered.items()}
671-
return {}

policyengine_us_data/calibration/clone_and_assign.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,23 @@ def load_global_block_distribution():
6767
return block_geoids, cd_geoids, state_fips, probs
6868

6969

70+
def _build_agi_block_probs(cds, pop_probs, cd_agi_targets):
71+
"""Multiply population block probs by CD AGI target weights."""
72+
agi_weights = np.array([cd_agi_targets.get(cd, 0.0) for cd in cds])
73+
agi_weights = np.maximum(agi_weights, 0.0)
74+
if agi_weights.sum() == 0:
75+
return pop_probs
76+
agi_probs = pop_probs * agi_weights
77+
return agi_probs / agi_probs.sum()
78+
79+
7080
def assign_random_geography(
7181
n_records: int,
7282
n_clones: int = 10,
7383
seed: int = 42,
84+
household_agi: np.ndarray = None,
85+
cd_agi_targets: dict = None,
86+
agi_threshold_pctile: float = 90.0,
7487
) -> GeographyAssignment:
7588
"""Assign random census block geography to cloned
7689
CPS records.
@@ -95,17 +108,48 @@ def assign_random_geography(
95108
n_total = n_records * n_clones
96109
rng = np.random.default_rng(seed)
97110

111+
agi_probs = None
112+
extreme_mask = None
113+
if household_agi is not None and cd_agi_targets is not None:
114+
threshold = np.percentile(household_agi, agi_threshold_pctile)
115+
extreme_mask = household_agi >= threshold
116+
agi_probs = _build_agi_block_probs(cds, probs, cd_agi_targets)
117+
logger.info(
118+
"AGI-conditional assignment: %d extreme HHs (AGI >= $%.0f) "
119+
"use AGI-weighted block probs",
120+
extreme_mask.sum(),
121+
threshold,
122+
)
123+
124+
def _sample(size, mask_slice=None):
125+
"""Sample block indices, using AGI-weighted probs for extreme HHs."""
126+
if (
127+
extreme_mask is not None
128+
and agi_probs is not None
129+
and mask_slice is not None
130+
):
131+
out = np.empty(size, dtype=np.int64)
132+
ext = mask_slice
133+
n_ext = ext.sum()
134+
n_norm = size - n_ext
135+
if n_ext > 0:
136+
out[ext] = rng.choice(len(blocks), size=n_ext, p=agi_probs)
137+
if n_norm > 0:
138+
out[~ext] = rng.choice(len(blocks), size=n_norm, p=probs)
139+
return out
140+
return rng.choice(len(blocks), size=size, p=probs)
141+
98142
indices = np.empty(n_total, dtype=np.int64)
99143

100144
# Clone 0: unrestricted draw
101-
indices[:n_records] = rng.choice(len(blocks), size=n_records, p=probs)
145+
indices[:n_records] = _sample(n_records, extreme_mask)
102146

103147
assigned_cds = np.empty((n_clones, n_records), dtype=object)
104148
assigned_cds[0] = cds[indices[:n_records]]
105149

106150
for clone_idx in range(1, n_clones):
107151
start = clone_idx * n_records
108-
clone_indices = rng.choice(len(blocks), size=n_records, p=probs)
152+
clone_indices = _sample(n_records, extreme_mask)
109153
clone_cds = cds[clone_indices]
110154

111155
collisions = np.zeros(n_records, dtype=bool)
@@ -116,7 +160,20 @@ def assign_random_geography(
116160
n_bad = collisions.sum()
117161
if n_bad == 0:
118162
break
119-
clone_indices[collisions] = rng.choice(len(blocks), size=n_bad, p=probs)
163+
bad_mask = collisions
164+
if extreme_mask is not None and agi_probs is not None:
165+
bad_ext = bad_mask & extreme_mask
166+
bad_norm = bad_mask & ~extreme_mask
167+
if bad_ext.sum() > 0:
168+
clone_indices[bad_ext] = rng.choice(
169+
len(blocks), size=bad_ext.sum(), p=agi_probs
170+
)
171+
if bad_norm.sum() > 0:
172+
clone_indices[bad_norm] = rng.choice(
173+
len(blocks), size=bad_norm.sum(), p=probs
174+
)
175+
else:
176+
clone_indices[collisions] = rng.choice(len(blocks), size=n_bad, p=probs)
120177
clone_cds = cds[clone_indices]
121178
collisions = np.zeros(n_records, dtype=bool)
122179
for prev in range(clone_idx):

0 commit comments

Comments
 (0)