Skip to content

Commit bd15588

Browse files
authored
Add CTC artifact comparison diagnostics to national validation (#731)
* Add CTC artifact comparison diagnostics * Format CTC comparison tests
1 parent 6ac616b commit bd15588

File tree

3 files changed

+376
-1
lines changed

3 files changed

+376
-1
lines changed

changelog.d/725.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add comparison-mode CTC diagnostics to `validate_national_h5`, including child-count and child-age drift reporting between national artifacts.

policyengine_us_data/calibration/validate_national_h5.py

Lines changed: 220 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@
6161
}
6262

6363
DEFAULT_HF_PATH = "hf://policyengine/policyengine-us-data/national/US.h5"
64+
ARTIFACT_CTC_SUMMARY_VARIABLES = [
65+
"ctc_qualifying_children",
66+
"ctc",
67+
"refundable_ctc",
68+
"non_refundable_ctc",
69+
]
6470

6571
COUNT_VARS = {
6672
"person_count",
@@ -216,11 +222,183 @@ def build_canonical_ctc_reform_summary(
216222

217223
def _format_canonical_ctc_reform_summary(table: pd.DataFrame) -> str:
218224
display = table.copy()
219-
for column in ("baseline", "reformed", "delta"):
225+
numeric_columns = [
226+
column
227+
for column in display.columns
228+
if column != "variable" and pd.api.types.is_numeric_dtype(display[column])
229+
]
230+
for column in numeric_columns:
220231
display[column] = display[column].map(lambda value: f"${value / 1e9:,.1f}B")
221232
return display.to_string(index=False)
222233

223234

235+
def build_artifact_ctc_summary(
236+
reference_sim,
237+
candidate_sim,
238+
*,
239+
period: int = 2025,
240+
) -> pd.DataFrame:
241+
rows = []
242+
for variable in ARTIFACT_CTC_SUMMARY_VARIABLES:
243+
reference = float(reference_sim.calculate(variable, period=period).sum())
244+
candidate = float(candidate_sim.calculate(variable, period=period).sum())
245+
rows.append(
246+
{
247+
"variable": variable,
248+
"reference": reference,
249+
"candidate": candidate,
250+
"delta": candidate - reference,
251+
}
252+
)
253+
return pd.DataFrame(rows)
254+
255+
256+
def _format_artifact_ctc_summary(table: pd.DataFrame) -> str:
257+
display = table.copy()
258+
for column in ("reference", "candidate", "delta"):
259+
display[column] = display.apply(
260+
lambda row: (
261+
f"{row[column] / 1e6:,.2f}M"
262+
if row["variable"] in COUNT_VARS
263+
else f"${row[column] / 1e9:,.1f}B"
264+
),
265+
axis=1,
266+
)
267+
return display.to_string(index=False)
268+
269+
270+
def get_artifact_ctc_comparison_outputs(
271+
reference_sim,
272+
candidate_sim,
273+
*,
274+
period: int = 2025,
275+
) -> dict[str, str]:
276+
outputs = {
277+
"CURRENT-LAW CTC TOTAL DELTAS VS COMPARISON DATASET": (
278+
_format_artifact_ctc_summary(
279+
build_artifact_ctc_summary(
280+
reference_sim,
281+
candidate_sim,
282+
period=period,
283+
)
284+
)
285+
)
286+
}
287+
288+
delta_tables = _subtract_diagnostic_tables(
289+
create_ctc_diagnostic_tables(reference_sim, period=period),
290+
create_ctc_diagnostic_tables(candidate_sim, period=period),
291+
)
292+
section_names = {
293+
"by_agi_band": "CURRENT-LAW CTC DIAGNOSTIC DELTAS BY AGI BAND",
294+
"by_filing_status": "CURRENT-LAW CTC DIAGNOSTIC DELTAS BY FILING STATUS",
295+
"by_agi_band_and_filing_status": (
296+
"CURRENT-LAW CTC DIAGNOSTIC DELTAS BY AGI BAND AND FILING STATUS"
297+
),
298+
"by_child_count": "CURRENT-LAW CTC DIAGNOSTIC DELTAS BY QUALIFYING-CHILD COUNT",
299+
"by_child_age": "CURRENT-LAW CTC DIAGNOSTIC DELTAS BY QUALIFYING-CHILD AGE",
300+
}
301+
for name, table in delta_tables.items():
302+
if name in section_names:
303+
outputs[section_names[name]] = format_ctc_diagnostic_table(table)
304+
305+
return outputs
306+
307+
308+
def _build_canonical_ctc_reform_comparison_summary(
309+
reference_summary: pd.DataFrame,
310+
candidate_summary: pd.DataFrame,
311+
) -> pd.DataFrame:
312+
merged = reference_summary.merge(
313+
candidate_summary,
314+
on="variable",
315+
suffixes=("_reference", "_candidate"),
316+
)
317+
comparison = pd.DataFrame(
318+
{
319+
"variable": merged["variable"],
320+
"reference_baseline": merged["baseline_reference"],
321+
"candidate_baseline": merged["baseline_candidate"],
322+
"baseline_delta": (
323+
merged["baseline_candidate"] - merged["baseline_reference"]
324+
),
325+
"reference_reformed": merged["reformed_reference"],
326+
"candidate_reformed": merged["reformed_candidate"],
327+
"reformed_delta": (
328+
merged["reformed_candidate"] - merged["reformed_reference"]
329+
),
330+
"reference_delta": merged["delta_reference"],
331+
"candidate_delta": merged["delta_candidate"],
332+
"delta_drift": merged["delta_candidate"] - merged["delta_reference"],
333+
}
334+
)
335+
return comparison
336+
337+
338+
def get_canonical_ctc_reform_comparison_outputs(
339+
reference_dataset_path: str | None = None,
340+
candidate_dataset_path: str | None = None,
341+
*,
342+
reference_baseline_sim=None,
343+
candidate_baseline_sim=None,
344+
reference_reformed_sim=None,
345+
candidate_reformed_sim=None,
346+
period: int = 2025,
347+
) -> dict[str, str]:
348+
from policyengine_us import Microsimulation
349+
350+
if reference_baseline_sim is None:
351+
if reference_dataset_path is None:
352+
raise ValueError(
353+
"reference_dataset_path is required when reference_baseline_sim is not provided"
354+
)
355+
reference_baseline_sim = Microsimulation(dataset=reference_dataset_path)
356+
if candidate_baseline_sim is None:
357+
if candidate_dataset_path is None:
358+
raise ValueError(
359+
"candidate_dataset_path is required when candidate_baseline_sim is not provided"
360+
)
361+
candidate_baseline_sim = Microsimulation(dataset=candidate_dataset_path)
362+
363+
canonical_reform = _create_canonical_ctc_reform()
364+
if reference_reformed_sim is None:
365+
if reference_dataset_path is None:
366+
raise ValueError(
367+
"reference_dataset_path is required when reference_reformed_sim is not provided"
368+
)
369+
reference_reformed_sim = Microsimulation(
370+
dataset=reference_dataset_path,
371+
reform=canonical_reform,
372+
)
373+
if candidate_reformed_sim is None:
374+
if candidate_dataset_path is None:
375+
raise ValueError(
376+
"candidate_dataset_path is required when candidate_reformed_sim is not provided"
377+
)
378+
candidate_reformed_sim = Microsimulation(
379+
dataset=candidate_dataset_path,
380+
reform=canonical_reform,
381+
)
382+
383+
comparison = _build_canonical_ctc_reform_comparison_summary(
384+
build_canonical_ctc_reform_summary(
385+
reference_baseline_sim,
386+
reference_reformed_sim,
387+
period=period,
388+
),
389+
build_canonical_ctc_reform_summary(
390+
candidate_baseline_sim,
391+
candidate_reformed_sim,
392+
period=period,
393+
),
394+
)
395+
return {
396+
"CANONICAL CTC REFORM DRIFT VS COMPARISON DATASET": (
397+
_format_canonical_ctc_reform_summary(comparison)
398+
)
399+
}
400+
401+
224402
def _subtract_diagnostic_tables(
225403
baseline_tables: dict[str, pd.DataFrame],
226404
reformed_tables: dict[str, pd.DataFrame],
@@ -337,15 +515,35 @@ def main(argv=None):
337515
default=DEFAULT_HF_PATH,
338516
help=f"HF path to US.h5 (default: {DEFAULT_HF_PATH})",
339517
)
518+
parser.add_argument(
519+
"--compare-h5-path",
520+
default=None,
521+
help="Optional local path to comparison US.h5",
522+
)
523+
parser.add_argument(
524+
"--compare-hf-path",
525+
default=None,
526+
help="Optional HF path to comparison US.h5",
527+
)
340528
args = parser.parse_args(argv)
341529

342530
dataset_path = args.h5_path or args.hf_path
343531
resolved_dataset_path = resolve_dataset_path(dataset_path)
532+
comparison_dataset_path = args.compare_h5_path or args.compare_hf_path
533+
resolved_comparison_dataset_path = (
534+
resolve_dataset_path(comparison_dataset_path)
535+
if comparison_dataset_path is not None
536+
else None
537+
)
344538

345539
from policyengine_us import Microsimulation
346540

347541
print(f"Loading {dataset_path}...")
348542
sim = Microsimulation(dataset=resolved_dataset_path)
543+
comparison_sim = None
544+
if resolved_comparison_dataset_path is not None:
545+
print(f"Loading comparison dataset {comparison_dataset_path}...")
546+
comparison_sim = Microsimulation(dataset=resolved_comparison_dataset_path)
349547

350548
n_hh = sim.calculate("household_id", map_to="household").shape[0]
351549
print(f"Households in file: {n_hh:,}")
@@ -417,6 +615,27 @@ def main(argv=None):
417615
print("=" * 70)
418616
print(section_output)
419617

618+
if comparison_sim is not None:
619+
for section_name, section_output in get_artifact_ctc_comparison_outputs(
620+
comparison_sim,
621+
sim,
622+
).items():
623+
print("\n" + "=" * 70)
624+
print(section_name)
625+
print("=" * 70)
626+
print(section_output)
627+
628+
for section_name, section_output in get_canonical_ctc_reform_comparison_outputs(
629+
reference_dataset_path=resolved_comparison_dataset_path,
630+
candidate_dataset_path=resolved_dataset_path,
631+
reference_baseline_sim=comparison_sim,
632+
candidate_baseline_sim=sim,
633+
).items():
634+
print("\n" + "=" * 70)
635+
print(section_name)
636+
print("=" * 70)
637+
print(section_output)
638+
420639
print("\n" + "=" * 70)
421640
print("STRUCTURAL CHECKS")
422641
print("=" * 70)

0 commit comments

Comments
 (0)