1+ import h5py
12import yaml
23from importlib .resources import files
34
@@ -432,6 +433,20 @@ def preprocess_puf(puf: pd.DataFrame) -> pd.DataFrame:
432433 0.0
433434 )
434435 puf ["business_is_sstb" ] = rng .binomial (n = 1 , p = pr_sstb )
436+ is_sstb = puf ["business_is_sstb" ].astype (bool )
437+
438+ # The current PUF pipeline only imputes an all-or-nothing SSTB flag.
439+ # Use that to split Schedule C self-employment and allocable W-2/UBIA
440+ # inputs for policyengine-us without pretending to observe mixed cases.
441+ legacy_self_employment_income = puf ["self_employment_income" ].fillna (0 )
442+ puf ["sstb_self_employment_income" ] = np .where (
443+ is_sstb , legacy_self_employment_income , 0.0
444+ )
445+ puf ["self_employment_income" ] = np .where (
446+ is_sstb , 0.0 , legacy_self_employment_income
447+ )
448+ puf ["sstb_w2_wages_from_qualified_business" ] = np .where (is_sstb , w2 , 0.0 )
449+ puf ["sstb_unadjusted_basis_qualified_property" ] = np .where (is_sstb , ubia , 0.0 )
435450
436451 reit_params = QBI_PARAMS ["reit_ptp_income_distribution" ]
437452 p_reit_ptp = reit_params ["probability_of_receiving" ]
@@ -526,6 +541,9 @@ def preprocess_puf(puf: pd.DataFrame) -> pd.DataFrame:
526541 "w2_wages_from_qualified_business" ,
527542 "unadjusted_basis_qualified_property" ,
528543 "business_is_sstb" ,
544+ "sstb_self_employment_income" ,
545+ "sstb_w2_wages_from_qualified_business" ,
546+ "sstb_unadjusted_basis_qualified_property" ,
529547 "deductible_mortgage_interest" ,
530548 "partnership_s_corp_income" ,
531549 "partnership_se_income" ,
@@ -538,6 +556,164 @@ class PUF(Dataset):
538556 time_period = None
539557 data_format = Dataset .ARRAYS
540558
559+ @staticmethod
560+ def _replace_array (file_handle , key : str , values : np .ndarray ) -> None :
561+ if key in file_handle :
562+ del file_handle [key ]
563+ file_handle .create_dataset (key , data = values )
564+
565+ def _sstb_split_overrides (self ) -> dict [str , np .ndarray ]:
566+ if not self .file_path .exists ():
567+ return {}
568+
569+ with h5py .File (self .file_path , "r" ) as file_handle :
570+ if "business_is_sstb" not in file_handle :
571+ return {}
572+ keys = set (file_handle .keys ())
573+ is_sstb = np .asarray (file_handle ["business_is_sstb" ]).astype (bool )
574+ overrides = {}
575+ if "self_employment_income" in keys :
576+ self_employment_income = np .asarray (
577+ file_handle ["self_employment_income" ]
578+ )
579+ existing_sstb_self_employment_income = (
580+ np .asarray (file_handle ["sstb_self_employment_income" ])
581+ if "sstb_self_employment_income" in keys
582+ else np .zeros_like (self_employment_income )
583+ )
584+ corrected_sstb_self_employment_income = np .where (
585+ is_sstb ,
586+ np .where (
587+ existing_sstb_self_employment_income != 0 ,
588+ existing_sstb_self_employment_income ,
589+ self_employment_income ,
590+ ),
591+ 0.0 ,
592+ )
593+ corrected_self_employment_income = np .where (
594+ is_sstb , 0.0 , self_employment_income
595+ )
596+ if (
597+ "sstb_self_employment_income" not in keys
598+ or not np .array_equal (
599+ existing_sstb_self_employment_income ,
600+ corrected_sstb_self_employment_income ,
601+ )
602+ or not np .array_equal (
603+ self_employment_income ,
604+ corrected_self_employment_income ,
605+ )
606+ ):
607+ overrides ["sstb_self_employment_income" ] = (
608+ corrected_sstb_self_employment_income
609+ )
610+ overrides ["self_employment_income" ] = (
611+ corrected_self_employment_income
612+ )
613+
614+ for source_key , target_key in (
615+ (
616+ "w2_wages_from_qualified_business" ,
617+ "sstb_w2_wages_from_qualified_business" ,
618+ ),
619+ (
620+ "unadjusted_basis_qualified_property" ,
621+ "sstb_unadjusted_basis_qualified_property" ,
622+ ),
623+ ):
624+ if source_key not in keys :
625+ continue
626+ corrected_target = np .where (
627+ is_sstb , np .asarray (file_handle [source_key ]), 0.0
628+ )
629+ if target_key not in keys or not np .array_equal (
630+ np .asarray (file_handle [target_key ]),
631+ corrected_target ,
632+ ):
633+ overrides [target_key ] = corrected_target
634+
635+ return overrides
636+
637+ def _ensure_sstb_split_inputs (self ) -> dict [str , np .ndarray ]:
638+ overrides = self ._sstb_split_overrides ()
639+ if not overrides :
640+ return {}
641+
642+ try :
643+ with h5py .File (self .file_path , "r+" ) as file_handle :
644+ for key , values in overrides .items ():
645+ self ._replace_array (file_handle , key , values )
646+ except OSError :
647+ pass
648+
649+ return overrides
650+
651+ class _OverrideView :
652+ def __init__ (self , backing , overrides : dict [str , np .ndarray ]):
653+ self ._backing = backing
654+ self ._overrides = overrides
655+
656+ def __getitem__ (self , key ):
657+ if key in self ._overrides :
658+ return self ._overrides [key ]
659+ return self ._backing [key ]
660+
661+ def __contains__ (self , key ):
662+ return key in self ._overrides or key in self ._backing
663+
664+ def keys (self ):
665+ if hasattr (self ._backing , "keys" ):
666+ return tuple (dict .fromkeys ((* self ._backing .keys (), * self ._overrides )))
667+ return tuple (self ._overrides )
668+
669+ def get (self , key , default = None ):
670+ if key in self :
671+ return self [key ]
672+ return default
673+
674+ def items (self ):
675+ for key in self .keys ():
676+ yield key , self [key ]
677+
678+ def values (self ):
679+ for key in self .keys ():
680+ yield self [key ]
681+
682+ def __iter__ (self ):
683+ return iter (self .keys ())
684+
685+ def close (self ):
686+ if hasattr (self ._backing , "close" ):
687+ self ._backing .close ()
688+
689+ def __enter__ (self ):
690+ if hasattr (self ._backing , "__enter__" ):
691+ self ._backing .__enter__ ()
692+ return self
693+
694+ def __exit__ (self , exc_type , exc , traceback ):
695+ if hasattr (self ._backing , "__exit__" ):
696+ return self ._backing .__exit__ (exc_type , exc , traceback )
697+ return None
698+
699+ def __getattr__ (self , name ):
700+ return getattr (self ._backing , name )
701+
702+ def load (self , key = None , mode = "r" ):
703+ if mode == "r" :
704+ overrides = self ._ensure_sstb_split_inputs ()
705+ if key in overrides :
706+ return overrides [key ]
707+ if key is None and overrides :
708+ return self ._OverrideView (super ().load (key = key , mode = mode ), overrides )
709+ return super ().load (key = key , mode = mode )
710+
711+ def load_dataset (self ):
712+ overrides = self ._ensure_sstb_split_inputs ()
713+ arrays = super ().load_dataset ()
714+ arrays .update (overrides )
715+ return arrays
716+
541717 def generate (self ):
542718 from policyengine_us .system import system
543719
0 commit comments