Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
**/*.h5
**/*.npy
**/*.csv
**/*.csv.gz
**/_build
**/*.pkl
**/*.db
Expand Down
1 change: 1 addition & 0 deletions changelog.d/708.added
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Save calibration geography as a pipeline artifact and add ``--resume-from`` checkpoint support for long-running calibration fits.
30 changes: 30 additions & 0 deletions docs/calibration.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ python -m policyengine_us_data.calibration.unified_calibration \
--package-path storage/calibration/calibration_package.pkl \
--epochs 500 --device cuda

# Resume a previous fit for 500 more epochs:
python -m policyengine_us_data.calibration.unified_calibration \
--package-path storage/calibration/calibration_package.pkl \
--resume-from storage/calibration/calibration_weights.npy \
--epochs 500 --device cuda

# Full pipeline with PUF (build + fit in one shot):
make calibrate
```
Expand Down Expand Up @@ -88,6 +94,30 @@ python -m policyengine_us_data.calibration.unified_calibration \
You can re-run Step 2 as many times as you want with different hyperparameters. The expensive matrix
build only happens once.

Every fit now also writes a checkpoint next to the weights output
(`calibration_weights.checkpoint.pt` by default). To continue the same fit,
pass `--resume-from` with the weights file or checkpoint path. If a sibling
checkpoint exists next to the weights file, it is used automatically so the
L0 gate state is restored as well.

```bash
python -m policyengine_us_data.calibration.unified_calibration \
--package-path storage/calibration/calibration_package.pkl \
--epochs 2000 \
--beta 0.65 \
--lambda-l0 1e-4 \
--lambda-l2 1e-12 \
--log-freq 500 \
--target-config policyengine_us_data/calibration/target_config.yaml \
--device cpu \
--output policyengine_us_data/storage/calibration/national/weights.npy \
--resume-from policyengine_us_data/storage/calibration/national/weights.npy
```

When `--resume-from` points to a checkpoint, `--epochs` means additional epochs
to run beyond the saved checkpoint epoch count. If only a `.npy` weights file
exists, the run warm-starts from those weights.

### 2. Full pipeline with PUF

Adding `--puf-dataset` doubles the record count (~24K base records x 430 clones = ~10.3M columns) by
Expand Down
7 changes: 7 additions & 0 deletions modal_app/local_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,8 @@ def build_areas_worker(
"--output-dir",
str(output_dir),
]
if "geography" in calibration_inputs:
worker_cmd.extend(["--geography-path", calibration_inputs["geography"]])
if "n_clones" in calibration_inputs:
worker_cmd.extend(["--n-clones", str(calibration_inputs["n_clones"])])
if "seed" in calibration_inputs:
Expand Down Expand Up @@ -659,6 +661,7 @@ def coordinate_publish(
Path(f"/pipeline/artifacts/{run_id}") if run_id else Path("/pipeline/artifacts")
)
weights_path = artifacts / "calibration_weights.npy"
geography_path = artifacts / "geography_assignment.npz"
db_path = artifacts / "policy_data.db"
dataset_path = artifacts / "source_imputed_stratified_extended_cps.h5"
config_json_path = artifacts / "unified_run_config.json"
Expand All @@ -678,6 +681,7 @@ def coordinate_publish(

calibration_inputs = {
"weights": str(weights_path),
"geography": str(geography_path),
"dataset": str(dataset_path),
"database": str(db_path),
"n_clones": n_clones,
Expand Down Expand Up @@ -943,6 +947,7 @@ def coordinate_national_publish(
Path(f"/pipeline/artifacts/{run_id}") if run_id else Path("/pipeline/artifacts")
)
weights_path = artifacts / "national_calibration_weights.npy"
geography_path = artifacts / "national_geography_assignment.npz"
db_path = artifacts / "policy_data.db"
dataset_path = artifacts / "source_imputed_stratified_extended_cps.h5"
config_json_path = artifacts / "national_unified_run_config.json"
Expand All @@ -962,6 +967,7 @@ def coordinate_national_publish(

calibration_inputs = {
"weights": str(weights_path),
"geography": str(geography_path),
"dataset": str(dataset_path),
"database": str(db_path),
"n_clones": n_clones,
Expand All @@ -972,6 +978,7 @@ def coordinate_national_publish(
artifacts,
filename_remap={
"calibration_weights.npy": "national_calibration_weights.npy",
"geography_assignment.npz": "national_geography_assignment.npz",
},
)
run_dir = staging_dir / run_id
Expand Down
10 changes: 10 additions & 0 deletions modal_app/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,11 @@ def run_pipeline(
BytesIO(regional_result["weights"]),
f"{artifacts_rel}/calibration_weights.npy",
)
if regional_result.get("geography"):
batch.put_file(
BytesIO(regional_result["geography"]),
f"{artifacts_rel}/geography_assignment.npz",
)
if regional_result.get("config"):
batch.put_file(
BytesIO(regional_result["config"]),
Expand All @@ -856,6 +861,11 @@ def run_pipeline(
BytesIO(national_result["weights"]),
f"{artifacts_rel}/national_calibration_weights.npy",
)
if national_result.get("geography"):
batch.put_file(
BytesIO(national_result["geography"]),
f"{artifacts_rel}/national_geography_assignment.npz",
)
if national_result.get("config"):
batch.put_file(
BytesIO(national_result["config"]),
Expand Down
25 changes: 25 additions & 0 deletions modal_app/remote_calibration_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,15 @@ def _append_hyperparams(cmd, beta, lambda_l0, lambda_l2, learning_rate, log_freq
def _collect_outputs(cal_lines):
"""Extract weights and log bytes from calibration output lines."""
output_path = None
geography_path = None
log_path = None
cal_log_path = None
config_path = None
for line in cal_lines:
if "OUTPUT_PATH:" in line:
output_path = line.split("OUTPUT_PATH:")[1].strip()
elif "GEOGRAPHY_PATH:" in line:
geography_path = line.split("GEOGRAPHY_PATH:")[1].strip()
elif "CONFIG_PATH:" in line:
config_path = line.split("CONFIG_PATH:")[1].strip()
elif "CAL_LOG_PATH:" in line:
Expand All @@ -84,6 +87,11 @@ def _collect_outputs(cal_lines):
with open(output_path, "rb") as f:
weights_bytes = f.read()

geography_bytes = None
if geography_path:
with open(geography_path, "rb") as f:
geography_bytes = f.read()

log_bytes = None
if log_path:
with open(log_path, "rb") as f:
Expand All @@ -101,6 +109,7 @@ def _collect_outputs(cal_lines):

return {
"weights": weights_bytes,
"geography": geography_bytes,
"log": log_bytes,
"cal_log": cal_log_bytes,
"config": config_bytes,
Expand Down Expand Up @@ -975,6 +984,10 @@ def main(
f" - calibration/{prefix}calibration_weights.npy",
flush=True,
)
print(
f" - calibration/{prefix}geography_assignment.npz",
flush=True,
)
print(
f" - calibration/logs/{prefix}* (diagnostics, "
"config, calibration log)",
Expand Down Expand Up @@ -1006,6 +1019,12 @@ def main(
f.write(result["log"])
print(f"Diagnostics log saved to: {log_output}")

geography_output = f"{prefix}geography_assignment.npz"
if result.get("geography"):
with open(geography_output, "wb") as f:
f.write(result["geography"])
print(f"Geography saved to: {geography_output}")

cal_log_output = f"{prefix}calibration_log.csv"
if result.get("cal_log"):
with open(cal_log_output, "wb") as f:
Expand All @@ -1027,6 +1046,11 @@ def main(
BytesIO(result["weights"]),
f"artifacts/{prefix}calibration_weights.npy",
)
if result.get("geography"):
batch.put_file(
BytesIO(result["geography"]),
f"artifacts/{prefix}geography_assignment.npz",
)
if result.get("config"):
batch.put_file(
BytesIO(result["config"]),
Expand All @@ -1042,6 +1066,7 @@ def main(

upload_calibration_artifacts(
weights_path=output,
geography_path=(geography_output if result.get("geography") else None),
log_dir=".",
prefix=prefix,
)
Expand Down
33 changes: 16 additions & 17 deletions modal_app/worker_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,11 @@ def main():
parser.add_argument("--dataset-path", required=True)
parser.add_argument("--db-path", required=True)
parser.add_argument("--output-dir", required=True)
parser.add_argument(
"--geography-path",
default=None,
help="Optional explicit path to geography_assignment.npz",
)
parser.add_argument(
"--n-clones",
type=int,
Expand Down Expand Up @@ -210,13 +215,11 @@ def main():
build_h5,
NYC_COUNTY_FIPS,
AT_LARGE_DISTRICTS,
load_calibration_geography,
)
from policyengine_us_data.calibration.calibration_utils import (
STATE_CODES,
)
from policyengine_us_data.calibration.clone_and_assign import (
assign_random_geography,
)

weights = np.load(weights_path)

Expand All @@ -226,15 +229,18 @@ def main():
n_records = len(_sim.calculate("household_id", map_to="household").values)
del _sim

geography = assign_random_geography(
geography = load_calibration_geography(
weights_path=weights_path,
n_records=n_records,
n_clones=args.n_clones,
seed=args.seed,
geography_path=(
Path(args.geography_path) if args.geography_path is not None else None
),
)
cds_to_calibrate = sorted(set(geography.cd_geoid.astype(str)))
geo_labels = cds_to_calibrate
print(
f"Generated geography: "
f"Loaded geography: "
f"{geography.n_clones} clones x "
f"{geography.n_records} records",
file=sys.stderr,
Expand Down Expand Up @@ -403,19 +409,12 @@ def main():
national_dir.mkdir(parents=True, exist_ok=True)
n_clones_from_weights = weights.shape[0] // n_records
if n_clones_from_weights != geography.n_clones:
print(
raise ValueError(
f"National weights have {n_clones_from_weights} clones "
f"but geography has {geography.n_clones}; "
f"regenerating geography",
file=sys.stderr,
f"but geography has {geography.n_clones}. "
"Use the matching saved geography artifact."
)
national_geo = assign_random_geography(
n_records=n_records,
n_clones=n_clones_from_weights,
seed=args.seed,
)
else:
national_geo = geography
national_geo = geography
path = build_h5(
weights=weights,
geography=national_geo,
Expand Down
99 changes: 99 additions & 0 deletions policyengine_us_data/calibration/clone_and_assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,105 @@ def _sample(size, mask_slice=None):
)


def save_geography(geography: GeographyAssignment, path) -> None:
"""Save a GeographyAssignment to a compressed .npz file.

Args:
geography: The geography assignment to save.
path: Output file path (should end in .npz).
"""
from pathlib import Path

path = Path(path)
np.savez_compressed(
path,
block_geoid=geography.block_geoid,
cd_geoid=geography.cd_geoid,
county_fips=geography.county_fips,
state_fips=geography.state_fips,
n_records=np.array([geography.n_records]),
n_clones=np.array([geography.n_clones]),
)


def load_geography(path) -> GeographyAssignment:
"""Load a GeographyAssignment from a .npz file.

Args:
path: Path to the .npz file saved by save_geography.

Returns:
GeographyAssignment with all fields restored.
"""
from pathlib import Path

path = Path(path)
data = np.load(path, allow_pickle=True)
return GeographyAssignment(
block_geoid=data["block_geoid"],
cd_geoid=data["cd_geoid"],
county_fips=data["county_fips"],
state_fips=data["state_fips"],
n_records=int(data["n_records"][0]),
n_clones=int(data["n_clones"][0]),
)


@lru_cache(maxsize=1)
def load_sorted_block_cd_lookup():
"""Load a sorted block -> CD lookup for legacy block artifacts."""
blocks, cds, _, _ = load_global_block_distribution()
order = np.argsort(blocks)
return blocks[order], cds[order]


def reconstruct_geography_from_blocks(
block_geoids: np.ndarray,
n_records: int,
n_clones: int,
) -> GeographyAssignment:
"""Reconstruct a GeographyAssignment from saved block GEOIDs."""
block_geoids = np.asarray(block_geoids, dtype=str)
expected_len = n_records * n_clones
if len(block_geoids) != expected_len:
raise ValueError(
f"Expected {expected_len} block GEOIDs for "
f"{n_records} records x {n_clones} clones, got {len(block_geoids)}"
)

sorted_blocks, sorted_cds = load_sorted_block_cd_lookup()
indices = np.searchsorted(sorted_blocks, block_geoids)
valid = indices < len(sorted_blocks)
matched = np.zeros(len(block_geoids), dtype=bool)
matched[valid] = sorted_blocks[indices[valid]] == block_geoids[valid]

if not np.all(matched):
missing = np.unique(block_geoids[~matched])[:5]
raise KeyError(
"Could not recover congressional districts for some blocks. "
f"Examples: {missing.tolist()}"
)

county_fips = np.fromiter(
(block[:5] for block in block_geoids),
dtype="U5",
count=len(block_geoids),
)
state_fips = np.fromiter(
(int(block[:2]) for block in block_geoids),
dtype=np.int32,
count=len(block_geoids),
)
return GeographyAssignment(
block_geoid=block_geoids,
cd_geoid=sorted_cds[indices],
county_fips=county_fips,
state_fips=state_fips,
n_records=n_records,
n_clones=n_clones,
)


def double_geography_for_puf(
geography: GeographyAssignment,
) -> GeographyAssignment:
Expand Down
Loading
Loading