From 17d2bd5df524734949a1bc7148f44ea65027ad5e Mon Sep 17 00:00:00 2001 From: "baogorek@gmail.com" Date: Tue, 27 Jan 2026 13:26:04 -0500 Subject: [PATCH 1/8] Add full database schema, national targets ETL, and metadata utilities Migrate critical database infrastructure from junkyard repo: - Expand create_database_tables.py with Source, VariableGroup, and VariableMetadata tables, ConstraintOperation enum, and improved definition hash that includes parent_stratum_id - Add etl_national_targets.py for loading ~40 national calibration targets from CBO, Treasury/JCT, CMS, and other federal sources - Add utils/db_metadata.py with get_or_create helpers for sources, variable groups, and variable metadata - Add DATABASE_GUIDE.md documenting schema, stratum groups, ETL patterns, and SQL query examples - Standardize all ETL scripts to use calibration/policy_data.db path - Update Makefile database target to include national targets step Co-Authored-By: Claude Haiku 4.5 --- Makefile | 1 + policyengine_us_data/db/DATABASE_GUIDE.md | 364 ++++++++++ .../db/create_database_tables.py | 244 ++++++- .../db/create_initial_strata.py | 4 +- policyengine_us_data/db/etl_age.py | 4 +- policyengine_us_data/db/etl_irs_soi.py | 4 +- policyengine_us_data/db/etl_medicaid.py | 4 +- .../db/etl_national_targets.py | 647 ++++++++++++++++++ policyengine_us_data/db/etl_snap.py | 8 +- policyengine_us_data/db/validate_database.py | 4 +- policyengine_us_data/utils/db_metadata.py | 147 ++++ 11 files changed, 1400 insertions(+), 31 deletions(-) create mode 100644 policyengine_us_data/db/DATABASE_GUIDE.md create mode 100644 policyengine_us_data/db/etl_national_targets.py create mode 100644 policyengine_us_data/utils/db_metadata.py diff --git a/Makefile b/Makefile index fd212a08e..270717c39 100644 --- a/Makefile +++ b/Makefile @@ -56,6 +56,7 @@ documentation-dev: database: python policyengine_us_data/db/create_database_tables.py python policyengine_us_data/db/create_initial_strata.py + python policyengine_us_data/db/etl_national_targets.py python policyengine_us_data/db/etl_age.py python policyengine_us_data/db/etl_medicaid.py python policyengine_us_data/db/etl_snap.py diff --git a/policyengine_us_data/db/DATABASE_GUIDE.md b/policyengine_us_data/db/DATABASE_GUIDE.md new file mode 100644 index 000000000..93657ef5f --- /dev/null +++ b/policyengine_us_data/db/DATABASE_GUIDE.md @@ -0,0 +1,364 @@ +# PolicyEngine US Data - Database Getting Started Guide + +## Current Task: Matrix Generation for Calibration Targets + +### Objective +Create a comprehensive matrix of calibration targets with the following requirements: +1. **Rows grouped by target type** - All age targets together, all income targets together, etc. +2. **Known counts per group** - Each group has a predictable number of entries (e.g., 18 age groups, 9 income brackets) +3. **Source selection** - Ability to specify which data source to use when multiple exist +4. **Geographic filtering** - Ability to select specific geographic levels (national, state, or congressional district) + +### Implementation Strategy +The `stratum_group_id` field now categorizes strata by conceptual type, making matrix generation straightforward: +- Query by `stratum_group_id` to get all related targets together +- Each demographic group appears consistently across all 488 geographic areas +- Join with `sources` table to filter/identify data provenance +- Use parent-child relationships to navigate geographic hierarchy + +### Example Matrix Query +```sql +-- Generate matrix for a specific geography (e.g., national level) +SELECT + CASE s.stratum_group_id + WHEN 2 THEN 'Age' + WHEN 3 THEN 'Income' + WHEN 4 THEN 'SNAP' + WHEN 5 THEN 'Medicaid' + WHEN 6 THEN 'EITC' + END AS group_name, + s.notes AS stratum_description, + t.variable, + t.value, + src.name AS source +FROM strata s +JOIN targets t ON s.stratum_id = t.stratum_id +JOIN sources src ON t.source_id = src.source_id +WHERE s.parent_stratum_id = 1 -- National level (or any specific geography) + AND s.stratum_group_id > 1 -- Exclude geographic strata +ORDER BY s.stratum_group_id, s.stratum_id; +``` + +## Overview +This database uses a hierarchical stratum-based model to organize US demographic and economic data for PolicyEngine calibration. The core concept is that data is organized into "strata" - population subgroups defined by constraints. + +## Key Concepts + +### Strata Hierarchy +The database uses a parent-child hierarchy: +``` +United States (national) +├── States (51 including DC) +│ ├── Congressional Districts (436 total) +│ │ ├── Age groups (18 brackets per geographic area) +│ │ ├── Income groups (AGI stubs) +│ │ └── Other demographic strata (EITC recipients, SNAP, Medicaid, etc.) +``` + +### Stratum Groups +The `stratum_group_id` field categorizes strata by their conceptual type: +- `1`: Geographic boundaries (US, states, congressional districts) +- `2`: Age-based strata (18 age groups per geography) +- `3`: Income/AGI-based strata (9 income brackets per geography) +- `4`: SNAP recipient strata (1 per geography) +- `5`: Medicaid enrollment strata (1 per geography) +- `6`: EITC recipient strata (4 groups by qualifying children per geography) + +### UCGID Translation +The Census Bureau uses UCGIDs (Universal Census Geographic IDs) in their API responses: +- `0100000US`: National level +- `0400000USXX`: State (XX = state FIPS code) +- `5001800USXXDD`: Congressional district (XX = state FIPS, DD = district number) + +We parse these into our internal model using `state_fips` and `congressional_district_geoid`. + +### Constraint Operations +All constraints use standardized operators: +- `==`: Equals +- `!=`: Not equals +- `>`: Greater than +- `>=`: Greater than or equal +- `<`: Less than +- `<=`: Less than or equal + +## Database Structure + +### Core Tables +1. **strata**: Main table for population subgroups + - `stratum_id`: Primary key + - `parent_stratum_id`: Links to parent in hierarchy + - `stratum_group_id`: Conceptual category (1=Geographic, 2=Age, 3=Income, 4=SNAP, 5=Medicaid, 6=EITC) + - `definition_hash`: Unique hash of constraints for deduplication + +2. **stratum_constraints**: Defines rules for each stratum + - `constraint_variable`: Variable name (e.g., "age", "state_fips") + - `operation`: Comparison operator (==, >, <, etc.) + - `value`: Constraint value + +3. **targets**: Stores actual data values + - `variable`: PolicyEngine US variable name + - `period`: Year + - `value`: Numerical value + - `source_id`: Foreign key to sources table + - `active`: Boolean flag for active/inactive targets + - `tolerance`: Allowed relative error percentage + +### Metadata Tables +4. **sources**: Data source metadata + - `source_id`: Primary key (auto-generated) + - `name`: Source name (e.g., "IRS Statistics of Income") + - `type`: SourceType enum (administrative, survey, hardcoded) + - `vintage`: Year or version of data + - `description`: Detailed description + - `url`: Reference URL + - `notes`: Additional notes + +5. **variable_groups**: Logical groupings of related variables + - `group_id`: Primary key (auto-generated) + - `name`: Unique group name (e.g., "age_distribution", "snap_recipients") + - `category`: High-level category (demographic, benefit, tax, income, expense) + - `is_histogram`: Whether this represents a distribution + - `is_exclusive`: Whether variables are mutually exclusive + - `aggregation_method`: How to aggregate (sum, weighted_avg, etc.) + - `display_order`: Order for display in matrices/reports + - `description`: What this group represents + +6. **variable_metadata**: Display information for variables + - `metadata_id`: Primary key + - `variable`: PolicyEngine variable name + - `group_id`: Foreign key to variable_groups + - `display_name`: Human-readable name + - `display_order`: Order within group + - `units`: Units of measurement (dollars, count, percent) + - `is_primary`: Whether this is a primary vs derived variable + - `notes`: Additional notes + +## Building the Database + +### Step 1: Create Tables +```bash +source ~/envs/sep/bin/activate +cd policyengine_us_data/db +python create_database_tables.py +``` + +### Step 2: Create Geographic Hierarchy +```bash +python create_initial_strata.py +``` +Creates: 1 national + 51 state + 436 congressional district strata + +### Step 3: Load Data (in order) +```bash +# National hardcoded targets +python etl_national_targets.py + +# Age demographics (Census ACS) +python etl_age.py + +# Economic data (IRS SOI) +python etl_irs_soi.py + +# Benefits data +python etl_medicaid.py +python etl_snap.py +``` + +### Step 4: Validate +```bash +python validate_database.py +``` + +Expected output: +- 488 geographic strata +- 8,784 age strata (18 age groups × 488 areas) +- All strata have unique definition hashes + +## Common Utility Functions + +Located in `policyengine_us_data/utils/db.py`: + +- `get_stratum_by_id(session, id)`: Retrieve stratum by ID +- `get_simple_stratum_by_ucgid(session, ucgid)`: Get stratum by UCGID +- `get_root_strata(session)`: Get root strata +- `get_stratum_children(session, id)`: Get child strata +- `get_stratum_parent(session, id)`: Get parent stratum + +Located in `policyengine_us_data/utils/db_metadata.py`: + +- `get_or_create_source(session, ...)`: Get or create a data source +- `get_or_create_variable_group(session, ...)`: Get or create a variable group +- `get_or_create_variable_metadata(session, ...)`: Get or create variable metadata + +## ETL Script Pattern + +Each ETL script follows this pattern: + +1. **Extract**: Pull data from source (Census API, IRS files, etc.) +2. **Transform**: + - Parse UCGIDs to get geographic info + - Map to existing geographic strata + - Create demographic strata as children +3. **Load**: + - Check for existing strata to avoid duplicates + - Add constraints and targets + - Commit to database + +## Important Notes + +### Avoiding Duplicates +Always check if a stratum exists before creating: +```python +existing_stratum = session.exec( + select(Stratum).where( + Stratum.parent_stratum_id == parent_id, + Stratum.stratum_group_id == group_id, + Stratum.notes == note + ) +).first() +``` + +### Geographic Constraints +- National strata: No geographic constraints needed +- State strata: `state_fips` constraint +- District strata: `congressional_district_geoid` constraint + +### Congressional District Normalization +- District 00 → 01 (at-large districts) +- DC district 98 → 01 (delegate district) + +### IRS AGI Ranges +AGI stubs use >= for lower bound, < for upper bound: +- Stub 3: $10,000 <= AGI < $25,000 +- Stub 4: $25,000 <= AGI < $50,000 +- etc. + +## Troubleshooting + +### "WARNING: Expected 8784 age strata, found 16104" +**Status: RESOLVED** + +The validation script was incorrectly counting all demographic strata (stratum_group_id = 0) as age strata. After implementing the new stratum_group_id scheme (1=Geographic, 2=Age, 3=Income, etc.), the validation correctly identifies 8,784 age strata. + +### Fixed: Synthetic Variable Names +Previously, the IRS SOI ETL was creating invalid variable names like `eitc_tax_unit_count` that don't exist in PolicyEngine. Now correctly uses `tax_unit_count` with appropriate stratum constraints to indicate what's being counted. + +### UCGID strings in notes +Legacy UCGID references have been replaced with human-readable identifiers: +- "US" for national +- "State FIPS X" for states +- "CD XXXX" for congressional districts + +### Mixed operation types +All operations now use standardized symbols (==, >, <, etc.) validated by ConstraintOperation enum. + +## Database Location +`policyengine_us_data/storage/calibration/policy_data.db` + +## Example SQLite Queries with Metadata Features + +### Compare Administrative vs Survey Data for SNAP +```sql +SELECT + s.type AS source_type, + s.name AS source_name, + st.notes AS location, + t.value AS household_count +FROM targets t +JOIN sources s ON t.source_id = s.source_id +JOIN strata st ON t.stratum_id = st.stratum_id +WHERE t.variable = 'household_count' + AND st.notes LIKE '%SNAP%' +ORDER BY s.type, st.notes; +``` + +### Get All Variables in a Group with Their Metadata +```sql +SELECT + vm.display_name, + vm.variable, + vm.units, + vm.display_order, + vg.description AS group_description +FROM variable_metadata vm +JOIN variable_groups vg ON vm.group_id = vg.group_id +WHERE vg.name = 'eitc_recipients' +ORDER BY vm.display_order; +``` + +### Query by Stratum Group +```sql +-- Get all age-related strata and their targets +SELECT + s.stratum_id, + s.notes, + t.variable, + t.value, + src.name AS source +FROM strata s +JOIN targets t ON s.stratum_id = t.stratum_id +JOIN sources src ON t.source_id = src.source_id +WHERE s.stratum_group_id = 2 -- Age strata +LIMIT 20; + +-- Count strata by group +SELECT + stratum_group_id, + CASE stratum_group_id + WHEN 1 THEN 'Geographic' + WHEN 2 THEN 'Age' + WHEN 3 THEN 'Income/AGI' + WHEN 4 THEN 'SNAP' + WHEN 5 THEN 'Medicaid' + WHEN 6 THEN 'EITC' + END AS group_name, + COUNT(*) AS stratum_count +FROM strata +GROUP BY stratum_group_id +ORDER BY stratum_group_id; +``` + +## Key Improvements +1. Removed UCGID as a constraint variable (legacy Census concept) +2. Standardized constraint operations with validation +3. Consolidated duplicate code (parse_ucgid, get_geographic_strata) +4. Fixed epsilon hack in IRS AGI ranges +5. Added proper duplicate checking in age ETL +6. Improved human-readable notes without UCGID strings +7. Added metadata tables for sources, variable groups, and variable metadata +8. Fixed synthetic variable name bug (e.g., eitc_tax_unit_count → tax_unit_count) +9. Auto-generated source IDs instead of hardcoding +10. Proper categorization of admin vs survey data for same concepts +11. Implemented conceptual stratum_group_id scheme for better organization and querying + +## Known Issues / TODOs + +### IMPORTANT: stratum_id vs state_fips Codes +**WARNING**: The `stratum_id` is an auto-generated sequential ID and has NO relationship to FIPS codes, despite some confusing coincidences: +- California: stratum_id = 6, state_fips = "06" (coincidental match!) +- North Carolina: stratum_id = 35, state_fips = "37" (no match) +- Ohio: stratum_id = 37, state_fips = "39" (no match) + +When querying for states, ALWAYS use the `state_fips` constraint value, never assume stratum_id matches FIPS. + +Example of correct lookup: +```sql +-- Find North Carolina's stratum_id by FIPS code +SELECT s.stratum_id, s.notes +FROM strata s +JOIN stratum_constraints sc ON s.stratum_id = sc.stratum_id +WHERE sc.constraint_variable = 'state_fips' + AND sc.value = '37'; -- Returns stratum_id = 35 +``` + +### Type Conversion for Constraint Values +**DESIGN DECISION**: The `value` column in `stratum_constraints` must store heterogeneous data types as strings. The calibration code deserializes these: +- Numeric strings → int/float (for age, income constraints) +- "True"/"False" → Python booleans (for medicaid_enrolled, snap_enrolled) +- Other strings remain strings (for state_fips, which may have leading zeros) + +### Medicaid Data Structure +- Medicaid uses `person_count` variable (not `medicaid`) because it's structured as a histogram with constraints +- State-level targets use administrative data (T-MSIS source) +- Congressional district level uses survey data (ACS source) +- No national Medicaid target exists (intentionally, to avoid double-counting when using state-level data) diff --git a/policyengine_us_data/db/create_database_tables.py b/policyengine_us_data/db/create_database_tables.py index 920d1449e..4b526f7ef 100644 --- a/policyengine_us_data/db/create_database_tables.py +++ b/policyengine_us_data/db/create_database_tables.py @@ -11,6 +11,7 @@ SQLModel, create_engine, ) +from pydantic import validator from policyengine_us.system import system from policyengine_us_data.storage import STORAGE_FOLDER @@ -29,6 +30,17 @@ ) +class ConstraintOperation(str, Enum): + """Allowed operations for stratum constraints.""" + + EQ = "==" # Equals + NE = "!=" # Not equals + GT = ">" # Greater than + GE = ">=" # Greater than or equal + LT = "<" # Less than + LE = "<=" # Less than or equal + + class Stratum(SQLModel, table=True): """Represents a unique population subgroup (stratum).""" @@ -52,13 +64,15 @@ class Stratum(SQLModel, table=True): default=None, foreign_key="strata.stratum_id", index=True, - description="Identifier for a parent stratum, creating a hierarchy.", + description=("Identifier for a parent stratum, creating a hierarchy."), ) stratum_group_id: Optional[int] = Field( - default=None, description="Identifier for a group of related strata." + default=None, + description="Identifier for a group of related strata.", ) notes: Optional[str] = Field( - default=None, description="Descriptive notes about the stratum." + default=None, + description="Descriptive notes about the stratum.", ) children_rel: List["Stratum"] = Relationship( @@ -88,23 +102,35 @@ class StratumConstraint(SQLModel, table=True): __tablename__ = "stratum_constraints" stratum_id: int = Field(foreign_key="strata.stratum_id", primary_key=True) - constraint_variable: USVariable = Field( + constraint_variable: str = Field( primary_key=True, - description="The variable the constraint applies to (e.g., 'age').", + description=("The variable the constraint applies to (e.g., 'age')."), ) operation: str = Field( primary_key=True, - description="The comparison operator (e.g., 'greater_than_or_equal').", + description=("The comparison operator (==, !=, >, >=, <, <=)."), ) value: str = Field( description="The value for the constraint rule (e.g., '25')." ) notes: Optional[str] = Field( - default=None, description="Optional notes about the constraint." + default=None, + description="Optional notes about the constraint.", ) strata_rel: Stratum = Relationship(back_populates="constraints_rel") + @validator("operation") + def validate_operation(cls, v): + """Validate that the operation is one of the allowed values.""" + allowed_ops = [op.value for op in ConstraintOperation] + if v not in allowed_ops: + raise ValueError( + f"Invalid operation '{v}'. " + f"Must be one of: {', '.join(allowed_ops)}" + ) + return v + class Target(SQLModel, table=True): """Stores the data values for a specific stratum.""" @@ -122,7 +148,9 @@ class Target(SQLModel, table=True): target_id: Optional[int] = Field(default=None, primary_key=True) variable: USVariable = Field( - description="A variable defined in policyengine-us (e.g., 'income_tax')." + description=( + "A variable defined in policyengine-us " "(e.g., 'income_tax')." + ), ) period: int = Field( description="The time period for the data, typically a year." @@ -130,21 +158,28 @@ class Target(SQLModel, table=True): stratum_id: int = Field(foreign_key="strata.stratum_id", index=True) reform_id: int = Field( default=0, - description="Identifier for a policy reform scenario (0 for baseline).", + description=( + "Identifier for a policy reform scenario " "(0 for baseline)." + ), ) value: Optional[float] = Field( - default=None, description="The numerical value of the target variable." + default=None, + description="The numerical value of the target variable.", ) source_id: Optional[int] = Field( - default=None, description="Identifier for the data source." + default=None, + foreign_key="sources.source_id", + description="Identifier for the data source.", ) active: bool = Field( default=True, - description="Flag to indicate if the record is currently active.", + description=("Flag to indicate if the record is currently active."), ) tolerance: Optional[float] = Field( default=None, - description="Allowed relative error as a percent (e.g., 25 for 25%).", + description=( + "Allowed relative error as a percent " "(e.g., 25 for 25%)." + ), ) notes: Optional[str] = Field( default=None, @@ -152,23 +187,179 @@ class Target(SQLModel, table=True): ) strata_rel: Stratum = Relationship(back_populates="targets_rel") + source_rel: Optional["Source"] = Relationship() + + +class SourceType(str, Enum): + """Types of data sources.""" + + ADMINISTRATIVE = "administrative" + SURVEY = "survey" + SYNTHETIC = "synthetic" + DERIVED = "derived" + HARDCODED = "hardcoded" + + +class Source(SQLModel, table=True): + """Metadata about data sources.""" + + __tablename__ = "sources" + __table_args__ = ( + UniqueConstraint("name", "vintage", name="uq_source_name_vintage"), + ) + + source_id: Optional[int] = Field( + default=None, + primary_key=True, + description="Unique identifier for the data source.", + ) + name: str = Field( + description=( + "Name of the data source " "(e.g., 'IRS SOI', 'Census ACS')." + ), + index=True, + ) + type: SourceType = Field( + description=("Type of data source (administrative, survey, etc.)."), + ) + description: Optional[str] = Field( + default=None, + description="Detailed description of the data source.", + ) + url: Optional[str] = Field( + default=None, + description=("URL or reference to the original data source."), + ) + vintage: Optional[str] = Field( + default=None, + description="Version or release date of the data source.", + ) + notes: Optional[str] = Field( + default=None, + description="Additional notes about the source.", + ) + + +class VariableGroup(SQLModel, table=True): + """Groups of related variables that form logical units.""" + + __tablename__ = "variable_groups" + + group_id: Optional[int] = Field( + default=None, + primary_key=True, + description="Unique identifier for the variable group.", + ) + name: str = Field( + description=( + "Name of the variable group " + "(e.g., 'age_distribution', 'snap_recipients')." + ), + index=True, + unique=True, + ) + category: str = Field( + description=( + "High-level category " + "(e.g., 'demographic', 'benefit', 'tax', 'income')." + ), + index=True, + ) + is_histogram: bool = Field( + default=False, + description=( + "Whether this group represents a " "histogram/distribution." + ), + ) + is_exclusive: bool = Field( + default=False, + description=( + "Whether variables in this group are " "mutually exclusive." + ), + ) + aggregation_method: Optional[str] = Field( + default=None, + description=( + "How to aggregate variables in this group " + "(sum, weighted_avg, etc.)." + ), + ) + display_order: Optional[int] = Field( + default=None, + description=( + "Order for displaying this group in " "matrices/reports." + ), + ) + description: Optional[str] = Field( + default=None, + description="Description of what this group represents.", + ) + + +class VariableMetadata(SQLModel, table=True): + """Maps PolicyEngine variables to their groups and provides + metadata.""" + + __tablename__ = "variable_metadata" + __table_args__ = ( + UniqueConstraint("variable", name="uq_variable_metadata_variable"), + ) + + metadata_id: Optional[int] = Field(default=None, primary_key=True) + variable: str = Field( + description="PolicyEngine variable name.", index=True + ) + group_id: Optional[int] = Field( + default=None, + foreign_key="variable_groups.group_id", + description=("ID of the variable group this belongs to."), + ) + display_name: Optional[str] = Field( + default=None, + description=("Human-readable name for display in matrices."), + ) + display_order: Optional[int] = Field( + default=None, + description=("Order within its group for display purposes."), + ) + units: Optional[str] = Field( + default=None, + description=( + "Units of measurement " "(dollars, count, percent, etc.)." + ), + ) + is_primary: bool = Field( + default=True, + description=( + "Whether this is a primary variable vs " "derived/auxiliary." + ), + ) + notes: Optional[str] = Field( + default=None, + description="Additional notes about the variable.", + ) + + group_rel: Optional[VariableGroup] = Relationship() -# This SQLAlchemy event listener works directly with the SQLModel class @event.listens_for(Stratum, "before_insert") @event.listens_for(Stratum, "before_update") def calculate_definition_hash(mapper, connection, target: Stratum): - """ - Calculate and set the definition_hash before saving a Stratum instance. - """ + """Calculate and set the definition_hash before saving a + Stratum instance.""" constraints_history = get_history(target, "constraints_rel") if not ( constraints_history.has_changes() or target.definition_hash is None ): return - if not target.constraints_rel: # Handle cases with no constraints - target.definition_hash = hashlib.sha256(b"").hexdigest() + if not target.constraints_rel: + parent_str = ( + str(target.parent_stratum_id) if target.parent_stratum_id else "" + ) + target.definition_hash = hashlib.sha256( + parent_str.encode("utf-8") + ).hexdigest() return constraint_strings = [ @@ -177,19 +368,24 @@ def calculate_definition_hash(mapper, connection, target: Stratum): ] constraint_strings.sort() - fingerprint_text = "\n".join(constraint_strings) + parent_str = ( + str(target.parent_stratum_id) if target.parent_stratum_id else "" + ) + fingerprint_text = parent_str + "\n" + "\n".join(constraint_strings) h = hashlib.sha256(fingerprint_text.encode("utf-8")) target.definition_hash = h.hexdigest() +DB_PATH = STORAGE_FOLDER / "calibration" / "policy_data.db" + + def create_database( - db_uri: str = f"sqlite:///{STORAGE_FOLDER / 'policy_data.db'}", + db_uri: str = f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}", ): - """ - Creates a SQLite database and all the defined tables. + """Creates a SQLite database and all the defined tables. Args: - db_uri (str): The connection string for the database. + db_uri: The connection string for the database. Returns: An SQLAlchemy Engine instance connected to the database. diff --git a/policyengine_us_data/db/create_initial_strata.py b/policyengine_us_data/db/create_initial_strata.py index 5653948bc..1d7d3b4b2 100644 --- a/policyengine_us_data/db/create_initial_strata.py +++ b/policyengine_us_data/db/create_initial_strata.py @@ -34,7 +34,9 @@ def main(): .reset_index(drop=True) ) - DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'policy_data.db'}" + DATABASE_URL = ( + f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" + ) engine = create_engine(DATABASE_URL) # map the ucgid_str 'code' to auto-generated 'stratum_id' diff --git a/policyengine_us_data/db/etl_age.py b/policyengine_us_data/db/etl_age.py index d80faf065..9ce8f8a17 100644 --- a/policyengine_us_data/db/etl_age.py +++ b/policyengine_us_data/db/etl_age.py @@ -104,7 +104,9 @@ def load_age_data(df_long, geo, year, stratum_lookup=None): raise ValueError('geo must be one of "National", "State", "District"') # Prepare to load data ----------- - DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'policy_data.db'}" + DATABASE_URL = ( + f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" + ) engine = create_engine(DATABASE_URL) if stratum_lookup is None: diff --git a/policyengine_us_data/db/etl_irs_soi.py b/policyengine_us_data/db/etl_irs_soi.py index 6607a5dd6..879bd9a20 100644 --- a/policyengine_us_data/db/etl_irs_soi.py +++ b/policyengine_us_data/db/etl_irs_soi.py @@ -285,7 +285,9 @@ def transform_soi_data(raw_df): def load_soi_data(long_dfs, year): """Load a list of databases into the db, critically dependent on order""" - DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'policy_data.db'}" + DATABASE_URL = ( + f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" + ) engine = create_engine(DATABASE_URL) session = Session(engine) diff --git a/policyengine_us_data/db/etl_medicaid.py b/policyengine_us_data/db/etl_medicaid.py index 926a0d88c..d420edd0d 100644 --- a/policyengine_us_data/db/etl_medicaid.py +++ b/policyengine_us_data/db/etl_medicaid.py @@ -85,7 +85,9 @@ def transform_medicaid_data(state_admin_df, cd_survey_df, year): def load_medicaid_data(long_state, long_cd, year): - DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'policy_data.db'}" + DATABASE_URL = ( + f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" + ) engine = create_engine(DATABASE_URL) stratum_lookup = {} diff --git a/policyengine_us_data/db/etl_national_targets.py b/policyengine_us_data/db/etl_national_targets.py new file mode 100644 index 000000000..0a03add3d --- /dev/null +++ b/policyengine_us_data/db/etl_national_targets.py @@ -0,0 +1,647 @@ +from sqlmodel import Session, create_engine +import pandas as pd + +from policyengine_us_data.storage import STORAGE_FOLDER +from policyengine_us_data.db.create_database_tables import ( + Stratum, + StratumConstraint, + Target, + SourceType, +) +from policyengine_us_data.utils.db_metadata import ( + get_or_create_source, +) + + +def extract_national_targets(): + """ + Extract national calibration targets from various sources. + + Returns + ------- + dict + Dictionary containing: + - direct_sum_targets: Variables that can be summed directly + - tax_filer_targets: Tax-related variables requiring filer + constraint + - conditional_count_targets: Enrollment counts requiring + constraints + - cbo_targets: List of CBO projection targets + - treasury_targets: List of Treasury/JCT targets + """ + + from policyengine_us import Microsimulation + + sim = Microsimulation( + dataset="hf://policyengine/policyengine-us-data/cps_2023.h5" + ) + + HARDCODED_YEAR = 2024 + + tax_filer_targets = [ + { + "variable": "salt_deduction", + "value": 21.247e9, + "source": "Joint Committee on Taxation", + "notes": "SALT deduction tax expenditure", + "year": HARDCODED_YEAR, + }, + { + "variable": "medical_expense_deduction", + "value": 11.4e9, + "source": "Joint Committee on Taxation", + "notes": "Medical expense deduction tax expenditure", + "year": HARDCODED_YEAR, + }, + { + "variable": "charitable_deduction", + "value": 65.301e9, + "source": "Joint Committee on Taxation", + "notes": "Charitable deduction tax expenditure", + "year": HARDCODED_YEAR, + }, + { + "variable": "interest_deduction", + "value": 24.8e9, + "source": "Joint Committee on Taxation", + "notes": "Mortgage interest deduction tax expenditure", + "year": HARDCODED_YEAR, + }, + { + "variable": "qualified_business_income_deduction", + "value": 63.1e9, + "source": "Joint Committee on Taxation", + "notes": "QBI deduction tax expenditure", + "year": HARDCODED_YEAR, + }, + ] + + direct_sum_targets = [ + { + "variable": "alimony_income", + "value": 13e9, + "source": "Survey-reported (post-TCJA grandfathered)", + "notes": "Alimony received - survey reported, " + "not tax-filer restricted", + "year": HARDCODED_YEAR, + }, + { + "variable": "alimony_expense", + "value": 13e9, + "source": "Survey-reported (post-TCJA grandfathered)", + "notes": "Alimony paid - survey reported, " + "not tax-filer restricted", + "year": HARDCODED_YEAR, + }, + { + "variable": "medicaid", + "value": 871.7e9, + "source": "https://www.cms.gov/files/document/" "highlights.pdf", + "notes": "CMS 2023 highlights document - " + "total Medicaid spending", + "year": HARDCODED_YEAR, + }, + { + "variable": "net_worth", + "value": 160e12, + "source": "Federal Reserve SCF", + "notes": "Total household net worth", + "year": HARDCODED_YEAR, + }, + { + "variable": "health_insurance_premiums_without_" "medicare_part_b", + "value": 385e9, + "source": "MEPS/NHEA", + "notes": "Health insurance premiums excluding " "Medicare Part B", + "year": HARDCODED_YEAR, + }, + { + "variable": "other_medical_expenses", + "value": 278e9, + "source": "MEPS/NHEA", + "notes": "Out-of-pocket medical expenses", + "year": HARDCODED_YEAR, + }, + { + "variable": "medicare_part_b_premiums", + "value": 112e9, + "source": "CMS Medicare data", + "notes": "Medicare Part B premium payments", + "year": HARDCODED_YEAR, + }, + { + "variable": "over_the_counter_health_expenses", + "value": 72e9, + "source": "Consumer Expenditure Survey", + "notes": "OTC health products and supplies", + "year": HARDCODED_YEAR, + }, + { + "variable": "child_support_expense", + "value": 33e9, + "source": "Census Bureau", + "notes": "Child support payments", + "year": HARDCODED_YEAR, + }, + { + "variable": "child_support_received", + "value": 33e9, + "source": "Census Bureau", + "notes": "Child support received", + "year": HARDCODED_YEAR, + }, + { + "variable": "spm_unit_capped_work_childcare_expenses", + "value": 348e9, + "source": "Census Bureau SPM", + "notes": "Work and childcare expenses for SPM", + "year": HARDCODED_YEAR, + }, + { + "variable": "spm_unit_capped_housing_subsidy", + "value": 35e9, + "source": "HUD/Census", + "notes": "Housing subsidies", + "year": HARDCODED_YEAR, + }, + { + "variable": "tanf", + "value": 9e9, + "source": "HHS/ACF", + "notes": "TANF cash assistance", + "year": HARDCODED_YEAR, + }, + { + "variable": "real_estate_taxes", + "value": 500e9, + "source": "Census Bureau", + "notes": "Property taxes paid", + "year": HARDCODED_YEAR, + }, + { + "variable": "rent", + "value": 735e9, + "source": "Census Bureau/BLS", + "notes": "Rental payments", + "year": HARDCODED_YEAR, + }, + { + "variable": "tip_income", + "value": 53.2e9, + "source": "IRS Form W-2 Box 7 statistics", + "notes": "Social security tips uprated 40% to account " + "for underreporting", + "year": HARDCODED_YEAR, + }, + ] + + conditional_count_targets = [ + { + "constraint_variable": "medicaid", + "stratum_group_id": 5, + "person_count": 72_429_055, + "source": "CMS/HHS administrative data", + "notes": "Medicaid enrollment count", + "year": HARDCODED_YEAR, + }, + { + "constraint_variable": "aca_ptc", + "stratum_group_id": None, + "person_count": 19_743_689, + "source": "CMS marketplace data", + "notes": "ACA Premium Tax Credit recipients", + "year": HARDCODED_YEAR, + }, + ] + + ssn_none_targets_by_year = [ + { + "constraint_variable": "ssn_card_type", + "constraint_value": "NONE", + "stratum_group_id": 7, + "person_count": 11.0e6, + "source": "DHS Office of Homeland Security Statistics", + "notes": "Undocumented population estimate " "for Jan 1, 2022", + "year": 2022, + }, + { + "constraint_variable": "ssn_card_type", + "constraint_value": "NONE", + "stratum_group_id": 7, + "person_count": 12.2e6, + "source": "Center for Migration Studies " + "ACS-based residual estimate", + "notes": "Undocumented population estimate " + "(published May 2025)", + "year": 2023, + }, + { + "constraint_variable": "ssn_card_type", + "constraint_value": "NONE", + "stratum_group_id": 7, + "person_count": 13.0e6, + "source": "Reuters synthesis of experts", + "notes": "Undocumented population central estimate " + "(~13-14 million)", + "year": 2024, + }, + { + "constraint_variable": "ssn_card_type", + "constraint_value": "NONE", + "stratum_group_id": 7, + "person_count": 13.0e6, + "source": "Reuters synthesis of experts", + "notes": "Same midpoint carried forward - " + "CBP data show 95% drop in border apprehensions", + "year": 2025, + }, + ] + + conditional_count_targets.extend(ssn_none_targets_by_year) + + CBO_YEAR = 2023 + cbo_vars = [ + "income_tax", + "snap", + "social_security", + "ssi", + "unemployment_compensation", + ] + + cbo_targets = [] + for variable_name in cbo_vars: + try: + value = sim.tax_benefit_system.parameters( + CBO_YEAR + ).calibration.gov.cbo._children[variable_name] + cbo_targets.append( + { + "variable": variable_name, + "value": float(value), + "source": "CBO Budget Projections", + "notes": f"CBO projection for {variable_name}", + "year": CBO_YEAR, + } + ) + except (KeyError, AttributeError) as e: + print( + f"Warning: Could not extract CBO parameter " + f"for {variable_name}: {e}" + ) + + TREASURY_YEAR = 2023 + try: + eitc_value = sim.tax_benefit_system.parameters.calibration.gov.treasury.tax_expenditures.eitc( + TREASURY_YEAR + ) + treasury_targets = [ + { + "variable": "eitc", + "value": float(eitc_value), + "source": "Treasury/JCT Tax Expenditures", + "notes": "EITC tax expenditure", + "year": TREASURY_YEAR, + } + ] + except (KeyError, AttributeError) as e: + print(f"Warning: Could not extract Treasury EITC " f"parameter: {e}") + treasury_targets = [] + + return { + "direct_sum_targets": direct_sum_targets, + "tax_filer_targets": tax_filer_targets, + "conditional_count_targets": conditional_count_targets, + "cbo_targets": cbo_targets, + "treasury_targets": treasury_targets, + } + + +def transform_national_targets(raw_targets): + """ + Transform extracted targets into standardized format. + + Parameters + ---------- + raw_targets : dict + Dictionary from extract_national_targets() + + Returns + ------- + tuple + (direct_targets_df, tax_filer_df, conditional_targets) + """ + cbo_non_tax = [ + t for t in raw_targets["cbo_targets"] if t["variable"] != "income_tax" + ] + cbo_tax = [ + t for t in raw_targets["cbo_targets"] if t["variable"] == "income_tax" + ] + + all_direct_targets = raw_targets["direct_sum_targets"] + cbo_non_tax + + all_tax_filer_targets = ( + raw_targets["tax_filer_targets"] + + cbo_tax + + raw_targets["treasury_targets"] + ) + + direct_df = ( + pd.DataFrame(all_direct_targets) + if all_direct_targets + else pd.DataFrame() + ) + tax_filer_df = ( + pd.DataFrame(all_tax_filer_targets) + if all_tax_filer_targets + else pd.DataFrame() + ) + + conditional_targets = raw_targets["conditional_count_targets"] + + return direct_df, tax_filer_df, conditional_targets + + +def load_national_targets( + direct_targets_df, tax_filer_df, conditional_targets +): + """ + Load national targets into the database. + + Parameters + ---------- + direct_targets_df : pd.DataFrame + DataFrame with direct sum target data + tax_filer_df : pd.DataFrame + DataFrame with tax-related targets needing filer constraint + conditional_targets : list + List of conditional count targets requiring strata + """ + + DATABASE_URL = ( + f"sqlite:///" f"{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" + ) + engine = create_engine(DATABASE_URL) + + with Session(engine) as session: + calibration_source = get_or_create_source( + session, + name="PolicyEngine Calibration Targets", + source_type=SourceType.HARDCODED, + vintage="Mixed (2023-2024)", + description="National calibration targets from " + "various authoritative sources", + url=None, + notes="Aggregated from CMS, IRS, CBO, Treasury, " + "and other federal sources", + ) + + us_stratum = ( + session.query(Stratum) + .filter(Stratum.parent_stratum_id == None) # noqa: E711 + .first() + ) + + if not us_stratum: + raise ValueError( + "National stratum not found. " + "Run create_initial_strata.py first." + ) + + for _, target_data in direct_targets_df.iterrows(): + target_year = target_data["year"] + existing_target = ( + session.query(Target) + .filter( + Target.stratum_id == us_stratum.stratum_id, + Target.variable == target_data["variable"], + Target.period == target_year, + ) + .first() + ) + + notes_parts = [] + if pd.notna(target_data.get("notes")): + notes_parts.append(target_data["notes"]) + notes_parts.append( + f"Source: {target_data.get('source', 'Unknown')}" + ) + combined_notes = " | ".join(notes_parts) + + if existing_target: + existing_target.value = target_data["value"] + existing_target.notes = combined_notes + print(f"Updated target: {target_data['variable']}") + else: + target = Target( + stratum_id=us_stratum.stratum_id, + variable=target_data["variable"], + period=target_year, + value=target_data["value"], + source_id=calibration_source.source_id, + active=True, + notes=combined_notes, + ) + session.add(target) + print(f"Added target: {target_data['variable']}") + + if not tax_filer_df.empty: + national_filer_stratum = ( + session.query(Stratum) + .filter( + Stratum.parent_stratum_id == us_stratum.stratum_id, + Stratum.notes == "United States - Tax Filers", + ) + .first() + ) + + if not national_filer_stratum: + national_filer_stratum = Stratum( + parent_stratum_id=us_stratum.stratum_id, + stratum_group_id=2, + notes="United States - Tax Filers", + ) + national_filer_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable="tax_unit_is_filer", + operation="==", + value="1", + ) + ] + session.add(national_filer_stratum) + session.flush() + print("Created national filer stratum") + + for _, target_data in tax_filer_df.iterrows(): + target_year = target_data["year"] + existing_target = ( + session.query(Target) + .filter( + Target.stratum_id == national_filer_stratum.stratum_id, + Target.variable == target_data["variable"], + Target.period == target_year, + ) + .first() + ) + + notes_parts = [] + if pd.notna(target_data.get("notes")): + notes_parts.append(target_data["notes"]) + notes_parts.append( + f"Source: " f"{target_data.get('source', 'Unknown')}" + ) + combined_notes = " | ".join(notes_parts) + + if existing_target: + existing_target.value = target_data["value"] + existing_target.notes = combined_notes + print( + f"Updated filer target: " f"{target_data['variable']}" + ) + else: + target = Target( + stratum_id=(national_filer_stratum.stratum_id), + variable=target_data["variable"], + period=target_year, + value=target_data["value"], + source_id=calibration_source.source_id, + active=True, + notes=combined_notes, + ) + session.add(target) + print(f"Added filer target: " f"{target_data['variable']}") + + for cond_target in conditional_targets: + constraint_var = cond_target["constraint_variable"] + stratum_group_id = cond_target.get("stratum_group_id") + target_year = cond_target["year"] + + if constraint_var == "medicaid": + stratum_group_id = 5 + stratum_notes = "National Medicaid Enrollment" + constraint_operation = ">" + constraint_value = "0" + elif constraint_var == "aca_ptc": + stratum_group_id = 6 + stratum_notes = "National ACA Premium Tax Credit Recipients" + constraint_operation = ">" + constraint_value = "0" + elif constraint_var == "ssn_card_type": + stratum_group_id = 7 + stratum_notes = "National Undocumented Population" + constraint_operation = "=" + constraint_value = cond_target.get("constraint_value", "NONE") + else: + stratum_notes = f"National {constraint_var} Recipients" + constraint_operation = ">" + constraint_value = "0" + + existing_stratum = ( + session.query(Stratum) + .filter( + Stratum.parent_stratum_id == us_stratum.stratum_id, + Stratum.stratum_group_id == stratum_group_id, + Stratum.notes == stratum_notes, + ) + .first() + ) + + if existing_stratum: + existing_target = ( + session.query(Target) + .filter( + Target.stratum_id == existing_stratum.stratum_id, + Target.variable == "person_count", + Target.period == target_year, + ) + .first() + ) + + if existing_target: + existing_target.value = cond_target["person_count"] + print( + f"Updated enrollment target " f"for {constraint_var}" + ) + else: + new_target = Target( + stratum_id=existing_stratum.stratum_id, + variable="person_count", + period=target_year, + value=cond_target["person_count"], + source_id=calibration_source.source_id, + active=True, + notes=( + f"{cond_target['notes']} | " + f"Source: {cond_target['source']}" + ), + ) + session.add(new_target) + print(f"Added enrollment target " f"for {constraint_var}") + else: + new_stratum = Stratum( + parent_stratum_id=us_stratum.stratum_id, + stratum_group_id=stratum_group_id, + notes=stratum_notes, + ) + + new_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable=constraint_var, + operation=constraint_operation, + value=constraint_value, + ) + ] + + new_stratum.targets_rel = [ + Target( + variable="person_count", + period=target_year, + value=cond_target["person_count"], + source_id=calibration_source.source_id, + active=True, + notes=( + f"{cond_target['notes']} | " + f"Source: {cond_target['source']}" + ), + ) + ] + + session.add(new_stratum) + print( + f"Created stratum and target " + f"for {constraint_var} enrollment" + ) + + session.commit() + + total_targets = ( + len(direct_targets_df) + + len(tax_filer_df) + + len(conditional_targets) + ) + print(f"\nSuccessfully loaded {total_targets} " f"national targets") + print(f" - {len(direct_targets_df)} direct sum targets") + print(f" - {len(tax_filer_df)} tax filer targets") + print( + f" - {len(conditional_targets)} enrollment count " + f"targets (as strata)" + ) + + +def main(): + """Main ETL pipeline for national targets.""" + print("Extracting national targets...") + raw_targets = extract_national_targets() + + print("Transforming targets...") + direct_targets_df, tax_filer_df, conditional_targets = ( + transform_national_targets(raw_targets) + ) + + print("Loading targets into database...") + load_national_targets(direct_targets_df, tax_filer_df, conditional_targets) + + print("\nETL pipeline complete!") + + +if __name__ == "__main__": + main() diff --git a/policyengine_us_data/db/etl_snap.py b/policyengine_us_data/db/etl_snap.py index 1fba44a46..7f73c2b78 100644 --- a/policyengine_us_data/db/etl_snap.py +++ b/policyengine_us_data/db/etl_snap.py @@ -146,7 +146,9 @@ def transform_survey_snap_data(raw_df): def load_administrative_snap_data(df_states, year): - DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'policy_data.db'}" + DATABASE_URL = ( + f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" + ) engine = create_engine(DATABASE_URL) stratum_lookup = {} @@ -234,7 +236,9 @@ def load_survey_snap_data(survey_df, year, stratum_lookup=None): if stratum_lookup is None: raise ValueError("stratum_lookup must be provided") - DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'policy_data.db'}" + DATABASE_URL = ( + f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" + ) engine = create_engine(DATABASE_URL) with Session(engine) as session: diff --git a/policyengine_us_data/db/validate_database.py b/policyengine_us_data/db/validate_database.py index 53ac09852..2fa819f29 100644 --- a/policyengine_us_data/db/validate_database.py +++ b/policyengine_us_data/db/validate_database.py @@ -9,7 +9,9 @@ import pandas as pd from policyengine_us.system import system -conn = sqlite3.connect("policyengine_us_data/storage/policy_data.db") +conn = sqlite3.connect( + "policyengine_us_data/storage/calibration/policy_data.db" +) stratum_constraints_df = pd.read_sql("SELECT * FROM stratum_constraints", conn) targets_df = pd.read_sql("SELECT * FROM targets", conn) diff --git a/policyengine_us_data/utils/db_metadata.py b/policyengine_us_data/utils/db_metadata.py new file mode 100644 index 000000000..b3e63ebed --- /dev/null +++ b/policyengine_us_data/utils/db_metadata.py @@ -0,0 +1,147 @@ +""" +Utility functions for managing database metadata +(sources, variable groups, etc.) +""" + +from typing import Optional +from sqlmodel import Session, select +from policyengine_us_data.db.create_database_tables import ( + Source, + SourceType, + VariableGroup, + VariableMetadata, +) + + +def get_or_create_source( + session: Session, + name: str, + source_type: SourceType, + vintage: Optional[str] = None, + description: Optional[str] = None, + url: Optional[str] = None, + notes: Optional[str] = None, +) -> Source: + """Get an existing source or create a new one. + + Args: + session: Database session + name: Name of the data source + source_type: Type of source (administrative, survey, etc.) + vintage: Version or year of the data + description: Detailed description + url: Reference URL + notes: Additional notes + + Returns: + Source object with source_id populated + """ + query = select(Source).where(Source.name == name) + if vintage: + query = query.where(Source.vintage == vintage) + + source = session.exec(query).first() + + if not source: + source = Source( + name=name, + type=source_type, + vintage=vintage, + description=description, + url=url, + notes=notes, + ) + session.add(source) + session.flush() + + return source + + +def get_or_create_variable_group( + session: Session, + name: str, + category: str, + is_histogram: bool = False, + is_exclusive: bool = False, + aggregation_method: Optional[str] = None, + display_order: Optional[int] = None, + description: Optional[str] = None, +) -> VariableGroup: + """Get an existing variable group or create a new one. + + Args: + session: Database session + name: Unique name of the variable group + category: High-level category + is_histogram: Whether this represents a distribution + is_exclusive: Whether variables are mutually exclusive + aggregation_method: How to aggregate + display_order: Order for display + description: Description of the group + + Returns: + VariableGroup object with group_id populated + """ + group = session.exec( + select(VariableGroup).where(VariableGroup.name == name) + ).first() + + if not group: + group = VariableGroup( + name=name, + category=category, + is_histogram=is_histogram, + is_exclusive=is_exclusive, + aggregation_method=aggregation_method, + display_order=display_order, + description=description, + ) + session.add(group) + session.flush() + + return group + + +def get_or_create_variable_metadata( + session: Session, + variable: str, + group: Optional[VariableGroup] = None, + display_name: Optional[str] = None, + display_order: Optional[int] = None, + units: Optional[str] = None, + is_primary: bool = True, + notes: Optional[str] = None, +) -> VariableMetadata: + """Get existing variable metadata or create new. + + Args: + session: Database session + variable: PolicyEngine variable name + group: Variable group this belongs to + display_name: Human-readable name + display_order: Order within group + units: Units of measurement + is_primary: Whether this is a primary variable + notes: Additional notes + + Returns: + VariableMetadata object + """ + metadata = session.exec( + select(VariableMetadata).where(VariableMetadata.variable == variable) + ).first() + + if not metadata: + metadata = VariableMetadata( + variable=variable, + group_id=group.group_id if group else None, + display_name=display_name or variable, + display_order=display_order, + units=units, + is_primary=is_primary, + notes=notes, + ) + session.add(metadata) + session.flush() + + return metadata From ecb2f4edffeaccf90a29d5dbac59e16859ff5796 Mon Sep 17 00:00:00 2001 From: "baogorek@gmail.com" Date: Tue, 27 Jan 2026 14:03:37 -0500 Subject: [PATCH 2/8] Add parse_ucgid and get_geographic_strata to utils/db.py These functions were present in the junkyard repo but missing from the SEP version. Required by ETL scripts like etl_medicaid.py. Co-Authored-By: Claude Haiku 4.5 --- policyengine_us_data/utils/db.py | 77 +++++++++++++++++++++++++++++++- 1 file changed, 76 insertions(+), 1 deletion(-) diff --git a/policyengine_us_data/utils/db.py b/policyengine_us_data/utils/db.py index a8081db4e..6c7b1a4ed 100644 --- a/policyengine_us_data/utils/db.py +++ b/policyengine_us_data/utils/db.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Dict, List, Optional from sqlmodel import Session, select import sqlalchemy as sa @@ -66,3 +66,78 @@ def get_stratum_parent(session: Session, stratum_id: int) -> Optional[Stratum]: if child_stratum: return child_stratum.parent_rel return None + + +def parse_ucgid(ucgid_str: str) -> Dict: + """Parse UCGID string to extract geographic information. + + Returns: + dict with keys: 'type' ('national', 'state', 'district'), + 'state_fips' (if applicable), + 'district_number' (if applicable), + 'congressional_district_geoid' (if applicable) + """ + if ucgid_str == "0100000US": + return {"type": "national"} + elif ucgid_str.startswith("0400000US"): + state_fips = int(ucgid_str[9:]) + return {"type": "state", "state_fips": state_fips} + elif ucgid_str.startswith("5001800US"): + state_and_district = ucgid_str[9:] + state_fips = int(state_and_district[:2]) + district_number = int(state_and_district[2:]) + if district_number == 0 or ( + state_fips == 11 and district_number == 98 + ): + district_number = 1 + cd_geoid = state_fips * 100 + district_number + return { + "type": "district", + "state_fips": state_fips, + "district_number": district_number, + "congressional_district_geoid": cd_geoid, + } + else: + raise ValueError(f"Unknown UCGID format: {ucgid_str}") + + +def get_geographic_strata(session: Session) -> Dict: + """Fetch existing geographic strata from database. + + Returns: + dict mapping: + - 'national': stratum_id for US + - 'state': {state_fips: stratum_id} + - 'district': {congressional_district_geoid: stratum_id} + """ + strata_map = { + "national": None, + "state": {}, + "district": {}, + } + + stmt = select(Stratum).where(Stratum.stratum_group_id == 1) + geographic_strata = session.exec(stmt).unique().all() + + for stratum in geographic_strata: + constraints = session.exec( + select(StratumConstraint).where( + StratumConstraint.stratum_id == stratum.stratum_id + ) + ).all() + + if not constraints: + strata_map["national"] = stratum.stratum_id + else: + constraint_vars = { + c.constraint_variable: c.value for c in constraints + } + + if "congressional_district_geoid" in constraint_vars: + cd_geoid = int(constraint_vars["congressional_district_geoid"]) + strata_map["district"][cd_geoid] = stratum.stratum_id + elif "state_fips" in constraint_vars: + state_fips = int(constraint_vars["state_fips"]) + strata_map["state"][state_fips] = stratum.stratum_id + + return strata_map From a01e5bb03fcff3ac84656ca2c8cb59945b5eb67d Mon Sep 17 00:00:00 2001 From: "baogorek@gmail.com" Date: Tue, 27 Jan 2026 16:53:06 -0500 Subject: [PATCH 3/8] Migrate data pipeline from CPS 2023 to 2024 and remove unused datasets Switch the data target to use 2024 CPS data (March 2025 ASEC) instead of 2023. Add CPS_2024_Full for full-sample generation, update ExtendedCPS_2024 and local area calibration to use it. Remove CPS_2021/2022/2023_Full, PooledCPS, Pooled_3_Year_CPS_2023, ExtendedCPS_2023, dead code, and unused exports. Update database ETL scripts for strata, IRS SOI, Medicaid, and SNAP. Trim cps.py __main__ to generate only CPS_2024_Full. Co-Authored-By: Claude Opus 4.5 --- .gitignore | 3 + Makefile | 14 +- changelog_entry.yaml | 16 + docs/local_area_calibration_setup.ipynb | 145 +----- policyengine_us_data/datasets/__init__.py | 25 - policyengine_us_data/datasets/cps/cps.py | 104 +--- .../datasets/cps/enhanced_cps.py | 45 -- .../datasets/cps/extended_cps.py | 12 +- .../create_stratified_cps.py | 8 +- .../fit_calibration_weights.py | 4 +- policyengine_us_data/db/DATABASE_GUIDE.md | 470 +++++++----------- .../db/create_initial_strata.py | 196 ++++++-- policyengine_us_data/db/etl_age.py | 15 +- policyengine_us_data/db/etl_irs_soi.py | 64 ++- policyengine_us_data/db/etl_medicaid.py | 79 ++- policyengine_us_data/db/etl_snap.py | 49 +- policyengine_us_data/db/validate_database.py | 7 +- .../storage/upload_completed_datasets.py | 6 +- .../test_local_area_calibration/conftest.py | 2 +- policyengine_us_data/utils/census.py | 30 +- policyengine_us_data/utils/raw_cache.py | 37 ++ 21 files changed, 631 insertions(+), 700 deletions(-) create mode 100644 policyengine_us_data/utils/raw_cache.py diff --git a/.gitignore b/.gitignore index 5d183bf63..3f32cd0ac 100644 --- a/.gitignore +++ b/.gitignore @@ -27,6 +27,9 @@ node_modules !policyengine_us_data/storage/national_and_district_rents_2023.csv docs/.ipynb_checkpoints/ +## Raw input cache for database pipeline +policyengine_us_data/storage/calibration/raw_inputs/ + ## Batch processing checkpoints completed_*.txt diff --git a/Makefile b/Makefile index 270717c39..a2297de5b 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: all format test install download upload docker documentation data publish-local-area clean build paper clean-paper presentations +.PHONY: all format test install download upload docker documentation data publish-local-area clean build paper clean-paper presentations database database-refresh promote-database all: data test @@ -63,6 +63,18 @@ database: python policyengine_us_data/db/etl_irs_soi.py python policyengine_us_data/db/validate_database.py +database-refresh: + rm -rf policyengine_us_data/storage/calibration/raw_inputs/ + $(MAKE) database + +promote-database: + cp policyengine_us_data/storage/calibration/policy_data.db \ + $(HOME)/devl/huggingface/policyengine-us-data/calibration/policy_data.db + rm -rf $(HOME)/devl/huggingface/policyengine-us-data/calibration/raw_inputs + cp -r policyengine_us_data/storage/calibration/raw_inputs \ + $(HOME)/devl/huggingface/policyengine-us-data/calibration/raw_inputs + @echo "Copied DB and raw_inputs to HF clone. Now cd to HF repo, commit, and push." + data: download python policyengine_us_data/utils/uprating.py python policyengine_us_data/datasets/acs/acs.py diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29bb..bfc2edd01 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,16 @@ +- bump: minor + changes: + changed: + - Migrated data pipeline from CPS 2023 to CPS 2024 (March 2025 ASEC) + - Updated ExtendedCPS_2024 to use new CPS_2024_Full (full sample) + - Updated local area calibration to use 2024 extended CPS data + - Updated database ETL strata, IRS SOI, Medicaid, and SNAP scripts + removed: + - Removed CPS_2021_Full, CPS_2022_Full, CPS_2023_Full classes + - Removed PooledCPS and Pooled_3_Year_CPS_2023 + - Removed ExtendedCPS_2023 + - Removed dead train_previous_year_income_model function + - Removed unused dataset exports from __init__.py + added: + - Added CPS_2024_Full class for full-sample 2024 CPS generation + - Added raw_cache utility for Census data caching diff --git a/docs/local_area_calibration_setup.ipynb b/docs/local_area_calibration_setup.ipynb index 9060a3df2..c44c62f0f 100644 --- a/docs/local_area_calibration_setup.ipynb +++ b/docs/local_area_calibration_setup.ipynb @@ -61,17 +61,11 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "cell-3", "metadata": {}, "outputs": [], - "source": [ - "db_path = STORAGE_FOLDER / \"calibration\" / \"policy_data.db\"\n", - "db_uri = f\"sqlite:///{db_path}\"\n", - "dataset_path = str(STORAGE_FOLDER / \"stratified_extended_cps_2023.h5\")\n", - "\n", - "engine = create_engine(db_uri)" - ] + "source": "db_path = STORAGE_FOLDER / \"calibration\" / \"policy_data.db\"\ndb_uri = f\"sqlite:///{db_path}\"\ndataset_path = str(STORAGE_FOLDER / \"stratified_extended_cps_2024.h5\")\n\nengine = create_engine(db_uri)" }, { "cell_type": "markdown", @@ -148,42 +142,11 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "cell-7", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "X_sparse shape: (539, 256633)\n", - " Rows (targets): 539\n", - " Columns (household × CD pairs): 256633\n", - " Non-zero entries: 67,756\n", - " Sparsity: 99.95%\n" - ] - } - ], - "source": [ - "sim = Microsimulation(dataset=dataset_path)\n", - "\n", - "builder = SparseMatrixBuilder(\n", - " db_uri,\n", - " time_period=2023,\n", - " cds_to_calibrate=test_cds,\n", - " dataset_path=dataset_path,\n", - ")\n", - "\n", - "targets_df, X_sparse, household_id_mapping = builder.build_matrix(\n", - " sim, target_filter={\"stratum_group_ids\": [4], \"variables\": [\"snap\"]}\n", - ")\n", - "\n", - "print(f\"X_sparse shape: {X_sparse.shape}\")\n", - "print(f\" Rows (targets): {X_sparse.shape[0]}\")\n", - "print(f\" Columns (household × CD pairs): {X_sparse.shape[1]}\")\n", - "print(f\" Non-zero entries: {X_sparse.nnz:,}\")\n", - "print(f\" Sparsity: {1 - X_sparse.nnz / (X_sparse.shape[0] * X_sparse.shape[1]):.2%}\")" - ] + "outputs": [], + "source": "sim = Microsimulation(dataset=dataset_path)\n\nbuilder = SparseMatrixBuilder(\n db_uri,\n time_period=2024,\n cds_to_calibrate=test_cds,\n dataset_path=dataset_path,\n)\n\ntargets_df, X_sparse, household_id_mapping = builder.build_matrix(\n sim, target_filter={\"stratum_group_ids\": [4], \"variables\": [\"snap\"]}\n)\n\nprint(f\"X_sparse shape: {X_sparse.shape}\")\nprint(f\" Rows (targets): {X_sparse.shape[0]}\")\nprint(f\" Columns (household × CD pairs): {X_sparse.shape[1]}\")\nprint(f\" Non-zero entries: {X_sparse.nnz:,}\")\nprint(f\" Sparsity: {1 - X_sparse.nnz / (X_sparse.shape[0] * X_sparse.shape[1]):.2%}\")" }, { "cell_type": "markdown", @@ -428,43 +391,11 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "id": "e05aaeab-3786-4ff0-a50b-34577065d2e0", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Remember, this is a North Carolina target:\n", - "\n", - "target_id 9372\n", - "stratum_id 9799\n", - "variable snap\n", - "value 4041086120.0\n", - "period 2023\n", - "stratum_group_id 4\n", - "geographic_id 37\n", - "Name: 80, dtype: object\n", - "\n", - "Household donated to NC's 2nd district, 2023 SNAP dollars:\n", - "789.19995\n", - "\n", - "Household donated to NC's 2nd district, 2023 SNAP dollars:\n", - "0.0\n" - ] - } - ], - "source": [ - "print(\"Remember, this is a North Carolina target:\\n\")\n", - "print(targets_df.iloc[row_loc])\n", - "\n", - "print(\"\\nNC State target. Household donated to NC's 2nd district, 2023 SNAP dollars:\")\n", - "print(X_sparse[row_loc, positions['3702']]) # Household donated to NC's 2nd district\n", - "\n", - "print(\"\\nSame target, same household, donated to AK's at Large district, 2023 SNAP dollars:\")\n", - "print(X_sparse[row_loc, positions['201']]) # Household donated to AK's at Large District" - ] + "outputs": [], + "source": "print(\"Remember, this is a North Carolina target:\\n\")\nprint(targets_df.iloc[row_loc])\n\nprint(\"\\nNC State target. Household donated to NC's 2nd district, 2024 SNAP dollars:\")\nprint(X_sparse[row_loc, positions['3702']]) # Household donated to NC's 2nd district\n\nprint(\"\\nSame target, same household, donated to AK's at Large district, 2024 SNAP dollars:\")\nprint(X_sparse[row_loc, positions['201']]) # Household donated to AK's at Large District" }, { "cell_type": "markdown", @@ -507,24 +438,11 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "id": "ac59b6f1-859f-4246-8a05-8cb26384c882", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Household donated to AK's 1st district, 2023 SNAP dollars:\n", - "342.48004\n" - ] - } - ], - "source": [ - "print(\"\\nHousehold donated to AK's 1st district, 2023 SNAP dollars:\")\n", - "print(X_sparse[new_row_loc, positions['201']]) # Household donated to AK's at Large District" - ] + "outputs": [], + "source": "print(\"\\nHousehold donated to AK's 1st district, 2024 SNAP dollars:\")\nprint(X_sparse[new_row_loc, positions['201']]) # Household donated to AK's at Large District" }, { "cell_type": "markdown", @@ -538,44 +456,11 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "id": "cell-19", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "SNAP values for first 5 households under different state rules:\n", - " NC rules: [789.19995117 0. 0. 0. 0. ]\n", - " AK rules: [342.4800415 0. 0. 0. 0. ]\n", - " Difference: [-446.71990967 0. 0. 0. 0. ]\n" - ] - } - ], - "source": [ - "def create_state_simulation(state_fips):\n", - " \"\"\"Create a simulation with all households assigned to a specific state.\"\"\"\n", - " s = Microsimulation(dataset=dataset_path)\n", - " s.set_input(\n", - " \"state_fips\", 2023, np.full(hh_snap_df.shape[0], state_fips, dtype=np.int32)\n", - " )\n", - " for var in get_calculated_variables(s):\n", - " s.delete_arrays(var)\n", - " return s\n", - "\n", - "# Compare SNAP for first 5 households under NC vs AK rules\n", - "nc_sim = create_state_simulation(37) # NC\n", - "ak_sim = create_state_simulation(2) # AK\n", - "\n", - "nc_snap = nc_sim.calculate(\"snap\", map_to=\"household\").values[:5]\n", - "ak_snap = ak_sim.calculate(\"snap\", map_to=\"household\").values[:5]\n", - "\n", - "print(\"SNAP values for first 5 households under different state rules:\")\n", - "print(f\" NC rules: {nc_snap}\")\n", - "print(f\" AK rules: {ak_snap}\")\n", - "print(f\" Difference: {ak_snap - nc_snap}\")" - ] + "outputs": [], + "source": "def create_state_simulation(state_fips):\n \"\"\"Create a simulation with all households assigned to a specific state.\"\"\"\n s = Microsimulation(dataset=dataset_path)\n s.set_input(\n \"state_fips\", 2024, np.full(hh_snap_df.shape[0], state_fips, dtype=np.int32)\n )\n for var in get_calculated_variables(s):\n s.delete_arrays(var)\n return s\n\n# Compare SNAP for first 5 households under NC vs AK rules\nnc_sim = create_state_simulation(37) # NC\nak_sim = create_state_simulation(2) # AK\n\nnc_snap = nc_sim.calculate(\"snap\", map_to=\"household\").values[:5]\nak_snap = ak_sim.calculate(\"snap\", map_to=\"household\").values[:5]\n\nprint(\"SNAP values for first 5 households under different state rules:\")\nprint(f\" NC rules: {nc_snap}\")\nprint(f\" AK rules: {ak_snap}\")\nprint(f\" Difference: {ak_snap - nc_snap}\")" }, { "cell_type": "markdown", @@ -1015,4 +900,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/policyengine_us_data/datasets/__init__.py b/policyengine_us_data/datasets/__init__.py index 87461837e..8fb268cda 100644 --- a/policyengine_us_data/datasets/__init__.py +++ b/policyengine_us_data/datasets/__init__.py @@ -1,28 +1,3 @@ from .cps import ( - CPS_2019, - CPS_2020, - CPS_2021, - CPS_2022, - CPS_2023, - CPS_2024, - Pooled_3_Year_CPS_2023, - CensusCPS_2018, - CensusCPS_2019, - CensusCPS_2020, - CensusCPS_2021, - CensusCPS_2022, - CensusCPS_2023, EnhancedCPS_2024, - ReweightedCPS_2024, ) -from .puf import PUF_2015, PUF_2021, PUF_2024, IRS_PUF_2015 -from .acs import ACS_2022 - -DATASETS = [ - CPS_2022, - PUF_2021, - CPS_2024, - EnhancedCPS_2024, - ACS_2022, - Pooled_3_Year_CPS_2023, -] diff --git a/policyengine_us_data/datasets/cps/cps.py b/policyengine_us_data/datasets/cps/cps.py index 249e40e5d..4af95f9f1 100644 --- a/policyengine_us_data/datasets/cps/cps.py +++ b/policyengine_us_data/datasets/cps/cps.py @@ -2049,102 +2049,14 @@ class CPS_2025(CPS): frac = 1 -# The below datasets are a very naïve way of preventing downsampling in the -# Pooled 3-Year CPS. They should be replaced by a more sustainable approach. -# If these are still here on July 1, 2025, please open an issue and raise at standup. -class CPS_2021_Full(CPS): - name = "cps_2021_full" - label = "CPS 2021 (full)" - raw_cps = CensusCPS_2021 - previous_year_raw_cps = CensusCPS_2020 - file_path = STORAGE_FOLDER / "cps_2021_full.h5" - time_period = 2021 - - -class CPS_2022_Full(CPS): - name = "cps_2022_full" - label = "CPS 2022 (full)" - raw_cps = CensusCPS_2022 - previous_year_raw_cps = CensusCPS_2021 - file_path = STORAGE_FOLDER / "cps_2022_full.h5" - time_period = 2022 - - -class CPS_2023_Full(CPS): - name = "cps_2023_full" - label = "CPS 2023 (full)" - raw_cps = CensusCPS_2023 - previous_year_raw_cps = CensusCPS_2022 - file_path = STORAGE_FOLDER / "cps_2023_full.h5" - time_period = 2023 - - -class PooledCPS(Dataset): - data_format = Dataset.ARRAYS - input_datasets: list - time_period: int - - def generate(self): - data = [ - input_dataset(require=True).load_dataset() - for input_dataset in self.input_datasets - ] - time_periods = [dataset.time_period for dataset in self.input_datasets] - data = [ - uprate_cps_data(data, time_period, self.time_period) - for data, time_period in zip(data, time_periods) - ] - - new_data = {} - - for i in range(len(data)): - for variable in data[i]: - data_values = data[i][variable] - if variable not in new_data: - new_data[variable] = data_values - elif "_id" in variable: - previous_max = new_data[variable].max() - new_data[variable] = np.concatenate( - [ - new_data[variable], - data_values + previous_max, - ] - ) - else: - new_data[variable] = np.concatenate( - [ - new_data[variable], - data_values, - ] - ) - - new_data["household_weight"] = new_data["household_weight"] / len( - self.input_datasets - ) - - self.save_dataset(new_data) - - -class Pooled_3_Year_CPS_2023(PooledCPS): - label = "CPS 2023 (3-year pooled)" - name = "pooled_3_year_cps_2023" - file_path = STORAGE_FOLDER / "pooled_3_year_cps_2023.h5" - input_datasets = [ - CPS_2021_Full, - CPS_2022_Full, - CPS_2023_Full, - ] - time_period = 2023 - url = "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5" +class CPS_2024_Full(CPS): + name = "cps_2024_full" + label = "CPS 2024 (full)" + raw_cps = CensusCPS_2024 + previous_year_raw_cps = CensusCPS_2023 + file_path = STORAGE_FOLDER / "cps_2024_full.h5" + time_period = 2024 if __name__ == "__main__": - CPS_2021().generate() - CPS_2022().generate() - CPS_2023().generate() - CPS_2024().generate() - CPS_2025().generate() - CPS_2021_Full().generate() - CPS_2022_Full().generate() - CPS_2023_Full().generate() - Pooled_3_Year_CPS_2023().generate() + CPS_2024_Full().generate() diff --git a/policyengine_us_data/datasets/cps/enhanced_cps.py b/policyengine_us_data/datasets/cps/enhanced_cps.py index dc8f50402..9799e99ac 100644 --- a/policyengine_us_data/datasets/cps/enhanced_cps.py +++ b/policyengine_us_data/datasets/cps/enhanced_cps.py @@ -15,7 +15,6 @@ from policyengine_us_data.storage import STORAGE_FOLDER from policyengine_us_data.datasets.cps.extended_cps import ( ExtendedCPS_2024, - CPS_2019, CPS_2024, ) import logging @@ -140,50 +139,6 @@ def loss(weights): return final_weights_sparse -def train_previous_year_income_model(): - from policyengine_us import Microsimulation - - sim = Microsimulation(dataset=CPS_2019) - - sim.subsample(10_000) - - VARIABLES = [ - "previous_year_income_available", - "employment_income", - "self_employment_income", - "age", - "is_male", - "spm_unit_state_fips", - "dividend_income", - "interest_income", - "social_security", - "capital_gains", - "is_disabled", - "is_blind", - "is_married", - "tax_unit_children", - "pension_income", - ] - - OUTPUTS = [ - "employment_income_last_year", - "self_employment_income_last_year", - ] - - df = sim.calculate_dataframe(VARIABLES + OUTPUTS, 2019, map_to="person") - df_train = df[df.previous_year_income_available] - - from policyengine_us_data.utils import QRF - - income_last_year = QRF() - X = df_train[VARIABLES[1:]] - y = df_train[OUTPUTS] - - income_last_year.fit(X, y) - - return income_last_year - - class EnhancedCPS(Dataset): data_format = Dataset.TIME_PERIOD_ARRAYS input_dataset: Type[Dataset] diff --git a/policyengine_us_data/datasets/cps/extended_cps.py b/policyengine_us_data/datasets/cps/extended_cps.py index b5b4fa242..b9f2c81aa 100644 --- a/policyengine_us_data/datasets/cps/extended_cps.py +++ b/policyengine_us_data/datasets/cps/extended_cps.py @@ -320,17 +320,8 @@ def impute_income_variables( return result -class ExtendedCPS_2023(ExtendedCPS): - cps = CPS_2023_Full - puf = PUF_2023 - name = "extended_cps_2023" - label = "Extended CPS (2023)" - file_path = STORAGE_FOLDER / "extended_cps_2023.h5" - time_period = 2023 - - class ExtendedCPS_2024(ExtendedCPS): - cps = CPS_2024 + cps = CPS_2024_Full puf = PUF_2024 name = "extended_cps_2024" label = "Extended CPS (2024)" @@ -339,5 +330,4 @@ class ExtendedCPS_2024(ExtendedCPS): if __name__ == "__main__": - ExtendedCPS_2023().generate() ExtendedCPS_2024().generate() diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/create_stratified_cps.py b/policyengine_us_data/datasets/cps/local_area_calibration/create_stratified_cps.py index da3dffc05..ba1011016 100644 --- a/policyengine_us_data/datasets/cps/local_area_calibration/create_stratified_cps.py +++ b/policyengine_us_data/datasets/cps/local_area_calibration/create_stratified_cps.py @@ -1,5 +1,5 @@ """ -Create a stratified sample of extended_cps_2023.h5 that preserves high-income households +Create a stratified sample of extended_cps_2024.h5 that preserves high-income households while maintaining diversity in lower income strata for poverty analysis. Strategy: @@ -35,7 +35,7 @@ def create_stratified_cps_dataset( high_income_percentile: Keep ALL households above this AGI percentile (e.g., 99 or 99.5) oversample_poor: If True, boost sampling rate for bottom 25% by 1.5x seed: Random seed for reproducibility (default: None for random) - base_dataset: Path to source h5 file (default: extended_cps_2023.h5) + base_dataset: Path to source h5 file (default: extended_cps_2024.h5) output_path: Where to save the stratified h5 file """ print("\n" + "=" * 70) @@ -46,7 +46,7 @@ def create_stratified_cps_dataset( if base_dataset is None: from policyengine_us_data.storage import STORAGE_FOLDER - base_dataset = str(STORAGE_FOLDER / "extended_cps_2023.h5") + base_dataset = str(STORAGE_FOLDER / "extended_cps_2024.h5") # Load the original simulation print("Loading original dataset...") @@ -217,7 +217,7 @@ def create_stratified_cps_dataset( if output_path is None: from policyengine_us_data.storage import STORAGE_FOLDER - output_path = str(STORAGE_FOLDER / "stratified_extended_cps_2023.h5") + output_path = str(STORAGE_FOLDER / "stratified_extended_cps_2024.h5") # Save to h5 file print(f"\nSaving to {output_path}...") diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/fit_calibration_weights.py b/policyengine_us_data/datasets/cps/local_area_calibration/fit_calibration_weights.py index ee3d38475..2cb153c2c 100644 --- a/policyengine_us_data/datasets/cps/local_area_calibration/fit_calibration_weights.py +++ b/policyengine_us_data/datasets/cps/local_area_calibration/fit_calibration_weights.py @@ -76,11 +76,11 @@ if args.dataset_path: dataset_path = Path(args.dataset_path) else: - dataset_path = STORAGE_FOLDER / "stratified_extended_cps_2023.h5" + dataset_path = STORAGE_FOLDER / "stratified_extended_cps_2024.h5" output_dir = STORAGE_FOLDER / "calibration" output_dir.mkdir(parents=True, exist_ok=True) -time_period = 2023 +time_period = 2024 # Get all CDs from database cds_to_calibrate = get_all_cds_from_database(db_uri) diff --git a/policyengine_us_data/db/DATABASE_GUIDE.md b/policyengine_us_data/db/DATABASE_GUIDE.md index 93657ef5f..f751e1d7d 100644 --- a/policyengine_us_data/db/DATABASE_GUIDE.md +++ b/policyengine_us_data/db/DATABASE_GUIDE.md @@ -1,310 +1,197 @@ -# PolicyEngine US Data - Database Getting Started Guide +# PolicyEngine US Data - Database Guide -## Current Task: Matrix Generation for Calibration Targets +## Overview -### Objective -Create a comprehensive matrix of calibration targets with the following requirements: -1. **Rows grouped by target type** - All age targets together, all income targets together, etc. -2. **Known counts per group** - Each group has a predictable number of entries (e.g., 18 age groups, 9 income brackets) -3. **Source selection** - Ability to specify which data source to use when multiple exist -4. **Geographic filtering** - Ability to select specific geographic levels (national, state, or congressional district) +This database uses a hierarchical stratum-based model to organize US demographic and economic data for PolicyEngine calibration. Data is organized into "strata" - population subgroups defined by constraints - with calibration targets attached to each stratum. -### Implementation Strategy -The `stratum_group_id` field now categorizes strata by conceptual type, making matrix generation straightforward: -- Query by `stratum_group_id` to get all related targets together -- Each demographic group appears consistently across all 488 geographic areas -- Join with `sources` table to filter/identify data provenance -- Use parent-child relationships to navigate geographic hierarchy +The database is a **compiled artifact**: built locally from government data sources, validated, and promoted to HuggingFace for consumption by downstream pipelines. -### Example Matrix Query -```sql --- Generate matrix for a specific geography (e.g., national level) -SELECT - CASE s.stratum_group_id - WHEN 2 THEN 'Age' - WHEN 3 THEN 'Income' - WHEN 4 THEN 'SNAP' - WHEN 5 THEN 'Medicaid' - WHEN 6 THEN 'EITC' - END AS group_name, - s.notes AS stratum_description, - t.variable, - t.value, - src.name AS source -FROM strata s -JOIN targets t ON s.stratum_id = t.stratum_id -JOIN sources src ON t.source_id = src.source_id -WHERE s.parent_stratum_id = 1 -- National level (or any specific geography) - AND s.stratum_group_id > 1 -- Exclude geographic strata -ORDER BY s.stratum_group_id, s.stratum_id; +## Building the Database + +### Quick Start +```bash +source ~/envs/sep/bin/activate +cd ~/devl/sep/policyengine-us-data + +make database # Build (uses cached downloads if available) +make database-refresh # Force re-download all sources and rebuild +make promote-database # Copy DB + raw inputs to HuggingFace clone ``` -## Overview -This database uses a hierarchical stratum-based model to organize US demographic and economic data for PolicyEngine calibration. The core concept is that data is organized into "strata" - population subgroups defined by constraints. +### Pipeline Stages -## Key Concepts +`make database` runs these scripts sequentially: -### Strata Hierarchy -The database uses a parent-child hierarchy: -``` -United States (national) -├── States (51 including DC) -│ ├── Congressional Districts (436 total) -│ │ ├── Age groups (18 brackets per geographic area) -│ │ ├── Income groups (AGI stubs) -│ │ └── Other demographic strata (EITC recipients, SNAP, Medicaid, etc.) +| # | Script | Network? | What it does | +|---|--------|----------|--------------| +| 1 | `create_database_tables.py` | No | Creates empty SQLite schema (7 tables) | +| 2 | `create_initial_strata.py` | Census ACS 5-year | Builds geographic hierarchy: US > 51 states > 436 CDs | +| 3 | `etl_national_targets.py` | No | Loads ~40 hardcoded national targets (CBO, Treasury, CMS) | +| 4 | `etl_age.py` | Census ACS 1-year | Age distribution: 18 bins x 488 geographies | +| 5 | `etl_medicaid.py` | Census ACS + CMS | Medicaid enrollment (admin state-level, survey district-level) | +| 6 | `etl_snap.py` | USDA FNS + Census ACS | SNAP participation (admin state-level, survey district-level) | +| 7 | `etl_irs_soi.py` | IRS | Tax variables, EITC by child count, AGI brackets | +| 8 | `validate_database.py` | No | Checks all target variables exist in policyengine-us | + +### Raw Input Caching + +All network downloads are cached in `storage/calibration/raw_inputs/`. On subsequent runs, cached files are used instead of hitting external APIs. This decouples extraction from transformation so you can iterate on ETL logic without network access. + +Set `PE_REFRESH_RAW=1` to force re-download: +```bash +PE_REFRESH_RAW=1 make database ``` -### Stratum Groups -The `stratum_group_id` field categorizes strata by their conceptual type: -- `1`: Geographic boundaries (US, states, congressional districts) -- `2`: Age-based strata (18 age groups per geography) -- `3`: Income/AGI-based strata (9 income brackets per geography) -- `4`: SNAP recipient strata (1 per geography) -- `5`: Medicaid enrollment strata (1 per geography) -- `6`: EITC recipient strata (4 groups by qualifying children per geography) +### Promotion to HuggingFace -### UCGID Translation -The Census Bureau uses UCGIDs (Universal Census Geographic IDs) in their API responses: -- `0100000US`: National level -- `0400000USXX`: State (XX = state FIPS code) -- `5001800USXXDD`: Congressional district (XX = state FIPS, DD = district number) +After building and validating: +```bash +make promote-database +cd ~/devl/huggingface/policyengine-us-data +git add calibration/policy_data.db calibration/raw_inputs/ +git commit -m "Update policy_data.db - " +git push +``` -We parse these into our internal model using `state_fips` and `congressional_district_geoid`. +This copies both the database and the raw inputs that built it, preserving provenance in the HF repo's git history. -### Constraint Operations -All constraints use standardized operators: -- `==`: Equals -- `!=`: Not equals -- `>`: Greater than -- `>=`: Greater than or equal -- `<`: Less than -- `<=`: Less than or equal +### Recovery -## Database Structure +If a step fails mid-pipeline, delete the database and re-run. With cached downloads this takes ~10-15 minutes: +```bash +rm -f policyengine_us_data/storage/calibration/policy_data.db +make database +``` + +## Database Schema ### Core Tables -1. **strata**: Main table for population subgroups - - `stratum_id`: Primary key - - `parent_stratum_id`: Links to parent in hierarchy - - `stratum_group_id`: Conceptual category (1=Geographic, 2=Age, 3=Income, 4=SNAP, 5=Medicaid, 6=EITC) - - `definition_hash`: Unique hash of constraints for deduplication - -2. **stratum_constraints**: Defines rules for each stratum - - `constraint_variable`: Variable name (e.g., "age", "state_fips") - - `operation`: Comparison operator (==, >, <, etc.) - - `value`: Constraint value - -3. **targets**: Stores actual data values - - `variable`: PolicyEngine US variable name - - `period`: Year - - `value`: Numerical value - - `source_id`: Foreign key to sources table - - `active`: Boolean flag for active/inactive targets - - `tolerance`: Allowed relative error percentage + +**strata** - Population subgroups +- `stratum_id`: Auto-generated primary key +- `parent_stratum_id`: Links to parent in hierarchy +- `stratum_group_id`: Conceptual category (see below) +- `definition_hash`: SHA-256 of constraints for deduplication + +**stratum_constraints** - Rules defining each stratum +- `constraint_variable`: Variable name (e.g., `age`, `state_fips`) +- `operation`: Comparison operator (`==`, `!=`, `>`, `>=`, `<`, `<=`) +- `value`: String-encoded value + +**targets** - Calibration data values +- `variable`: PolicyEngine US variable name (e.g., `eitc`, `income_tax`) +- `period`: Year +- `value`: Numerical value +- `source_id`: Foreign key to sources table +- `active`: Boolean flag ### Metadata Tables -4. **sources**: Data source metadata - - `source_id`: Primary key (auto-generated) - - `name`: Source name (e.g., "IRS Statistics of Income") - - `type`: SourceType enum (administrative, survey, hardcoded) - - `vintage`: Year or version of data - - `description`: Detailed description - - `url`: Reference URL - - `notes`: Additional notes - -5. **variable_groups**: Logical groupings of related variables - - `group_id`: Primary key (auto-generated) - - `name`: Unique group name (e.g., "age_distribution", "snap_recipients") - - `category`: High-level category (demographic, benefit, tax, income, expense) - - `is_histogram`: Whether this represents a distribution - - `is_exclusive`: Whether variables are mutually exclusive - - `aggregation_method`: How to aggregate (sum, weighted_avg, etc.) - - `display_order`: Order for display in matrices/reports - - `description`: What this group represents - -6. **variable_metadata**: Display information for variables - - `metadata_id`: Primary key - - `variable`: PolicyEngine variable name - - `group_id`: Foreign key to variable_groups - - `display_name`: Human-readable name - - `display_order`: Order within group - - `units`: Units of measurement (dollars, count, percent) - - `is_primary`: Whether this is a primary vs derived variable - - `notes`: Additional notes -## Building the Database +**sources** - Data provenance (e.g., "Census ACS", "IRS SOI", "CBO") -### Step 1: Create Tables -```bash -source ~/envs/sep/bin/activate -cd policyengine_us_data/db -python create_database_tables.py -``` +**variable_groups** - Logical groupings (e.g., "age_distribution", "snap_recipients") -### Step 2: Create Geographic Hierarchy -```bash -python create_initial_strata.py -``` -Creates: 1 national + 51 state + 436 congressional district strata +**variable_metadata** - Display info for variables (display name, units, ordering) -### Step 3: Load Data (in order) -```bash -# National hardcoded targets -python etl_national_targets.py +## Key Concepts -# Age demographics (Census ACS) -python etl_age.py +### Stratum Groups -# Economic data (IRS SOI) -python etl_irs_soi.py +The `stratum_group_id` field categorizes strata: -# Benefits data -python etl_medicaid.py -python etl_snap.py -``` +| ID | Category | Description | +|----|----------|-------------| +| 0 | Uncategorized | Legacy strata not yet assigned a group | +| 1 | Geographic | US, states, congressional districts | +| 2 | Age | 18 age brackets per geography | +| 3 | Income/AGI | 9 income brackets per geography | +| 4 | SNAP | SNAP recipient strata | +| 5 | Medicaid | Medicaid enrollment strata | +| 6 | EITC | EITC recipients by qualifying children | -### Step 4: Validate -```bash -python validate_database.py -``` +### Geographic Hierarchy -Expected output: -- 488 geographic strata -- 8,784 age strata (18 age groups × 488 areas) -- All strata have unique definition hashes - -## Common Utility Functions - -Located in `policyengine_us_data/utils/db.py`: - -- `get_stratum_by_id(session, id)`: Retrieve stratum by ID -- `get_simple_stratum_by_ucgid(session, ucgid)`: Get stratum by UCGID -- `get_root_strata(session)`: Get root strata -- `get_stratum_children(session, id)`: Get child strata -- `get_stratum_parent(session, id)`: Get parent stratum - -Located in `policyengine_us_data/utils/db_metadata.py`: - -- `get_or_create_source(session, ...)`: Get or create a data source -- `get_or_create_variable_group(session, ...)`: Get or create a variable group -- `get_or_create_variable_metadata(session, ...)`: Get or create variable metadata - -## ETL Script Pattern - -Each ETL script follows this pattern: - -1. **Extract**: Pull data from source (Census API, IRS files, etc.) -2. **Transform**: - - Parse UCGIDs to get geographic info - - Map to existing geographic strata - - Create demographic strata as children -3. **Load**: - - Check for existing strata to avoid duplicates - - Add constraints and targets - - Commit to database - -## Important Notes - -### Avoiding Duplicates -Always check if a stratum exists before creating: -```python -existing_stratum = session.exec( - select(Stratum).where( - Stratum.parent_stratum_id == parent_id, - Stratum.stratum_group_id == group_id, - Stratum.notes == note - ) -).first() +``` +United States (no constraints) + ├── Alabama (state_fips == 1) + │ ├── AL-01 (congressional_district_geoid == 101) + │ ├── AL-02 (congressional_district_geoid == 102) + │ └── ... + ├── Alaska (state_fips == 2) + │ └── AK-01 (congressional_district_geoid == 201) + └── ... ``` -### Geographic Constraints -- National strata: No geographic constraints needed -- State strata: `state_fips` constraint -- District strata: `congressional_district_geoid` constraint +Geographic strata use `state_fips` and `congressional_district_geoid` constraints (not UCGIDs). The `parse_ucgid()` and `get_geographic_strata()` functions in `utils/db.py` bridge between Census UCGID strings and these internal identifiers. -### Congressional District Normalization -- District 00 → 01 (at-large districts) -- DC district 98 → 01 (delegate district) +### UCGID Translation -### IRS AGI Ranges -AGI stubs use >= for lower bound, < for upper bound: -- Stub 3: $10,000 <= AGI < $25,000 -- Stub 4: $25,000 <= AGI < $50,000 -- etc. +Census Bureau API responses use UCGIDs (Universal Census Geographic IDs): +- `0100000US` = National +- `0400000USXX` = State (XX = state FIPS) +- `5001800USXXDD` = Congressional district (XX = state FIPS, DD = district number) -## Troubleshooting +ETL scripts that pull Census data receive UCGIDs and create their own domain-specific strata with `ucgid_str` constraints. The geographic hierarchy strata (stratum_group_id=1) use `state_fips`/`congressional_district_geoid` instead. -### "WARNING: Expected 8784 age strata, found 16104" -**Status: RESOLVED** +### Constraint Operations -The validation script was incorrectly counting all demographic strata (stratum_group_id = 0) as age strata. After implementing the new stratum_group_id scheme (1=Geographic, 2=Age, 3=Income, etc.), the validation correctly identifies 8,784 age strata. +All constraints use standardized operators validated by the `ConstraintOperation` enum: +`==`, `!=`, `>`, `>=`, `<`, `<=` -### Fixed: Synthetic Variable Names -Previously, the IRS SOI ETL was creating invalid variable names like `eitc_tax_unit_count` that don't exist in PolicyEngine. Now correctly uses `tax_unit_count` with appropriate stratum constraints to indicate what's being counted. +Note: Some legacy ETL scripts use string operations like `"in"`, `"equals"`, `"greater_than"`. These coexist in the database but new code should use the standardized operators. -### UCGID strings in notes -Legacy UCGID references have been replaced with human-readable identifiers: -- "US" for national -- "State FIPS X" for states -- "CD XXXX" for congressional districts +### Constraint Value Types -### Mixed operation types -All operations now use standardized symbols (==, >, <, etc.) validated by ConstraintOperation enum. +The `value` column stores all values as strings. Downstream code deserializes: +- Numeric strings -> int/float (age, income) +- `"True"`/`"False"` -> booleans (medicaid_enrolled) +- Other strings stay as strings (state_fips with leading zeros) -## Database Location -`policyengine_us_data/storage/calibration/policy_data.db` +## Important Warnings -## Example SQLite Queries with Metadata Features +### stratum_id != FIPS Code -### Compare Administrative vs Survey Data for SNAP -```sql -SELECT - s.type AS source_type, - s.name AS source_name, - st.notes AS location, - t.value AS household_count -FROM targets t -JOIN sources s ON t.source_id = s.source_id -JOIN strata st ON t.stratum_id = st.stratum_id -WHERE t.variable = 'household_count' - AND st.notes LIKE '%SNAP%' -ORDER BY s.type, st.notes; -``` +The `stratum_id` is auto-generated and has **no relationship** to FIPS codes: +- California: stratum_id=6, state_fips="06" (coincidental!) +- North Carolina: stratum_id=35, state_fips="37" (no match) -### Get All Variables in a Group with Their Metadata +Always look up strata by constraint values: ```sql -SELECT - vm.display_name, - vm.variable, - vm.units, - vm.display_order, - vg.description AS group_description -FROM variable_metadata vm -JOIN variable_groups vg ON vm.group_id = vg.group_id -WHERE vg.name = 'eitc_recipients' -ORDER BY vm.display_order; +SELECT s.stratum_id, s.notes +FROM strata s +JOIN stratum_constraints sc ON s.stratum_id = sc.stratum_id +WHERE sc.constraint_variable = 'state_fips' + AND sc.value = '37'; ``` -### Query by Stratum Group -```sql --- Get all age-related strata and their targets -SELECT - s.stratum_id, - s.notes, - t.variable, - t.value, - src.name AS source -FROM strata s -JOIN targets t ON s.stratum_id = t.stratum_id -JOIN sources src ON t.source_id = src.source_id -WHERE s.stratum_group_id = 2 -- Age strata -LIMIT 20; +## Utility Functions --- Count strata by group +**`policyengine_us_data/utils/db.py`**: +- `get_stratum_by_id(session, id)` - Retrieve stratum by primary key +- `get_simple_stratum_by_ucgid(session, ucgid)` - Find stratum with single ucgid_str constraint +- `get_root_strata(session)` - Get strata with no parent +- `get_stratum_children(session, id)` / `get_stratum_parent(session, id)` - Navigate hierarchy +- `parse_ucgid(ucgid_str)` - Parse UCGID to type/state_fips/district info +- `get_geographic_strata(session)` - Map of all geographic strata by type + +**`policyengine_us_data/utils/db_metadata.py`**: +- `get_or_create_source(session, ...)` - Upsert data source metadata +- `get_or_create_variable_group(session, ...)` - Upsert variable group +- `get_or_create_variable_metadata(session, ...)` - Upsert variable display info + +**`policyengine_us_data/utils/raw_cache.py`**: +- `is_cached(filename)` - Check if a raw input is cached +- `save_json(filename, data)` / `load_json(filename)` - Cache JSON data +- `save_bytes(filename, data)` / `load_bytes(filename)` - Cache binary data + +## Example Queries + +### Count strata by group +```sql SELECT stratum_group_id, CASE stratum_group_id + WHEN 0 THEN 'Uncategorized' WHEN 1 THEN 'Geographic' WHEN 2 THEN 'Age' WHEN 3 THEN 'Income/AGI' @@ -318,47 +205,34 @@ GROUP BY stratum_group_id ORDER BY stratum_group_id; ``` -## Key Improvements -1. Removed UCGID as a constraint variable (legacy Census concept) -2. Standardized constraint operations with validation -3. Consolidated duplicate code (parse_ucgid, get_geographic_strata) -4. Fixed epsilon hack in IRS AGI ranges -5. Added proper duplicate checking in age ETL -6. Improved human-readable notes without UCGID strings -7. Added metadata tables for sources, variable groups, and variable metadata -8. Fixed synthetic variable name bug (e.g., eitc_tax_unit_count → tax_unit_count) -9. Auto-generated source IDs instead of hardcoding -10. Proper categorization of admin vs survey data for same concepts -11. Implemented conceptual stratum_group_id scheme for better organization and querying - -## Known Issues / TODOs - -### IMPORTANT: stratum_id vs state_fips Codes -**WARNING**: The `stratum_id` is an auto-generated sequential ID and has NO relationship to FIPS codes, despite some confusing coincidences: -- California: stratum_id = 6, state_fips = "06" (coincidental match!) -- North Carolina: stratum_id = 35, state_fips = "37" (no match) -- Ohio: stratum_id = 37, state_fips = "39" (no match) - -When querying for states, ALWAYS use the `state_fips` constraint value, never assume stratum_id matches FIPS. - -Example of correct lookup: +### Get targets for a specific state ```sql --- Find North Carolina's stratum_id by FIPS code -SELECT s.stratum_id, s.notes -FROM strata s +SELECT t.variable, t.value, t.period, s.notes +FROM targets t +JOIN strata s ON t.stratum_id = s.stratum_id JOIN stratum_constraints sc ON s.stratum_id = sc.stratum_id WHERE sc.constraint_variable = 'state_fips' - AND sc.value = '37'; -- Returns stratum_id = 35 + AND sc.value = '37' +ORDER BY t.variable; +``` + +### Compare admin vs survey data sources +```sql +SELECT + src.type AS source_type, + src.name AS source_name, + st.notes AS location, + t.value +FROM targets t +JOIN sources src ON t.source_id = src.source_id +JOIN strata st ON t.stratum_id = st.stratum_id +WHERE t.variable = 'household_count' + AND st.notes LIKE '%SNAP%' +ORDER BY src.type, st.notes; ``` -### Type Conversion for Constraint Values -**DESIGN DECISION**: The `value` column in `stratum_constraints` must store heterogeneous data types as strings. The calibration code deserializes these: -- Numeric strings → int/float (for age, income constraints) -- "True"/"False" → Python booleans (for medicaid_enrolled, snap_enrolled) -- Other strings remain strings (for state_fips, which may have leading zeros) - -### Medicaid Data Structure -- Medicaid uses `person_count` variable (not `medicaid`) because it's structured as a histogram with constraints -- State-level targets use administrative data (T-MSIS source) -- Congressional district level uses survey data (ACS source) -- No national Medicaid target exists (intentionally, to avoid double-counting when using state-level data) +## Database Location + +`policyengine_us_data/storage/calibration/policy_data.db` + +Downloaded from HuggingFace by `download_private_prerequisites.py` and `download_calibration_inputs()` in `utils/huggingface.py`. diff --git a/policyengine_us_data/db/create_initial_strata.py b/policyengine_us_data/db/create_initial_strata.py index 1d7d3b4b2..85144b22b 100644 --- a/policyengine_us_data/db/create_initial_strata.py +++ b/policyengine_us_data/db/create_initial_strata.py @@ -1,74 +1,188 @@ +import logging from typing import Dict +import requests import pandas as pd from sqlmodel import Session, create_engine from policyengine_us_data.storage import STORAGE_FOLDER - - -from policyengine_us.variables.household.demographic.geographic.ucgid.ucgid_enum import ( - UCGID, -) from policyengine_us_data.db.create_database_tables import ( Stratum, StratumConstraint, ) +from policyengine_us_data.utils.raw_cache import ( + is_cached, + save_json, + load_json, +) +logger = logging.getLogger(__name__) + +STATE_NAMES = { + 1: "Alabama (AL)", + 2: "Alaska (AK)", + 4: "Arizona (AZ)", + 5: "Arkansas (AR)", + 6: "California (CA)", + 8: "Colorado (CO)", + 9: "Connecticut (CT)", + 10: "Delaware (DE)", + 11: "District of Columbia (DC)", + 12: "Florida (FL)", + 13: "Georgia (GA)", + 15: "Hawaii (HI)", + 16: "Idaho (ID)", + 17: "Illinois (IL)", + 18: "Indiana (IN)", + 19: "Iowa (IA)", + 20: "Kansas (KS)", + 21: "Kentucky (KY)", + 22: "Louisiana (LA)", + 23: "Maine (ME)", + 24: "Maryland (MD)", + 25: "Massachusetts (MA)", + 26: "Michigan (MI)", + 27: "Minnesota (MN)", + 28: "Mississippi (MS)", + 29: "Missouri (MO)", + 30: "Montana (MT)", + 31: "Nebraska (NE)", + 32: "Nevada (NV)", + 33: "New Hampshire (NH)", + 34: "New Jersey (NJ)", + 35: "New Mexico (NM)", + 36: "New York (NY)", + 37: "North Carolina (NC)", + 38: "North Dakota (ND)", + 39: "Ohio (OH)", + 40: "Oklahoma (OK)", + 41: "Oregon (OR)", + 42: "Pennsylvania (PA)", + 44: "Rhode Island (RI)", + 45: "South Carolina (SC)", + 46: "South Dakota (SD)", + 47: "Tennessee (TN)", + 48: "Texas (TX)", + 49: "Utah (UT)", + 50: "Vermont (VT)", + 51: "Virginia (VA)", + 53: "Washington (WA)", + 54: "West Virginia (WV)", + 55: "Wisconsin (WI)", + 56: "Wyoming (WY)", +} + + +def fetch_congressional_districts(year): + cache_file = f"acs5_congressional_districts_{year}.json" + if is_cached(cache_file): + logger.info(f"Using cached {cache_file}") + data = load_json(cache_file) + else: + base_url = f"https://api.census.gov/data/{year}/acs/acs5" + params = { + "get": "NAME", + "for": "congressional district:*", + "in": "state:*", + } + logger.info(f"Downloading congressional districts for {year}") + response = requests.get(base_url, params=params) + response.raise_for_status() + data = response.json() + save_json(cache_file, data) + + df = pd.DataFrame(data[1:], columns=data[0]) + df["state_fips"] = df["state"].astype(int) + df = df[df["state_fips"] <= 56].copy() + df["district_number"] = df["congressional district"].apply( + lambda x: 0 if x in ["ZZ", "98"] else int(x) + ) -def main(): - # Get the implied hierarchy by the UCGID enum -------- - rows = [] - for node in UCGID: - codes = node.get_hierarchical_codes() - rows.append( - { - "name": node.name, - "code": codes[0], - "parent": codes[1] if len(codes) > 1 else None, - } - ) + df["n_districts"] = df.groupby("state_fips")["state_fips"].transform( + "count" + ) + df = df[(df["n_districts"] == 1) | (df["district_number"] > 0)].copy() + df = df.drop(columns=["n_districts"]) - hierarchy_df = ( - pd.DataFrame(rows) - .sort_values(["parent", "code"], na_position="first") - .reset_index(drop=True) + df.loc[df["district_number"] == 0, "district_number"] = 1 + df["congressional_district_geoid"] = ( + df["state_fips"] * 100 + df["district_number"] ) + df = df[ + [ + "state_fips", + "district_number", + "congressional_district_geoid", + "NAME", + ] + ] + df = df.sort_values("congressional_district_geoid") + + return df + + +def main(): + year = 2023 + cd_df = fetch_congressional_districts(year) + DATABASE_URL = ( f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" ) engine = create_engine(DATABASE_URL) - # map the ucgid_str 'code' to auto-generated 'stratum_id' - code_to_stratum_id: Dict[str, int] = {} - with Session(engine) as session: - for _, row in hierarchy_df.iterrows(): - parent_code = row["parent"] + us_stratum = Stratum( + parent_stratum_id=None, + notes="United States", + stratum_group_id=1, + ) + us_stratum.constraints_rel = [] + session.add(us_stratum) + session.flush() + us_stratum_id = us_stratum.stratum_id - parent_id = ( - code_to_stratum_id.get(parent_code) if parent_code else None - ) + state_stratum_ids = {} - new_stratum = Stratum( - parent_stratum_id=parent_id, - notes=f'{row["name"]} (ucgid {row["code"]})', + unique_states = cd_df["state_fips"].unique() + for state_fips in sorted(unique_states): + state_name = STATE_NAMES.get( + state_fips, f"State FIPS {state_fips}" + ) + state_stratum = Stratum( + parent_stratum_id=us_stratum_id, + notes=state_name, stratum_group_id=1, ) - - new_stratum.constraints_rel = [ + state_stratum.constraints_rel = [ StratumConstraint( - constraint_variable="ucgid_str", - operation="in", - value=row["code"], + constraint_variable="state_fips", + operation="==", + value=str(state_fips), ) ] - - session.add(new_stratum) - + session.add(state_stratum) session.flush() + state_stratum_ids[state_fips] = state_stratum.stratum_id - code_to_stratum_id[row["code"]] = new_stratum.stratum_id + for _, row in cd_df.iterrows(): + state_fips = row["state_fips"] + cd_geoid = row["congressional_district_geoid"] + name = row["NAME"] + + cd_stratum = Stratum( + parent_stratum_id=state_stratum_ids[state_fips], + notes=f"{name} (CD GEOID {cd_geoid})", + stratum_group_id=1, + ) + cd_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable="congressional_district_geoid", + operation="==", + value=str(cd_geoid), + ) + ] + session.add(cd_stratum) session.commit() diff --git a/policyengine_us_data/db/etl_age.py b/policyengine_us_data/db/etl_age.py index 9ce8f8a17..01cbbe308 100644 --- a/policyengine_us_data/db/etl_age.py +++ b/policyengine_us_data/db/etl_age.py @@ -5,11 +5,13 @@ from policyengine_us_data.storage import STORAGE_FOLDER from policyengine_us_data.db.create_database_tables import ( + SourceType, Stratum, StratumConstraint, Target, ) from policyengine_us_data.utils.census import get_census_docs, pull_acs_table +from policyengine_us_data.utils.db_metadata import get_or_create_source LABEL_TO_SHORT = { "Estimate!!Total!!Total population!!AGE!!Under 5 years": "0-4", @@ -81,7 +83,6 @@ def transform_age_data(age_data, docs): df_long["age_less_than"] = age_bounds[["lt"]] df_long["variable"] = "person_count" df_long["reform_id"] = 0 - df_long["source_id"] = 1 df_long["active"] = True return df_long @@ -117,6 +118,16 @@ def load_age_data(df_long, geo, year, stratum_lookup=None): stratum_lookup[geo] = {} with Session(engine) as session: + census_source = get_or_create_source( + session, + name="Census ACS Table S0101", + source_type=SourceType.SURVEY, + vintage=f"{year} ACS 5-year estimates", + description="American Community Survey Age and Sex demographics", + url="https://data.census.gov/", + notes="Age distribution in 18 brackets across all geographic levels", + ) + for _, row in df_long.iterrows(): # Create the parent Stratum object. # We will attach children to it before adding it to the session. @@ -164,7 +175,7 @@ def load_age_data(df_long, geo, year, stratum_lookup=None): variable=row["variable"], period=year, value=row["value"], - source_id=row["source_id"], + source_id=census_source.source_id, active=row["active"], ) ) diff --git a/policyengine_us_data/db/etl_irs_soi.py b/policyengine_us_data/db/etl_irs_soi.py index 879bd9a20..a280b006a 100644 --- a/policyengine_us_data/db/etl_irs_soi.py +++ b/policyengine_us_data/db/etl_irs_soi.py @@ -1,3 +1,4 @@ +import logging from typing import Optional import numpy as np @@ -6,18 +7,29 @@ from sqlmodel import Session, create_engine from policyengine_us_data.storage import STORAGE_FOLDER +from policyengine_us_data.utils.raw_cache import ( + is_cached, + cache_path, + save_bytes, +) + +logger = logging.getLogger(__name__) from policyengine_us_data.db.create_database_tables import ( + SourceType, Stratum, StratumConstraint, Target, ) +from policyengine_us_data.utils.db_metadata import get_or_create_source from policyengine_us_data.utils.db import ( get_stratum_by_id, get_simple_stratum_by_ucgid, get_root_strata, get_stratum_children, get_stratum_parent, + parse_ucgid, + get_geographic_strata, ) from policyengine_us_data.utils.census import TERRITORY_UCGIDS from policyengine_us_data.storage.calibration_targets.make_district_mapping import ( @@ -149,7 +161,19 @@ def extract_soi_data() -> pd.DataFrame: In the file below, "22" is 2022, "in" is individual returns, "cd" is congressional districts """ - return pd.read_csv("https://www.irs.gov/pub/irs-soi/22incd.csv") + import requests + + cache_file = "irs_soi_22incd.csv" + if is_cached(cache_file): + logger.info(f"Using cached {cache_file}") + return pd.read_csv(cache_path(cache_file)) + + url = "https://www.irs.gov/pub/irs-soi/22incd.csv" + logger.info(f"Downloading IRS SOI data from {url}") + response = requests.get(url) + response.raise_for_status() + save_bytes(cache_file, response.content) + return pd.read_csv(cache_path(cache_file)) def transform_soi_data(raw_df): @@ -282,6 +306,20 @@ def transform_soi_data(raw_df): return converted +def _lookup_geo_stratum(session, ucgid_str, geo_map): + """Look up a geographic stratum by ucgid string.""" + info = parse_ucgid(ucgid_str) + if info["type"] == "national": + sid = geo_map["national"] + elif info["type"] == "state": + sid = geo_map["state"].get(info["state_fips"]) + elif info["type"] == "district": + sid = geo_map["district"].get(info["congressional_district_geoid"]) + else: + return None + return get_stratum_by_id(session, sid) if sid else None + + def load_soi_data(long_dfs, year): """Load a list of databases into the db, critically dependent on order""" @@ -292,6 +330,18 @@ def load_soi_data(long_dfs, year): session = Session(engine) + irs_source = get_or_create_source( + session, + name="IRS Statistics of Income", + source_type=SourceType.ADMINISTRATIVE, + vintage=f"{year} Tax Year", + description="IRS Statistics of Income administrative tax data", + url="https://www.irs.gov/statistics", + notes="Tax return data by congressional district, state, and national levels", + ) + + geo_map = get_geographic_strata(session) + # Load EITC data -------------------------------------------------------- eitc_data = { "0": (long_dfs[0], long_dfs[1]), @@ -355,7 +405,7 @@ def load_soi_data(long_dfs, year): variable="eitc", period=year, value=eitc_amount_i.iloc[i][["target_value"]].values[0], - source_id=5, + source_id=irs_source.source_id, active=True, ) ] @@ -387,8 +437,8 @@ def load_soi_data(long_dfs, year): for i in range(count_j.shape[0]): ucgid_i = count_j[["ucgid_str"]].iloc[i].values[0] - # Reusing an existing stratum this time, since there is no breakdown - stratum = get_simple_stratum_by_ucgid(session, ucgid_i) + # Reusing an existing geographic stratum + stratum = _lookup_geo_stratum(session, ucgid_i, geo_map) amount_value = amount_j.iloc[i][["target_value"]].values[0] stratum.targets_rel.append( @@ -396,7 +446,7 @@ def load_soi_data(long_dfs, year): variable=amount_variable_name, period=year, value=amount_value, - source_id=5, + source_id=irs_source.source_id, active=True, ) ) @@ -412,7 +462,7 @@ def load_soi_data(long_dfs, year): for i in range(agi_values.shape[0]): ucgid_i = agi_values[["ucgid_str"]].iloc[i].values[0] - stratum = get_simple_stratum_by_ucgid(session, ucgid_i) + stratum = _lookup_geo_stratum(session, ucgid_i, geo_map) stratum.targets_rel.append( Target( variable="adjusted_gross_income", @@ -513,7 +563,7 @@ def load_soi_data(long_dfs, year): variable="person_count", period=year, value=person_count, - source_id=5, + source_id=irs_source.source_id, active=True, ) ) diff --git a/policyengine_us_data/db/etl_medicaid.py b/policyengine_us_data/db/etl_medicaid.py index d420edd0d..67bf8db56 100644 --- a/policyengine_us_data/db/etl_medicaid.py +++ b/policyengine_us_data/db/etl_medicaid.py @@ -1,3 +1,4 @@ +import logging import requests import pandas as pd @@ -6,35 +7,58 @@ from policyengine_us_data.storage import STORAGE_FOLDER from policyengine_us_data.db.create_database_tables import ( + SourceType, Stratum, StratumConstraint, Target, ) from policyengine_us_data.utils.census import STATE_ABBREV_TO_FIPS +from policyengine_us_data.utils.db_metadata import get_or_create_source +from policyengine_us_data.utils.raw_cache import ( + is_cached, + cache_path, + save_json, + load_json, + save_bytes, +) +logger = logging.getLogger(__name__) -def extract_medicaid_data(year): - base_url = ( - f"https://api.census.gov/data/{year}/acs/acs1/subject?get=group(S2704)" - ) - url = f"{base_url}&for=congressional+district:*" - response = requests.get(url) - response.raise_for_status() - data = response.json() +def extract_medicaid_data(year): + # Census ACS survey data + census_cache = f"acs_S2704_district_{year}.json" + if is_cached(census_cache): + logger.info(f"Using cached {census_cache}") + data = load_json(census_cache) + else: + base_url = f"https://api.census.gov/data/{year}/acs/acs1/subject?get=group(S2704)" + url = f"{base_url}&for=congressional+district:*" + logger.info(f"Downloading ACS S2704 for {year}") + response = requests.get(url) + response.raise_for_status() + data = response.json() + save_json(census_cache, data) headers = data[0] data_rows = data[1:] cd_survey_df = pd.DataFrame(data_rows, columns=headers) - item = "6165f45b-ca93-5bb5-9d06-db29c692a360" - response = requests.get( - f"https://data.medicaid.gov/api/1/metastore/schemas/dataset/items/{item}?show-reference-ids=false" - ) - metadata = response.json() - - data_url = metadata["distribution"][0]["data"]["downloadURL"] - state_admin_df = pd.read_csv(data_url) + # CMS Medicaid administrative data + cms_cache = f"medicaid_enrollment_{year}.csv" + if is_cached(cms_cache): + logger.info(f"Using cached {cms_cache}") + state_admin_df = pd.read_csv(cache_path(cms_cache)) + else: + item = "6165f45b-ca93-5bb5-9d06-db29c692a360" + logger.info("Downloading Medicaid enrollment from CMS") + response = requests.get( + f"https://data.medicaid.gov/api/1/metastore/schemas/dataset/items/{item}?show-reference-ids=false" + ) + metadata = response.json() + data_url = metadata["distribution"][0]["data"]["downloadURL"] + state_admin_df = pd.read_csv(data_url) + state_admin_df.to_csv(cache_path(cms_cache), index=False) return cd_survey_df, state_admin_df @@ -93,6 +117,25 @@ def load_medicaid_data(long_state, long_cd, year): stratum_lookup = {} with Session(engine) as session: + admin_source = get_or_create_source( + session, + name="Medicaid T-MSIS", + source_type=SourceType.ADMINISTRATIVE, + vintage=f"{year} Final Report", + description="Medicaid Transformed MSIS administrative enrollment data", + url="https://data.medicaid.gov/", + notes="State-level Medicaid enrollment from administrative records", + ) + survey_source = get_or_create_source( + session, + name="Census ACS Table S2704", + source_type=SourceType.SURVEY, + vintage=f"{year} ACS 1-year estimates", + description="American Community Survey health insurance coverage data", + url="https://data.census.gov/", + notes="Congressional district level Medicaid coverage from ACS", + ) + # National ---------------- nat_stratum = Stratum( parent_stratum_id=None, @@ -146,7 +189,7 @@ def load_medicaid_data(long_state, long_cd, year): variable="person_count", period=year, value=row["medicaid_enrollment"], - source_id=2, + source_id=admin_source.source_id, active=True, ) ) @@ -184,7 +227,7 @@ def load_medicaid_data(long_state, long_cd, year): variable="person_count", period=year, value=row["medicaid_enrollment"], - source_id=2, + source_id=survey_source.source_id, active=True, ) ) diff --git a/policyengine_us_data/db/etl_snap.py b/policyengine_us_data/db/etl_snap.py index 7f73c2b78..6f1a64767 100644 --- a/policyengine_us_data/db/etl_snap.py +++ b/policyengine_us_data/db/etl_snap.py @@ -1,3 +1,4 @@ +import logging import requests import zipfile import io @@ -10,6 +11,7 @@ from policyengine_us_data.storage import STORAGE_FOLDER from policyengine_us_data.db.create_database_tables import ( + SourceType, Stratum, StratumConstraint, Target, @@ -18,15 +20,28 @@ pull_acs_table, STATE_NAME_TO_FIPS, ) +from policyengine_us_data.utils.db_metadata import get_or_create_source +from policyengine_us_data.utils.raw_cache import ( + is_cached, + cache_path, + save_bytes, + load_bytes, +) + +logger = logging.getLogger(__name__) def extract_administrative_snap_data(year=2023): """ Downloads and extracts annual state-level SNAP data from the USDA FNS zip file. """ + cache_file = "snap_fy69tocurrent.zip" + if is_cached(cache_file): + logger.info(f"Using cached {cache_file}") + return zipfile.ZipFile(io.BytesIO(load_bytes(cache_file))) + url = "https://www.fns.usda.gov/sites/default/files/resource-files/snap-zip-fy69tocurrent-6.zip" - # Note: extra complexity in request due to regional restrictions on downloads (e.g., Spain) headers = { "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8", @@ -40,18 +55,17 @@ def extract_administrative_snap_data(year=2023): session = requests.Session() session.headers.update(headers) - # Try to visit the main page first to get any necessary cookies main_page = "https://www.fns.usda.gov/pd/supplemental-nutrition-assistance-program-snap" try: session.get(main_page, timeout=30) except: - pass # Ignore errors on the main page + pass + logger.info("Downloading SNAP data from USDA FNS") response = session.get(url, timeout=30, allow_redirects=True) response.raise_for_status() except requests.exceptions.RequestException as e: print(f"Error downloading file: {e}") - # Try alternative URL or method try: alt_url = "https://www.fns.usda.gov/sites/default/files/resource-files/snap-zip-fy69tocurrent-6.zip" response = session.get(alt_url, timeout=30, allow_redirects=True) @@ -60,6 +74,7 @@ def extract_administrative_snap_data(year=2023): print(f"Alternative URL also failed: {e2}") return None + save_bytes(cache_file, response.content) return zipfile.ZipFile(io.BytesIO(response.content)) @@ -154,6 +169,16 @@ def load_administrative_snap_data(df_states, year): stratum_lookup = {} with Session(engine) as session: + admin_source = get_or_create_source( + session, + name="USDA FNS SNAP Data", + source_type=SourceType.ADMINISTRATIVE, + vintage=f"FY {year}", + description="SNAP administrative data from USDA Food and Nutrition Service", + url="https://www.fns.usda.gov/pd/supplemental-nutrition-assistance-program-snap", + notes="State-level administrative totals for households and costs", + ) + # National ---------------- nat_stratum = Stratum( parent_stratum_id=None, @@ -209,7 +234,7 @@ def load_administrative_snap_data(df_states, year): variable="household_count", period=year, value=row["Households"], - source_id=3, + source_id=admin_source.source_id, active=True, ) ) @@ -218,7 +243,7 @@ def load_administrative_snap_data(df_states, year): variable="snap", period=year, value=row["Cost"], - source_id=3, + source_id=admin_source.source_id, active=True, ) ) @@ -242,6 +267,16 @@ def load_survey_snap_data(survey_df, year, stratum_lookup=None): engine = create_engine(DATABASE_URL) with Session(engine) as session: + survey_source = get_or_create_source( + session, + name="Census ACS Table S2201", + source_type=SourceType.SURVEY, + vintage=f"{year} ACS 5-year estimates", + description="American Community Survey SNAP/Food Stamps data", + url="https://data.census.gov/", + notes="Congressional district level SNAP household counts from ACS", + ) + # Create new strata for districts whose households recieve SNAP benefits district_df = survey_df.copy() for _, row in district_df.iterrows(): @@ -271,7 +306,7 @@ def load_survey_snap_data(survey_df, year, stratum_lookup=None): variable="household_count", period=year, value=row["snap_household_ct"], - source_id=4, + source_id=survey_source.source_id, active=True, ) ) diff --git a/policyengine_us_data/db/validate_database.py b/policyengine_us_data/db/validate_database.py index 2fa819f29..3760706b8 100644 --- a/policyengine_us_data/db/validate_database.py +++ b/policyengine_us_data/db/validate_database.py @@ -20,6 +20,7 @@ if not var_name in system.variables.keys(): raise ValueError(f"{var_name} not a policyengine-us variable") -for var_name in set(stratum_constraints_df["constraint_variable"]): - if not var_name in system.variables.keys(): - raise ValueError(f"{var_name} not a policyengine-us variable") +constraint_vars = set(stratum_constraints_df["constraint_variable"]) +print(f"Constraint variables: {sorted(constraint_vars)}") +print(f"Target variables validated: {len(set(targets_df['variable']))}") +print("Validation passed.") diff --git a/policyengine_us_data/storage/upload_completed_datasets.py b/policyengine_us_data/storage/upload_completed_datasets.py index e99eed012..c8b903cf6 100644 --- a/policyengine_us_data/storage/upload_completed_datasets.py +++ b/policyengine_us_data/storage/upload_completed_datasets.py @@ -1,7 +1,5 @@ from policyengine_us_data.datasets import ( EnhancedCPS_2024, - Pooled_3_Year_CPS_2023, - CPS_2023, ) from policyengine_us_data.storage import STORAGE_FOLDER from policyengine_us_data.utils.data_upload import upload_data_files @@ -12,10 +10,8 @@ def upload_datasets(): dataset_files = [ EnhancedCPS_2024.file_path, - Pooled_3_Year_CPS_2023.file_path, - CPS_2023.file_path, STORAGE_FOLDER / "small_enhanced_cps_2024.h5", - # STORAGE_FOLDER / "policy_data.db", + STORAGE_FOLDER / "calibration" / "policy_data.db", ] # Filter to only existing files diff --git a/policyengine_us_data/tests/test_local_area_calibration/conftest.py b/policyengine_us_data/tests/test_local_area_calibration/conftest.py index 7abcbafbf..20b7f05bd 100644 --- a/policyengine_us_data/tests/test_local_area_calibration/conftest.py +++ b/policyengine_us_data/tests/test_local_area_calibration/conftest.py @@ -58,7 +58,7 @@ def db_uri(): @pytest.fixture(scope="module") def dataset_path(): - return str(STORAGE_FOLDER / "stratified_extended_cps_2023.h5") + return str(STORAGE_FOLDER / "stratified_extended_cps_2024.h5") @pytest.fixture(scope="module") diff --git a/policyengine_us_data/utils/census.py b/policyengine_us_data/utils/census.py index 8081b6162..cb9d0b5d8 100644 --- a/policyengine_us_data/utils/census.py +++ b/policyengine_us_data/utils/census.py @@ -1,9 +1,18 @@ +import logging import pathlib import requests import pandas as pd import numpy as np +from policyengine_us_data.utils.raw_cache import ( + is_cached, + save_json, + load_json, +) + +logger = logging.getLogger(__name__) + STATE_NAME_TO_FIPS = { "Alabama": "01", "Alaska": "02", @@ -126,13 +135,17 @@ def get_census_docs(year): docs_url = ( f"https://api.census.gov/data/{year}/acs/acs1/subject/variables.json" ) - # NOTE: The URL for detail tables, should we ever need it is: - # "https://api.census.gov/data/2023/acs/acs1/variables.json" + cache_file = f"census_docs_{year}.json" + if is_cached(cache_file): + logger.info(f"Using cached {cache_file}") + return load_json(cache_file) + logger.info(f"Downloading census docs for {year}") docs_response = requests.get(docs_url) docs_response.raise_for_status() - - return docs_response.json() + data = docs_response.json() + save_json(cache_file, data) + return data def pull_acs_table(group: str, geo: str, year: int) -> pd.DataFrame: @@ -141,6 +154,13 @@ def pull_acs_table(group: str, geo: str, year: int) -> pd.DataFrame: "geo": 'National' | 'State' | 'District' "year": e.g., 2023 """ + cache_file = f"acs_{group}_{geo.lower()}_{year}.json" + if is_cached(cache_file): + logger.info(f"Using cached {cache_file}") + data = load_json(cache_file) + headers, rows = data[0], data[1:] + return pd.DataFrame(rows, columns=headers) + base = f"https://api.census.gov/data/{year}/acs/acs1" if group[0] == "S": @@ -153,7 +173,9 @@ def pull_acs_table(group: str, geo: str, year: int) -> pd.DataFrame: url = f"{base}?get=group({group})&for={geo_q}" + logger.info(f"Downloading ACS table {group} ({geo}) for {year}") data = requests.get(url).json() + save_json(cache_file, data) headers, rows = data[0], data[1:] df = pd.DataFrame(rows, columns=headers) return df diff --git a/policyengine_us_data/utils/raw_cache.py b/policyengine_us_data/utils/raw_cache.py new file mode 100644 index 000000000..e9ccb1290 --- /dev/null +++ b/policyengine_us_data/utils/raw_cache.py @@ -0,0 +1,37 @@ +import json +import logging +import os +from pathlib import Path + +from policyengine_us_data.storage import STORAGE_FOLDER + +logger = logging.getLogger(__name__) + +RAW_INPUTS_DIR = STORAGE_FOLDER / "calibration" / "raw_inputs" +RAW_INPUTS_DIR.mkdir(parents=True, exist_ok=True) + +REFRESH = os.environ.get("PE_REFRESH_RAW", "0") == "1" + + +def cache_path(filename: str) -> Path: + return RAW_INPUTS_DIR / filename + + +def is_cached(filename: str) -> bool: + return cache_path(filename).exists() and not REFRESH + + +def save_json(filename: str, data): + cache_path(filename).write_text(json.dumps(data, ensure_ascii=False)) + + +def load_json(filename: str): + return json.loads(cache_path(filename).read_text()) + + +def save_bytes(filename: str, data: bytes): + cache_path(filename).write_bytes(data) + + +def load_bytes(filename: str) -> bytes: + return cache_path(filename).read_bytes() From cca5efcd35ebda06607f7e2107d73c2d1d291014 Mon Sep 17 00:00:00 2001 From: "baogorek@gmail.com" Date: Wed, 28 Jan 2026 09:40:39 -0500 Subject: [PATCH 4/8] Port complete DB/ETL logic with raw_cache integration and conditional strata Replace simplified DB pipeline with full implementation: - IRS SOI: 19 conditional strata groups (100-118) with filer population layer - Variables: income_tax_before_credits, rental_income, self_employment_income, net_capital_gains, and complete AGI distribution with tax_unit_count - Medicaid: 2024 admin data (CD survey disabled pending 119th Congress remap) - All ETL extract functions now use raw_cache for offline iteration New files: validate_hierarchy.py, migrate_stratum_group_ids.py, IRS_SOI_DATA_ISSUE.md Verified: 53 target groups, 32,781 targets, X_sparse (32781, 4577564) Co-Authored-By: Claude Opus 4.5 --- policyengine_us_data/db/DATABASE_GUIDE.md | 28 +- policyengine_us_data/db/IRS_SOI_DATA_ISSUE.md | 109 ++ .../db/create_database_tables.py | 124 +-- .../db/create_initial_strata.py | 124 ++- policyengine_us_data/db/etl_age.py | 179 +++- policyengine_us_data/db/etl_irs_soi.py | 985 ++++++++++++++---- policyengine_us_data/db/etl_medicaid.py | 249 +++-- .../db/etl_national_targets.py | 159 +-- policyengine_us_data/db/etl_snap.py | 168 +-- .../db/migrate_stratum_group_ids.py | 137 +++ policyengine_us_data/db/validate_database.py | 7 +- policyengine_us_data/db/validate_hierarchy.py | 326 ++++++ 12 files changed, 2003 insertions(+), 592 deletions(-) create mode 100644 policyengine_us_data/db/IRS_SOI_DATA_ISSUE.md create mode 100644 policyengine_us_data/db/migrate_stratum_group_ids.py create mode 100644 policyengine_us_data/db/validate_hierarchy.py diff --git a/policyengine_us_data/db/DATABASE_GUIDE.md b/policyengine_us_data/db/DATABASE_GUIDE.md index f751e1d7d..8cd0002af 100644 --- a/policyengine_us_data/db/DATABASE_GUIDE.md +++ b/policyengine_us_data/db/DATABASE_GUIDE.md @@ -30,7 +30,7 @@ make promote-database # Copy DB + raw inputs to HuggingFace clone | 4 | `etl_age.py` | Census ACS 1-year | Age distribution: 18 bins x 488 geographies | | 5 | `etl_medicaid.py` | Census ACS + CMS | Medicaid enrollment (admin state-level, survey district-level) | | 6 | `etl_snap.py` | USDA FNS + Census ACS | SNAP participation (admin state-level, survey district-level) | -| 7 | `etl_irs_soi.py` | IRS | Tax variables, EITC by child count, AGI brackets | +| 7 | `etl_irs_soi.py` | IRS | Tax variables, EITC by child count, AGI brackets, conditional strata | | 8 | `validate_database.py` | No | Checks all target variables exist in policyengine-us | ### Raw Input Caching @@ -103,11 +103,27 @@ The `stratum_group_id` field categorizes strata: |----|----------|-------------| | 0 | Uncategorized | Legacy strata not yet assigned a group | | 1 | Geographic | US, states, congressional districts | -| 2 | Age | 18 age brackets per geography | +| 2 | Age/Filer population | Age brackets, tax filer intermediate strata | | 3 | Income/AGI | 9 income brackets per geography | | 4 | SNAP | SNAP recipient strata | | 5 | Medicaid | Medicaid enrollment strata | | 6 | EITC | EITC recipients by qualifying children | +| 100-118 | IRS Conditional | Each IRS variable paired with conditional count constraints | + +### Conditional Strata (IRS SOI) + +IRS variables use a "filer population" intermediate layer and conditional strata: + +``` +Geographic stratum (group_id=1) + └── Tax Filer stratum (group_id=2, constraint: tax_unit_is_filer==1) + ├── AGI bracket strata (group_id=3, constraint: AGI range) + ├── EITC by child count (group_id=6, constraint: eitc_child_count) + └── IRS variable strata (group_id=100+, constraint: variable > 0) + - Targets: tax_unit_count + variable amount +``` + +Each IRS variable (e.g., `rental_income`, `self_employment_income`) gets its own stratum_group_id (100+) with a constraint requiring that variable > 0. This captures both the count of filers with that income type and the total amount. ### Geographic Hierarchy @@ -138,8 +154,6 @@ ETL scripts that pull Census data receive UCGIDs and create their own domain-spe All constraints use standardized operators validated by the `ConstraintOperation` enum: `==`, `!=`, `>`, `>=`, `<`, `<=` -Note: Some legacy ETL scripts use string operations like `"in"`, `"equals"`, `"greater_than"`. These coexist in the database but new code should use the standardized operators. - ### Constraint Value Types The `value` column stores all values as strings. Downstream code deserializes: @@ -164,6 +178,10 @@ WHERE sc.constraint_variable = 'state_fips' AND sc.value = '37'; ``` +### IRS SOI A59664 Data Issue + +The IRS SOI column A59664 (EITC with 3+ children amount) is reported in dollars, not thousands like all other monetary columns. The ETL code detects and compensates for this. See `IRS_SOI_DATA_ISSUE.md` for details. + ## Utility Functions **`policyengine_us_data/utils/db.py`**: @@ -193,7 +211,7 @@ SELECT CASE stratum_group_id WHEN 0 THEN 'Uncategorized' WHEN 1 THEN 'Geographic' - WHEN 2 THEN 'Age' + WHEN 2 THEN 'Age/Filer' WHEN 3 THEN 'Income/AGI' WHEN 4 THEN 'SNAP' WHEN 5 THEN 'Medicaid' diff --git a/policyengine_us_data/db/IRS_SOI_DATA_ISSUE.md b/policyengine_us_data/db/IRS_SOI_DATA_ISSUE.md new file mode 100644 index 000000000..3d7225166 --- /dev/null +++ b/policyengine_us_data/db/IRS_SOI_DATA_ISSUE.md @@ -0,0 +1,109 @@ +# IRS SOI Data Inconsistency: A59664 Units Issue + +## Summary +The IRS Statistics of Income (SOI) Congressional District data file has an undocumented data inconsistency where column A59664 (EITC amount for 3+ children) is reported in **dollars** instead of **thousands of dollars** like all other monetary columns. + +## Discovery Date +December 2024 + +## Affected Data +- **File**: https://www.irs.gov/pub/irs-soi/22incd.csv (and likely other years) +- **Column**: A59664 - "Earned income credit with three qualifying children amount" +- **Issue**: Value is in dollars, not thousands of dollars + +## Evidence + +### 1. Documentation States All Money in Thousands +From the IRS SOI documentation: "For all the files, the money amounts are reported in thousands of dollars." + +### 2. Data Analysis Shows Inconsistency +California example from 2022 data: +``` +A59661 (EITC 0 children): 284,115 (thousands) = $284M ✓ +A59662 (EITC 1 child): 2,086,260 (thousands) = $2.1B ✓ +A59663 (EITC 2 children): 2,067,922 (thousands) = $2.1B ✓ +A59664 (EITC 3+ children): 1,248,669,042 (if thousands) = $1.25 TRILLION ✗ +``` + +### 3. Total EITC Confirms the Issue +``` +A59660 (Total EITC): 5,687,167 (thousands) = $5.69B + +Sum with A59664 as dollars: $5.69B ✓ (matches!) +Sum with A59664 as thousands: $1.25T ✗ (way off!) +``` + +### 4. Pattern Across All States +The ratio of A59664 to A59663 is consistently ~600x across all states: +- California: 603.8x +- North Carolina: 598.9x +- New York: 594.2x +- Texas: 691.5x + +If both were in the same units, this ratio should be 0.5-2x. + +## Additional Finding: "Three" Means "Three or More" + +The documentation says "three qualifying children" but the data shows this represents "three or more": +- Sum of N59661 + N59662 + N59663 + N59664 = 23,261,270 +- N59660 (Total EITC recipients) = 23,266,630 +- Difference: 5,360 (0.02% - essentially equal) + +This confirms that category 4 represents families with 3+ children, not exactly 3. + +## Fix Applied + +In `etl_irs_soi.py`, we now divide A59664 by 1000 before applying the standard multiplier: + +```python +if amount_col == 'A59664': + # Convert from dollars to thousands to match other columns + rec_amounts["target_value"] /= 1_000 +``` + +## Impact Before Fix +- EITC calibration targets for 3+ children were 1000x too high +- California target: $1.25 trillion instead of $1.25 billion +- Made calibration impossible to converge for EITC + +## Verification Steps +1. Download IRS SOI data for any year +2. Check A59660 (total EITC) value +3. Sum A59661-A59664 with A59664 divided by 1000 +4. Confirm sum matches A59660 + +## Recommendation for IRS +The IRS should either: +1. Fix the data to report A59664 in thousands like other columns +2. Document this exception clearly in their documentation + +## Verification Code + +To verify this issue or check if the IRS has fixed it: + +```python +import pandas as pd + +# Load IRS data +df = pd.read_csv('https://www.irs.gov/pub/irs-soi/22incd.csv') +us_data = df[(df['STATE'] == 'US') & (df['agi_stub'] == 0)] + +# Get EITC values +a61 = us_data['A59661'].values[0] * 1000 # 0 children (convert from thousands) +a62 = us_data['A59662'].values[0] * 1000 # 1 child +a63 = us_data['A59663'].values[0] * 1000 # 2 children +a64 = us_data['A59664'].values[0] # 3+ children (already in dollars!) +total = us_data['A59660'].values[0] * 1000 # Total EITC + +print(f'Sum with A59664 as dollars: ${(a61 + a62 + a63 + a64):,.0f}') +print(f'Total EITC (A59660): ${total:,.0f}') +print(f'Match: {abs(total - (a61 + a62 + a63 + a64)) < 1e6}') + +# Check ratio to confirm inconsistency +ratio = us_data['A59664'].values[0] / us_data['A59663'].values[0] +print(f'\nA59664/A59663 ratio: {ratio:.1f}x') +print('(Should be ~0.5-2x if same units, but is ~600x)') +``` + +## Related Files +- `/home/baogorek/devl/policyengine-us-data/policyengine_us_data/db/etl_irs_soi.py` - ETL script with fix and auto-detection \ No newline at end of file diff --git a/policyengine_us_data/db/create_database_tables.py b/policyengine_us_data/db/create_database_tables.py index 4b526f7ef..9485d02e1 100644 --- a/policyengine_us_data/db/create_database_tables.py +++ b/policyengine_us_data/db/create_database_tables.py @@ -64,15 +64,13 @@ class Stratum(SQLModel, table=True): default=None, foreign_key="strata.stratum_id", index=True, - description=("Identifier for a parent stratum, creating a hierarchy."), + description="Identifier for a parent stratum, creating a hierarchy.", ) stratum_group_id: Optional[int] = Field( - default=None, - description="Identifier for a group of related strata.", + default=None, description="Identifier for a group of related strata." ) notes: Optional[str] = Field( - default=None, - description="Descriptive notes about the stratum.", + default=None, description="Descriptive notes about the stratum." ) children_rel: List["Stratum"] = Relationship( @@ -104,18 +102,17 @@ class StratumConstraint(SQLModel, table=True): stratum_id: int = Field(foreign_key="strata.stratum_id", primary_key=True) constraint_variable: str = Field( primary_key=True, - description=("The variable the constraint applies to (e.g., 'age')."), + description="The variable the constraint applies to (e.g., 'age').", ) operation: str = Field( primary_key=True, - description=("The comparison operator (==, !=, >, >=, <, <=)."), + description="The comparison operator (==, !=, >, >=, <, <=).", ) value: str = Field( description="The value for the constraint rule (e.g., '25')." ) notes: Optional[str] = Field( - default=None, - description="Optional notes about the constraint.", + default=None, description="Optional notes about the constraint." ) strata_rel: Stratum = Relationship(back_populates="constraints_rel") @@ -126,8 +123,7 @@ def validate_operation(cls, v): allowed_ops = [op.value for op in ConstraintOperation] if v not in allowed_ops: raise ValueError( - f"Invalid operation '{v}'. " - f"Must be one of: {', '.join(allowed_ops)}" + f"Invalid operation '{v}'. Must be one of: {', '.join(allowed_ops)}" ) return v @@ -148,9 +144,7 @@ class Target(SQLModel, table=True): target_id: Optional[int] = Field(default=None, primary_key=True) variable: USVariable = Field( - description=( - "A variable defined in policyengine-us " "(e.g., 'income_tax')." - ), + description="A variable defined in policyengine-us (e.g., 'income_tax')." ) period: int = Field( description="The time period for the data, typically a year." @@ -158,13 +152,10 @@ class Target(SQLModel, table=True): stratum_id: int = Field(foreign_key="strata.stratum_id", index=True) reform_id: int = Field( default=0, - description=( - "Identifier for a policy reform scenario " "(0 for baseline)." - ), + description="Identifier for a policy reform scenario (0 for baseline).", ) value: Optional[float] = Field( - default=None, - description="The numerical value of the target variable.", + default=None, description="The numerical value of the target variable." ) source_id: Optional[int] = Field( default=None, @@ -173,13 +164,11 @@ class Target(SQLModel, table=True): ) active: bool = Field( default=True, - description=("Flag to indicate if the record is currently active."), + description="Flag to indicate if the record is currently active.", ) tolerance: Optional[float] = Field( default=None, - description=( - "Allowed relative error as a percent " "(e.g., 25 for 25%)." - ), + description="Allowed relative error as a percent (e.g., 25 for 25%).", ) notes: Optional[str] = Field( default=None, @@ -197,7 +186,9 @@ class SourceType(str, Enum): SURVEY = "survey" SYNTHETIC = "synthetic" DERIVED = "derived" - HARDCODED = "hardcoded" + HARDCODED = ( + "hardcoded" # Values from various sources, hardcoded into the system + ) class Source(SQLModel, table=True): @@ -214,29 +205,24 @@ class Source(SQLModel, table=True): description="Unique identifier for the data source.", ) name: str = Field( - description=( - "Name of the data source " "(e.g., 'IRS SOI', 'Census ACS')." - ), + description="Name of the data source (e.g., 'IRS SOI', 'Census ACS').", index=True, ) type: SourceType = Field( - description=("Type of data source (administrative, survey, etc.)."), + description="Type of data source (administrative, survey, etc.)." ) description: Optional[str] = Field( - default=None, - description="Detailed description of the data source.", + default=None, description="Detailed description of the data source." ) url: Optional[str] = Field( default=None, - description=("URL or reference to the original data source."), + description="URL or reference to the original data source.", ) vintage: Optional[str] = Field( - default=None, - description="Version or release date of the data source.", + default=None, description="Version or release date of the data source." ) notes: Optional[str] = Field( - default=None, - description="Additional notes about the source.", + default=None, description="Additional notes about the source." ) @@ -251,54 +237,37 @@ class VariableGroup(SQLModel, table=True): description="Unique identifier for the variable group.", ) name: str = Field( - description=( - "Name of the variable group " - "(e.g., 'age_distribution', 'snap_recipients')." - ), + description="Name of the variable group (e.g., 'age_distribution', 'snap_recipients').", index=True, unique=True, ) category: str = Field( - description=( - "High-level category " - "(e.g., 'demographic', 'benefit', 'tax', 'income')." - ), + description="High-level category (e.g., 'demographic', 'benefit', 'tax', 'income').", index=True, ) is_histogram: bool = Field( default=False, - description=( - "Whether this group represents a " "histogram/distribution." - ), + description="Whether this group represents a histogram/distribution.", ) is_exclusive: bool = Field( default=False, - description=( - "Whether variables in this group are " "mutually exclusive." - ), + description="Whether variables in this group are mutually exclusive.", ) aggregation_method: Optional[str] = Field( default=None, - description=( - "How to aggregate variables in this group " - "(sum, weighted_avg, etc.)." - ), + description="How to aggregate variables in this group (sum, weighted_avg, etc.).", ) display_order: Optional[int] = Field( default=None, - description=( - "Order for displaying this group in " "matrices/reports." - ), + description="Order for displaying this group in matrices/reports.", ) description: Optional[str] = Field( - default=None, - description="Description of what this group represents.", + default=None, description="Description of what this group represents." ) class VariableMetadata(SQLModel, table=True): - """Maps PolicyEngine variables to their groups and provides - metadata.""" + """Maps PolicyEngine variables to their groups and provides metadata.""" __tablename__ = "variable_metadata" __table_args__ = ( @@ -312,48 +281,46 @@ class VariableMetadata(SQLModel, table=True): group_id: Optional[int] = Field( default=None, foreign_key="variable_groups.group_id", - description=("ID of the variable group this belongs to."), + description="ID of the variable group this belongs to.", ) display_name: Optional[str] = Field( default=None, - description=("Human-readable name for display in matrices."), + description="Human-readable name for display in matrices.", ) display_order: Optional[int] = Field( default=None, - description=("Order within its group for display purposes."), + description="Order within its group for display purposes.", ) units: Optional[str] = Field( default=None, - description=( - "Units of measurement " "(dollars, count, percent, etc.)." - ), + description="Units of measurement (dollars, count, percent, etc.).", ) is_primary: bool = Field( default=True, - description=( - "Whether this is a primary variable vs " "derived/auxiliary." - ), + description="Whether this is a primary variable vs derived/auxiliary.", ) notes: Optional[str] = Field( - default=None, - description="Additional notes about the variable.", + default=None, description="Additional notes about the variable." ) group_rel: Optional[VariableGroup] = Relationship() +# This SQLAlchemy event listener works directly with the SQLModel class @event.listens_for(Stratum, "before_insert") @event.listens_for(Stratum, "before_update") def calculate_definition_hash(mapper, connection, target: Stratum): - """Calculate and set the definition_hash before saving a - Stratum instance.""" + """ + Calculate and set the definition_hash before saving a Stratum instance. + """ constraints_history = get_history(target, "constraints_rel") if not ( constraints_history.has_changes() or target.definition_hash is None ): return - if not target.constraints_rel: + if not target.constraints_rel: # Handle cases with no constraints + # Include parent_stratum_id to make hash unique per parent parent_str = ( str(target.parent_stratum_id) if target.parent_stratum_id else "" ) @@ -368,6 +335,7 @@ def calculate_definition_hash(mapper, connection, target: Stratum): ] constraint_strings.sort() + # Include parent_stratum_id in the hash to ensure uniqueness per parent parent_str = ( str(target.parent_stratum_id) if target.parent_stratum_id else "" ) @@ -376,16 +344,14 @@ def calculate_definition_hash(mapper, connection, target: Stratum): target.definition_hash = h.hexdigest() -DB_PATH = STORAGE_FOLDER / "calibration" / "policy_data.db" - - def create_database( db_uri: str = f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}", ): - """Creates a SQLite database and all the defined tables. + """ + Creates a SQLite database and all the defined tables. Args: - db_uri: The connection string for the database. + db_uri (str): The connection string for the database. Returns: An SQLAlchemy Engine instance connected to the database. diff --git a/policyengine_us_data/db/create_initial_strata.py b/policyengine_us_data/db/create_initial_strata.py index 85144b22b..f3edb1b41 100644 --- a/policyengine_us_data/db/create_initial_strata.py +++ b/policyengine_us_data/db/create_initial_strata.py @@ -18,60 +18,6 @@ logger = logging.getLogger(__name__) -STATE_NAMES = { - 1: "Alabama (AL)", - 2: "Alaska (AK)", - 4: "Arizona (AZ)", - 5: "Arkansas (AR)", - 6: "California (CA)", - 8: "Colorado (CO)", - 9: "Connecticut (CT)", - 10: "Delaware (DE)", - 11: "District of Columbia (DC)", - 12: "Florida (FL)", - 13: "Georgia (GA)", - 15: "Hawaii (HI)", - 16: "Idaho (ID)", - 17: "Illinois (IL)", - 18: "Indiana (IN)", - 19: "Iowa (IA)", - 20: "Kansas (KS)", - 21: "Kentucky (KY)", - 22: "Louisiana (LA)", - 23: "Maine (ME)", - 24: "Maryland (MD)", - 25: "Massachusetts (MA)", - 26: "Michigan (MI)", - 27: "Minnesota (MN)", - 28: "Mississippi (MS)", - 29: "Missouri (MO)", - 30: "Montana (MT)", - 31: "Nebraska (NE)", - 32: "Nevada (NV)", - 33: "New Hampshire (NH)", - 34: "New Jersey (NJ)", - 35: "New Mexico (NM)", - 36: "New York (NY)", - 37: "North Carolina (NC)", - 38: "North Dakota (ND)", - 39: "Ohio (OH)", - 40: "Oklahoma (OK)", - 41: "Oregon (OR)", - 42: "Pennsylvania (PA)", - 44: "Rhode Island (RI)", - 45: "South Carolina (SC)", - 46: "South Dakota (SD)", - 47: "Tennessee (TN)", - 48: "Texas (TX)", - 49: "Utah (UT)", - 50: "Vermont (VT)", - 51: "Virginia (VA)", - 53: "Washington (WA)", - 54: "West Virginia (WV)", - 55: "Wisconsin (WI)", - 56: "Wyoming (WY)", -} - def fetch_congressional_districts(year): cache_file = f"acs5_congressional_districts_{year}.json" @@ -85,9 +31,7 @@ def fetch_congressional_districts(year): "for": "congressional district:*", "in": "state:*", } - logger.info(f"Downloading congressional districts for {year}") response = requests.get(base_url, params=params) - response.raise_for_status() data = response.json() save_json(cache_file, data) @@ -98,6 +42,7 @@ def fetch_congressional_districts(year): lambda x: 0 if x in ["ZZ", "98"] else int(x) ) + # Filter out statewide summary records for multi-district states df["n_districts"] = df.groupby("state_fips")["state_fips"].transform( "count" ) @@ -123,6 +68,62 @@ def fetch_congressional_districts(year): def main(): + # State FIPS to name/abbreviation mapping + STATE_NAMES = { + 1: "Alabama (AL)", + 2: "Alaska (AK)", + 4: "Arizona (AZ)", + 5: "Arkansas (AR)", + 6: "California (CA)", + 8: "Colorado (CO)", + 9: "Connecticut (CT)", + 10: "Delaware (DE)", + 11: "District of Columbia (DC)", + 12: "Florida (FL)", + 13: "Georgia (GA)", + 15: "Hawaii (HI)", + 16: "Idaho (ID)", + 17: "Illinois (IL)", + 18: "Indiana (IN)", + 19: "Iowa (IA)", + 20: "Kansas (KS)", + 21: "Kentucky (KY)", + 22: "Louisiana (LA)", + 23: "Maine (ME)", + 24: "Maryland (MD)", + 25: "Massachusetts (MA)", + 26: "Michigan (MI)", + 27: "Minnesota (MN)", + 28: "Mississippi (MS)", + 29: "Missouri (MO)", + 30: "Montana (MT)", + 31: "Nebraska (NE)", + 32: "Nevada (NV)", + 33: "New Hampshire (NH)", + 34: "New Jersey (NJ)", + 35: "New Mexico (NM)", + 36: "New York (NY)", + 37: "North Carolina (NC)", + 38: "North Dakota (ND)", + 39: "Ohio (OH)", + 40: "Oklahoma (OK)", + 41: "Oregon (OR)", + 42: "Pennsylvania (PA)", + 44: "Rhode Island (RI)", + 45: "South Carolina (SC)", + 46: "South Dakota (SD)", + 47: "Tennessee (TN)", + 48: "Texas (TX)", + 49: "Utah (UT)", + 50: "Vermont (VT)", + 51: "Virginia (VA)", + 53: "Washington (WA)", + 54: "West Virginia (WV)", + 55: "Wisconsin (WI)", + 56: "Wyoming (WY)", + } + + # Fetch congressional district data for year 2023 year = 2023 cd_df = fetch_congressional_districts(year) @@ -132,18 +133,26 @@ def main(): engine = create_engine(DATABASE_URL) with Session(engine) as session: + # Truncate existing tables + session.query(StratumConstraint).delete() + session.query(Stratum).delete() + session.commit() + + # Create national level stratum us_stratum = Stratum( parent_stratum_id=None, notes="United States", stratum_group_id=1, ) - us_stratum.constraints_rel = [] + us_stratum.constraints_rel = [] # No constraints for national level session.add(us_stratum) session.flush() us_stratum_id = us_stratum.stratum_id + # Track state strata for parent relationships state_stratum_ids = {} + # Create state-level strata unique_states = cd_df["state_fips"].unique() for state_fips in sorted(unique_states): state_name = STATE_NAMES.get( @@ -165,6 +174,7 @@ def main(): session.flush() state_stratum_ids[state_fips] = state_stratum.stratum_id + # Create congressional district strata for _, row in cd_df.iterrows(): state_fips = row["state_fips"] cd_geoid = row["congressional_district_geoid"] diff --git a/policyengine_us_data/db/etl_age.py b/policyengine_us_data/db/etl_age.py index 01cbbe308..39ffedf22 100644 --- a/policyengine_us_data/db/etl_age.py +++ b/policyengine_us_data/db/etl_age.py @@ -1,17 +1,22 @@ import pandas as pd import numpy as np -from sqlmodel import Session, create_engine +from sqlmodel import Session, create_engine, select from policyengine_us_data.storage import STORAGE_FOLDER from policyengine_us_data.db.create_database_tables import ( - SourceType, Stratum, StratumConstraint, Target, + SourceType, ) from policyengine_us_data.utils.census import get_census_docs, pull_acs_table -from policyengine_us_data.utils.db_metadata import get_or_create_source +from policyengine_us_data.utils.db import parse_ucgid, get_geographic_strata +from policyengine_us_data.utils.db_metadata import ( + get_or_create_source, + get_or_create_variable_group, + get_or_create_variable_metadata, +) LABEL_TO_SHORT = { "Estimate!!Total!!Total population!!AGE!!Under 5 years": "0-4", @@ -88,11 +93,7 @@ def transform_age_data(age_data, docs): return df_long -def get_parent_geo(geo): - return {"National": None, "State": "National", "District": "State"}[geo] - - -def load_age_data(df_long, geo, year, stratum_lookup=None): +def load_age_data(df_long, geo, year): # Quick data quality check before loading ---- if geo == "National": @@ -110,14 +111,8 @@ def load_age_data(df_long, geo, year, stratum_lookup=None): ) engine = create_engine(DATABASE_URL) - if stratum_lookup is None: - if geo != "National": - raise ValueError("Include stratum_lookup unless National geo") - stratum_lookup = {"National": {}} - else: - stratum_lookup[geo] = {} - with Session(engine) as session: + # Get or create the Census ACS source census_source = get_or_create_source( session, name="Census ACS Table S0101", @@ -128,43 +123,139 @@ def load_age_data(df_long, geo, year, stratum_lookup=None): notes="Age distribution in 18 brackets across all geographic levels", ) + # Get or create the age distribution variable group + age_group = get_or_create_variable_group( + session, + name="age_distribution", + category="demographic", + is_histogram=True, + is_exclusive=True, + aggregation_method="sum", + display_order=1, + description="Age distribution in 18 brackets (0-4, 5-9, ..., 85+)", + ) + + # Create variable metadata for person_count + get_or_create_variable_metadata( + session, + variable="person_count", + group=age_group, + display_name="Population Count", + display_order=1, + units="count", + notes="Number of people in age bracket", + ) + + # Fetch existing geographic strata + geo_strata = get_geographic_strata(session) + for _, row in df_long.iterrows(): - # Create the parent Stratum object. - # We will attach children to it before adding it to the session. - note = f"Age: {row['age_range']}, Geo: {row['ucgid_str']}" - parent_geo = get_parent_geo(geo) - parent_stratum_id = ( - stratum_lookup[parent_geo][row["age_range"]] - if parent_geo - else None - ) + # Parse the UCGID to determine geographic info + geo_info = parse_ucgid(row["ucgid_str"]) + + # Determine parent stratum based on geographic level + if geo_info["type"] == "national": + parent_stratum_id = geo_strata["national"] + elif geo_info["type"] == "state": + parent_stratum_id = geo_strata["state"][geo_info["state_fips"]] + elif geo_info["type"] == "district": + parent_stratum_id = geo_strata["district"][ + geo_info["congressional_district_geoid"] + ] + else: + raise ValueError(f"Unknown geography type: {geo_info['type']}") + + # Create the age stratum as a child of the geographic stratum + # Build a proper geographic identifier for the notes + if geo_info["type"] == "national": + geo_desc = "US" + elif geo_info["type"] == "state": + geo_desc = f"State FIPS {geo_info['state_fips']}" + elif geo_info["type"] == "district": + geo_desc = f"CD {geo_info['congressional_district_geoid']}" + else: + geo_desc = "Unknown" + + note = f"Age: {row['age_range']}, {geo_desc}" + + # Check if this age stratum already exists + existing_stratum = session.exec( + select(Stratum).where( + Stratum.parent_stratum_id == parent_stratum_id, + Stratum.stratum_group_id == 2, # Age strata group + Stratum.notes == note, + ) + ).first() + + if existing_stratum: + # Update the existing stratum's target instead of creating a duplicate + existing_target = session.exec( + select(Target).where( + Target.stratum_id == existing_stratum.stratum_id, + Target.variable == row["variable"], + Target.period == year, + ) + ).first() + + if existing_target: + # Update existing target + existing_target.value = row["value"] + else: + # Add new target to existing stratum + new_target = Target( + stratum_id=existing_stratum.stratum_id, + variable=row["variable"], + period=year, + value=row["value"], + source_id=census_source.source_id, + active=row["active"], + ) + session.add(new_target) + continue # Skip creating a new stratum new_stratum = Stratum( parent_stratum_id=parent_stratum_id, - stratum_group_id=0, + stratum_group_id=2, # Age strata group notes=note, ) - # Create constraints and link them to the parent's relationship attribute. - new_stratum.constraints_rel = [ - StratumConstraint( - constraint_variable="ucgid_str", - operation="in", - value=row["ucgid_str"], - ), + # Create constraints including both age and geographic for uniqueness + new_stratum.constraints_rel = [] + + # Add geographic constraints based on level + if geo_info["type"] == "state": + new_stratum.constraints_rel.append( + StratumConstraint( + constraint_variable="state_fips", + operation="==", + value=str(geo_info["state_fips"]), + ) + ) + elif geo_info["type"] == "district": + new_stratum.constraints_rel.append( + StratumConstraint( + constraint_variable="congressional_district_geoid", + operation="==", + value=str(geo_info["congressional_district_geoid"]), + ) + ) + # For national level, no geographic constraint needed + + # Add age constraints + new_stratum.constraints_rel.append( StratumConstraint( constraint_variable="age", - operation="greater_than", + operation=">", value=str(row["age_greater_than"]), - ), - ] + ) + ) age_lt_value = row["age_less_than"] if not np.isinf(age_lt_value): new_stratum.constraints_rel.append( StratumConstraint( constraint_variable="age", - operation="less_than", + operation="<", value=str(row["age_less_than"]), ) ) @@ -184,15 +275,9 @@ def load_age_data(df_long, geo, year, stratum_lookup=None): # The 'cascade' setting will handle the children automatically. session.add(new_stratum) - # Flush to get the id - session.flush() - stratum_lookup[geo][row["age_range"]] = new_stratum.stratum_id - # Commit all the new objects at once. session.commit() - return stratum_lookup - if __name__ == "__main__": @@ -211,8 +296,8 @@ def load_age_data(df_long, geo, year, stratum_lookup=None): long_district_df = transform_age_data(district_df, docs) # --- Load -------- - national_strata_lku = load_age_data(long_national_df, "National", year) - state_strata_lku = load_age_data( - long_state_df, "State", year, national_strata_lku - ) - load_age_data(long_district_df, "District", year, state_strata_lku) + # Note: The geographic strata must already exist in the database + # (created by create_initial_strata.py) + load_age_data(long_national_df, "National", year) + load_age_data(long_state_df, "State", year) + load_age_data(long_district_df, "District", year) diff --git a/policyengine_us_data/db/etl_irs_soi.py b/policyengine_us_data/db/etl_irs_soi.py index a280b006a..ed4da4e5c 100644 --- a/policyengine_us_data/db/etl_irs_soi.py +++ b/policyengine_us_data/db/etl_irs_soi.py @@ -4,7 +4,7 @@ import numpy as np import pandas as pd -from sqlmodel import Session, create_engine +from sqlmodel import Session, create_engine, select from policyengine_us_data.storage import STORAGE_FOLDER from policyengine_us_data.utils.raw_cache import ( @@ -16,40 +16,41 @@ logger = logging.getLogger(__name__) from policyengine_us_data.db.create_database_tables import ( - SourceType, Stratum, StratumConstraint, Target, + SourceType, ) -from policyengine_us_data.utils.db_metadata import get_or_create_source from policyengine_us_data.utils.db import ( get_stratum_by_id, - get_simple_stratum_by_ucgid, get_root_strata, get_stratum_children, get_stratum_parent, parse_ucgid, get_geographic_strata, ) +from policyengine_us_data.utils.db_metadata import ( + get_or_create_source, + get_or_create_variable_group, + get_or_create_variable_metadata, +) from policyengine_us_data.utils.census import TERRITORY_UCGIDS from policyengine_us_data.storage.calibration_targets.make_district_mapping import ( get_district_mapping, ) """See the 22incddocguide.docx manual from the IRS SOI""" -# Let's make this work with strict inequalities -# Language in the doc: '$10,000 under $25,000' -epsilon = 0.005 # i.e., half a penny +# Language in the doc: '$10,000 under $25,000' means >= $10,000 and < $25,000 AGI_STUB_TO_INCOME_RANGE = { - 1: (-np.inf, 1), - 2: (1 - epsilon, 10_000), - 3: (10_000 - epsilon, 25_000), - 4: (25_000 - epsilon, 50_000), - 5: (50_000 - epsilon, 75_000), - 6: (75_000 - epsilon, 100_000), - 7: (100_000 - epsilon, 200_000), - 8: (200_000 - epsilon, 500_000), - 9: (500_000 - epsilon, np.inf), + 1: (-np.inf, 1), # Under $1 (negative AGI allowed) + 2: (1, 10_000), # $1 under $10,000 + 3: (10_000, 25_000), # $10,000 under $25,000 + 4: (25_000, 50_000), # $25,000 under $50,000 + 5: (50_000, 75_000), # $50,000 under $75,000 + 6: (75_000, 100_000), # $75,000 under $100,000 + 7: (100_000, 200_000), # $100,000 under $200,000 + 8: (200_000, 500_000), # $200,000 under $500,000 + 9: (500_000, np.inf), # $500,000 or more } @@ -72,6 +73,15 @@ def make_records( breakdown_col: Optional[str] = None, multiplier: int = 1_000, ): + """ + Create standardized records from IRS SOI data. + + IMPORTANT DATA INCONSISTENCY (discovered 2024-12): + The IRS SOI documentation states "money amounts are reported in thousands of dollars." + This is true for almost all columns EXCEPT A59664 (EITC with 3+ children amount), + which is already in dollars, not thousands. This appears to be a data quality issue + in the IRS SOI file itself. We handle this special case below. + """ df = df.rename( {count_col: "tax_unit_count", amount_col: amount_name}, axis=1 ).copy() @@ -82,8 +92,31 @@ def make_records( rec_counts = create_records(df, breakdown_col, "tax_unit_count") rec_amounts = create_records(df, breakdown_col, amount_name) - rec_amounts["target_value"] *= multiplier # Only the amounts get * 1000 - rec_counts["target_variable"] = f"{amount_name}_tax_unit_count" + + # SPECIAL CASE: A59664 (EITC with 3+ children) is already in dollars, not thousands! + # All other EITC amounts (A59661-A59663) are correctly in thousands. + # This was verified by checking that A59660 (total EITC) equals the sum only when + # A59664 is treated as already being in dollars. + if amount_col == "A59664": + # Check if IRS has fixed the data inconsistency + # If values are < 10 million, they're likely already in thousands (fixed) + max_value = rec_amounts["target_value"].max() + if max_value < 10_000_000: + print( + f"WARNING: A59664 values appear to be in thousands (max={max_value:,.0f})" + ) + print("The IRS may have fixed their data inconsistency.") + print( + "Please verify and remove the special case handling if confirmed." + ) + # Don't apply the fix - data appears to already be in thousands + else: + # Convert from dollars to thousands to match other columns + rec_amounts["target_value"] /= 1_000 + + rec_amounts["target_value"] *= multiplier # Apply standard multiplier + # Note: tax_unit_count is the correct variable - the stratum constraints + # indicate what is being counted (e.g., eitc > 0 for EITC recipients) return rec_counts, rec_amounts @@ -161,19 +194,53 @@ def extract_soi_data() -> pd.DataFrame: In the file below, "22" is 2022, "in" is individual returns, "cd" is congressional districts """ - import requests - cache_file = "irs_soi_22incd.csv" if is_cached(cache_file): logger.info(f"Using cached {cache_file}") - return pd.read_csv(cache_path(cache_file)) + df = pd.read_csv(cache_path(cache_file)) + else: + url = "https://www.irs.gov/pub/irs-soi/22incd.csv" + import requests + + response = requests.get(url) + response.raise_for_status() + save_bytes(cache_file, response.content) + df = pd.read_csv(cache_path(cache_file)) + + # Validate EITC data consistency (check if IRS fixed the A59664 issue) + us_data = df[(df["STATE"] == "US") & (df["agi_stub"] == 0)] + if not us_data.empty and all( + col in us_data.columns + for col in ["A59660", "A59661", "A59662", "A59663", "A59664"] + ): + total_eitc = us_data["A59660"].values[0] + sum_as_thousands = ( + us_data["A59661"].values[0] + + us_data["A59662"].values[0] + + us_data["A59663"].values[0] + + us_data["A59664"].values[0] + ) + sum_mixed = ( + us_data["A59661"].values[0] + + us_data["A59662"].values[0] + + us_data["A59663"].values[0] + + us_data["A59664"].values[0] / 1000 + ) - url = "https://www.irs.gov/pub/irs-soi/22incd.csv" - logger.info(f"Downloading IRS SOI data from {url}") - response = requests.get(url) - response.raise_for_status() - save_bytes(cache_file, response.content) - return pd.read_csv(cache_path(cache_file)) + # Check which interpretation matches the total + if abs(total_eitc - sum_as_thousands) < 100: # Within 100K (thousands) + print("=" * 60) + print("ALERT: IRS may have fixed the A59664 data inconsistency!") + print(f"Total EITC (A59660): {total_eitc:,.0f}") + print(f"Sum treating A59664 as thousands: {sum_as_thousands:,.0f}") + print("These now match! Please verify and update the code.") + print("=" * 60) + elif abs(total_eitc - sum_mixed) < 100: + print( + "Note: A59664 still has the units inconsistency (in dollars, not thousands)" + ) + + return df def transform_soi_data(raw_df): @@ -182,14 +249,20 @@ def transform_soi_data(raw_df): dict(code="59661", name="eitc", breakdown=("eitc_child_count", 0)), dict(code="59662", name="eitc", breakdown=("eitc_child_count", 1)), dict(code="59663", name="eitc", breakdown=("eitc_child_count", 2)), - dict(code="59664", name="eitc", breakdown=("eitc_child_count", "3+")), + dict( + code="59664", name="eitc", breakdown=("eitc_child_count", "3+") + ), # Doc says "three" but data shows this is 3+ dict( code="04475", name="qualified_business_income_deduction", breakdown=None, ), + dict(code="00900", name="self_employment_income", breakdown=None), + dict( + code="01000", name="net_capital_gains", breakdown=None + ), # Not to be confused with the always positive net_capital_gain dict(code="18500", name="real_estate_taxes", breakdown=None), - dict(code="01000", name="net_capital_gain", breakdown=None), + dict(code="25870", name="rental_income", breakdown=None), dict(code="01400", name="taxable_ira_distributions", breakdown=None), dict(code="00300", name="taxable_interest_income", breakdown=None), dict(code="00400", name="tax_exempt_interest_income", breakdown=None), @@ -207,6 +280,7 @@ def transform_soi_data(raw_df): dict(code="11070", name="refundable_ctc", breakdown=None), dict(code="18425", name="salt", breakdown=None), dict(code="06500", name="income_tax", breakdown=None), + dict(code="05800", name="income_tax_before_credits", breakdown=None), ] # National --------------- @@ -306,20 +380,6 @@ def transform_soi_data(raw_df): return converted -def _lookup_geo_stratum(session, ucgid_str, geo_map): - """Look up a geographic stratum by ucgid string.""" - info = parse_ucgid(ucgid_str) - if info["type"] == "national": - sid = geo_map["national"] - elif info["type"] == "state": - sid = geo_map["state"].get(info["state_fips"]) - elif info["type"] == "district": - sid = geo_map["district"].get(info["congressional_district_geoid"]) - else: - return None - return get_stratum_by_id(session, sid) if sid else None - - def load_soi_data(long_dfs, year): """Load a list of databases into the db, critically dependent on order""" @@ -330,6 +390,7 @@ def load_soi_data(long_dfs, year): session = Session(engine) + # Get or create the IRS SOI source irs_source = get_or_create_source( session, name="IRS Statistics of Income", @@ -340,7 +401,282 @@ def load_soi_data(long_dfs, year): notes="Tax return data by congressional district, state, and national levels", ) - geo_map = get_geographic_strata(session) + # Create variable groups + agi_group = get_or_create_variable_group( + session, + name="agi_distribution", + category="income", + is_histogram=True, + is_exclusive=True, + aggregation_method="sum", + display_order=4, + description="Adjusted Gross Income distribution by IRS income stubs", + ) + + eitc_group = get_or_create_variable_group( + session, + name="eitc_recipients", + category="tax", + is_histogram=False, + is_exclusive=False, + aggregation_method="sum", + display_order=5, + description="Earned Income Tax Credit by number of qualifying children", + ) + + ctc_group = get_or_create_variable_group( + session, + name="ctc_recipients", + category="tax", + is_histogram=False, + is_exclusive=False, + aggregation_method="sum", + display_order=6, + description="Child Tax Credit recipients and amounts", + ) + + income_components_group = get_or_create_variable_group( + session, + name="income_components", + category="income", + is_histogram=False, + is_exclusive=False, + aggregation_method="sum", + display_order=7, + description="Components of income (interest, dividends, capital gains, etc.)", + ) + + deductions_group = get_or_create_variable_group( + session, + name="tax_deductions", + category="tax", + is_histogram=False, + is_exclusive=False, + aggregation_method="sum", + display_order=8, + description="Tax deductions (SALT, medical, real estate, etc.)", + ) + + # Create variable metadata + # EITC - both amount and count use same variable with different constraints + get_or_create_variable_metadata( + session, + variable="eitc", + group=eitc_group, + display_name="EITC Amount", + display_order=1, + units="dollars", + notes="EITC amounts by number of qualifying children", + ) + + # For counts, tax_unit_count is used with appropriate constraints + get_or_create_variable_metadata( + session, + variable="tax_unit_count", + group=None, # This spans multiple groups based on constraints + display_name="Tax Unit Count", + display_order=100, + units="count", + notes="Number of tax units - meaning depends on stratum constraints", + ) + + # CTC + get_or_create_variable_metadata( + session, + variable="refundable_ctc", + group=ctc_group, + display_name="Refundable CTC", + display_order=1, + units="dollars", + ) + + # AGI and related + get_or_create_variable_metadata( + session, + variable="adjusted_gross_income", + group=agi_group, + display_name="Adjusted Gross Income", + display_order=1, + units="dollars", + ) + + get_or_create_variable_metadata( + session, + variable="person_count", + group=agi_group, + display_name="Person Count", + display_order=3, + units="count", + notes="Number of people in tax units by AGI bracket", + ) + + # Income components + income_vars = [ + ("taxable_interest_income", "Taxable Interest", 1), + ("tax_exempt_interest_income", "Tax-Exempt Interest", 2), + ("dividend_income", "Ordinary Dividends", 3), + ("qualified_dividend_income", "Qualified Dividends", 4), + ("net_capital_gain", "Net Capital Gain", 5), + ("taxable_ira_distributions", "Taxable IRA Distributions", 6), + ("taxable_pension_income", "Taxable Pensions", 7), + ("taxable_social_security", "Taxable Social Security", 8), + ("unemployment_compensation", "Unemployment Compensation", 9), + ( + "tax_unit_partnership_s_corp_income", + "Partnership/S-Corp Income", + 10, + ), + ] + + for var_name, display_name, order in income_vars: + get_or_create_variable_metadata( + session, + variable=var_name, + group=income_components_group, + display_name=display_name, + display_order=order, + units="dollars", + ) + + # Deductions + deduction_vars = [ + ("salt", "State and Local Taxes", 1), + ("real_estate_taxes", "Real Estate Taxes", 2), + ("medical_expense_deduction", "Medical Expenses", 3), + ("qualified_business_income_deduction", "QBI Deduction", 4), + ] + + for var_name, display_name, order in deduction_vars: + get_or_create_variable_metadata( + session, + variable=var_name, + group=deductions_group, + display_name=display_name, + display_order=order, + units="dollars", + ) + + # Income tax + get_or_create_variable_metadata( + session, + variable="income_tax", + group=None, # Could create a tax_liability group if needed + display_name="Income Tax", + display_order=1, + units="dollars", + ) + + # Fetch existing geographic strata + geo_strata = get_geographic_strata(session) + + # Create filer strata as intermediate layer between geographic and IRS-specific strata + # All IRS data represents only tax filers, not the entire population + filer_strata = {"national": None, "state": {}, "district": {}} + + # National filer stratum - check if it exists first + national_filer_stratum = ( + session.query(Stratum) + .filter( + Stratum.parent_stratum_id == geo_strata["national"], + Stratum.notes == "United States - Tax Filers", + ) + .first() + ) + + if not national_filer_stratum: + national_filer_stratum = Stratum( + parent_stratum_id=geo_strata["national"], + stratum_group_id=2, # Filer population group + notes="United States - Tax Filers", + ) + national_filer_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable="tax_unit_is_filer", + operation="==", + value="1", + ) + ] + session.add(national_filer_stratum) + session.flush() + + filer_strata["national"] = national_filer_stratum.stratum_id + + # State filer strata + for state_fips, state_geo_stratum_id in geo_strata["state"].items(): + # Check if state filer stratum exists + state_filer_stratum = ( + session.query(Stratum) + .filter( + Stratum.parent_stratum_id == state_geo_stratum_id, + Stratum.notes == f"State FIPS {state_fips} - Tax Filers", + ) + .first() + ) + + if not state_filer_stratum: + state_filer_stratum = Stratum( + parent_stratum_id=state_geo_stratum_id, + stratum_group_id=2, # Filer population group + notes=f"State FIPS {state_fips} - Tax Filers", + ) + state_filer_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable="tax_unit_is_filer", + operation="==", + value="1", + ), + StratumConstraint( + constraint_variable="state_fips", + operation="==", + value=str(state_fips), + ), + ] + session.add(state_filer_stratum) + session.flush() + + filer_strata["state"][state_fips] = state_filer_stratum.stratum_id + + # District filer strata + for district_geoid, district_geo_stratum_id in geo_strata[ + "district" + ].items(): + # Check if district filer stratum exists + district_filer_stratum = ( + session.query(Stratum) + .filter( + Stratum.parent_stratum_id == district_geo_stratum_id, + Stratum.notes + == f"Congressional District {district_geoid} - Tax Filers", + ) + .first() + ) + + if not district_filer_stratum: + district_filer_stratum = Stratum( + parent_stratum_id=district_geo_stratum_id, + stratum_group_id=2, # Filer population group + notes=f"Congressional District {district_geoid} - Tax Filers", + ) + district_filer_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable="tax_unit_is_filer", + operation="==", + value="1", + ), + StratumConstraint( + constraint_variable="congressional_district_geoid", + operation="==", + value=str(district_geoid), + ), + ] + session.add(district_filer_stratum) + session.flush() + + filer_strata["district"][ + district_geoid + ] = district_filer_stratum.stratum_id + + session.commit() # Load EITC data -------------------------------------------------------- eitc_data = { @@ -350,73 +686,147 @@ def load_soi_data(long_dfs, year): "3+": (long_dfs[6], long_dfs[7]), } - stratum_lookup = {"State": {}, "District": {}} + eitc_stratum_lookup = {"national": {}, "state": {}, "district": {}} for n_children in eitc_data.keys(): eitc_count_i, eitc_amount_i = eitc_data[n_children] for i in range(eitc_count_i.shape[0]): ucgid_i = eitc_count_i[["ucgid_str"]].iloc[i].values[0] - note = f"Geo: {ucgid_i}, EITC received with {n_children} children" + geo_info = parse_ucgid(ucgid_i) - if len(ucgid_i) == 9: # National. - new_stratum = Stratum( - parent_stratum_id=None, stratum_group_id=0, notes=note - ) - elif len(ucgid_i) == 11: # State - new_stratum = Stratum( - parent_stratum_id=stratum_lookup["National"], - stratum_group_id=0, - notes=note, + # Determine parent stratum based on geographic level - use filer strata not geo strata + if geo_info["type"] == "national": + parent_stratum_id = filer_strata["national"] + note = f"National EITC received with {n_children} children (filers)" + constraints = [ + StratumConstraint( + constraint_variable="tax_unit_is_filer", + operation="==", + value="1", + ) + ] + elif geo_info["type"] == "state": + parent_stratum_id = filer_strata["state"][ + geo_info["state_fips"] + ] + note = f"State FIPS {geo_info['state_fips']} EITC received with {n_children} children (filers)" + constraints = [ + StratumConstraint( + constraint_variable="tax_unit_is_filer", + operation="==", + value="1", + ), + StratumConstraint( + constraint_variable="state_fips", + operation="==", + value=str(geo_info["state_fips"]), + ), + ] + elif geo_info["type"] == "district": + parent_stratum_id = filer_strata["district"][ + geo_info["congressional_district_geoid"] + ] + note = f"Congressional District {geo_info['congressional_district_geoid']} EITC received with {n_children} children (filers)" + constraints = [ + StratumConstraint( + constraint_variable="tax_unit_is_filer", + operation="==", + value="1", + ), + StratumConstraint( + constraint_variable="congressional_district_geoid", + operation="==", + value=str(geo_info["congressional_district_geoid"]), + ), + ] + + # Check if stratum already exists + existing_stratum = ( + session.query(Stratum) + .filter( + Stratum.parent_stratum_id == parent_stratum_id, + Stratum.stratum_group_id == 6, + Stratum.notes == note, ) - elif len(ucgid_i) == 13: # District + .first() + ) + + if existing_stratum: + new_stratum = existing_stratum + else: new_stratum = Stratum( - parent_stratum_id=stratum_lookup["State"][ - "0400000US" + ucgid_i[9:11] - ], - stratum_group_id=0, + parent_stratum_id=parent_stratum_id, + stratum_group_id=6, # EITC strata group notes=note, ) - new_stratum.constraints_rel = [ - StratumConstraint( - constraint_variable="ucgid_str", - operation="in", - value=ucgid_i, - ), - ] - if n_children == "3+": - new_stratum.constraints_rel.append( - StratumConstraint( - constraint_variable="eitc_child_count", - operation="greater_than", - value="2", + new_stratum.constraints_rel = constraints + if n_children == "3+": + new_stratum.constraints_rel.append( + StratumConstraint( + constraint_variable="eitc_child_count", + operation=">", + value="2", + ) ) - ) - else: - new_stratum.constraints_rel.append( - StratumConstraint( - constraint_variable="eitc_child_count", - operation="equals", - value=f"{n_children}", + else: + new_stratum.constraints_rel.append( + StratumConstraint( + constraint_variable="eitc_child_count", + operation="==", + value=f"{n_children}", + ) ) - ) - new_stratum.targets_rel = [ - Target( - variable="eitc", - period=year, - value=eitc_amount_i.iloc[i][["target_value"]].values[0], - source_id=irs_source.source_id, - active=True, + session.add(new_stratum) + session.flush() + + # Get both count and amount values + count_value = eitc_count_i.iloc[i][["target_value"]].values[0] + amount_value = eitc_amount_i.iloc[i][["target_value"]].values[0] + + # Check if targets already exist and update or create them + for variable, value in [ + ("tax_unit_count", count_value), + ("eitc", amount_value), + ]: + existing_target = ( + session.query(Target) + .filter( + Target.stratum_id == new_stratum.stratum_id, + Target.variable == variable, + Target.period == year, + ) + .first() ) - ] + + if existing_target: + existing_target.value = value + existing_target.source_id = irs_source.source_id + else: + new_stratum.targets_rel.append( + Target( + variable=variable, + period=year, + value=value, + source_id=irs_source.source_id, + active=True, + ) + ) session.add(new_stratum) session.flush() - if len(ucgid_i) == 9: - stratum_lookup["National"] = new_stratum.stratum_id - elif len(ucgid_i) == 11: - stratum_lookup["State"][ucgid_i] = new_stratum.stratum_id + # Store lookup for later use + if geo_info["type"] == "national": + eitc_stratum_lookup["national"][ + n_children + ] = new_stratum.stratum_id + elif geo_info["type"] == "state": + key = (geo_info["state_fips"], n_children) + eitc_stratum_lookup["state"][key] = new_stratum.stratum_id + elif geo_info["type"] == "district": + key = (geo_info["congressional_district_geoid"], n_children) + eitc_stratum_lookup["district"][key] = new_stratum.stratum_id session.commit() @@ -428,30 +838,140 @@ def load_soi_data(long_dfs, year): == "adjusted_gross_income" and long_dfs[i][["breakdown_variable"]].values[0] == "one" ][0] + # IRS variables start at stratum_group_id 100 + irs_group_id_start = 100 + for j in range(8, first_agi_index, 2): count_j, amount_j = long_dfs[j], long_dfs[j + 1] + count_variable_name = count_j.iloc[0][["target_variable"]].values[ + 0 + ] # Should be tax_unit_count amount_variable_name = amount_j.iloc[0][["target_variable"]].values[0] + + # Assign a unique stratum_group_id for this IRS variable + stratum_group_id = irs_group_id_start + (j - 8) // 2 + print( - f"Loading amount data for IRS SOI data on {amount_variable_name}" + f"Loading count and amount data for IRS SOI data on {amount_variable_name} (group_id={stratum_group_id})" ) + for i in range(count_j.shape[0]): ucgid_i = count_j[["ucgid_str"]].iloc[i].values[0] + geo_info = parse_ucgid(ucgid_i) + + # Get parent filer stratum (not geographic stratum) + if geo_info["type"] == "national": + parent_stratum_id = filer_strata["national"] + geo_description = "National" + elif geo_info["type"] == "state": + parent_stratum_id = filer_strata["state"][ + geo_info["state_fips"] + ] + geo_description = f"State {geo_info['state_fips']}" + elif geo_info["type"] == "district": + parent_stratum_id = filer_strata["district"][ + geo_info["congressional_district_geoid"] + ] + geo_description = ( + f"CD {geo_info['congressional_district_geoid']}" + ) - # Reusing an existing geographic stratum - stratum = _lookup_geo_stratum(session, ucgid_i, geo_map) - amount_value = amount_j.iloc[i][["target_value"]].values[0] + # Create child stratum with constraint for this IRS variable + # Note: This stratum will have the constraint that amount_variable > 0 + note = f"{geo_description} filers with {amount_variable_name} > 0" - stratum.targets_rel.append( - Target( - variable=amount_variable_name, - period=year, - value=amount_value, - source_id=irs_source.source_id, - active=True, + # Check if child stratum already exists + existing_stratum = ( + session.query(Stratum) + .filter( + Stratum.parent_stratum_id == parent_stratum_id, + Stratum.stratum_group_id == stratum_group_id, ) + .first() ) - session.add(stratum) + if existing_stratum: + child_stratum = existing_stratum + else: + # Create new child stratum with constraint + child_stratum = Stratum( + parent_stratum_id=parent_stratum_id, + stratum_group_id=stratum_group_id, + notes=note, + ) + + # Add constraints - filer status and this IRS variable must be positive + child_stratum.constraints_rel.extend( + [ + StratumConstraint( + constraint_variable="tax_unit_is_filer", + operation="==", + value="1", + ), + StratumConstraint( + constraint_variable=amount_variable_name, + operation=">", + value="0", + ), + ] + ) + + # Add geographic constraints if applicable + if geo_info["type"] == "state": + child_stratum.constraints_rel.append( + StratumConstraint( + constraint_variable="state_fips", + operation="==", + value=str(geo_info["state_fips"]), + ) + ) + elif geo_info["type"] == "district": + child_stratum.constraints_rel.append( + StratumConstraint( + constraint_variable="congressional_district_geoid", + operation="==", + value=str( + geo_info["congressional_district_geoid"] + ), + ) + ) + + session.add(child_stratum) + session.flush() + + count_value = count_j.iloc[i][["target_value"]].values[0] + amount_value = amount_j.iloc[i][["target_value"]].values[0] + + # Check if targets already exist and update or create them + for variable, value in [ + (count_variable_name, count_value), + (amount_variable_name, amount_value), + ]: + existing_target = ( + session.query(Target) + .filter( + Target.stratum_id == child_stratum.stratum_id, + Target.variable == variable, + Target.period == year, + ) + .first() + ) + + if existing_target: + existing_target.value = value + existing_target.source_id = irs_source.source_id + else: + child_stratum.targets_rel.append( + Target( + variable=variable, + period=year, + value=value, + source_id=irs_source.source_id, + active=True, + ) + ) + + session.add(child_stratum) session.flush() session.commit() @@ -462,16 +982,49 @@ def load_soi_data(long_dfs, year): for i in range(agi_values.shape[0]): ucgid_i = agi_values[["ucgid_str"]].iloc[i].values[0] - stratum = _lookup_geo_stratum(session, ucgid_i, geo_map) - stratum.targets_rel.append( - Target( - variable="adjusted_gross_income", - period=year, - value=agi_values.iloc[i][["target_value"]].values[0], - source_id=5, - active=True, + geo_info = parse_ucgid(ucgid_i) + + # Add target to existing FILER stratum (not geographic stratum) + if geo_info["type"] == "national": + stratum = session.get(Stratum, filer_strata["national"]) + elif geo_info["type"] == "state": + stratum = session.get( + Stratum, filer_strata["state"][geo_info["state_fips"]] + ) + elif geo_info["type"] == "district": + stratum = session.get( + Stratum, + filer_strata["district"][ + geo_info["congressional_district_geoid"] + ], ) + + # Check if target already exists + existing_target = ( + session.query(Target) + .filter( + Target.stratum_id == stratum.stratum_id, + Target.variable == "adjusted_gross_income", + Target.period == year, + ) + .first() ) + + if existing_target: + existing_target.value = agi_values.iloc[i][ + ["target_value"] + ].values[0] + existing_target.source_id = irs_source.source_id + else: + stratum.targets_rel.append( + Target( + variable="adjusted_gross_income", + period=year, + value=agi_values.iloc[i][["target_value"]].values[0], + source_id=irs_source.source_id, + active=True, + ) + ) session.add(stratum) session.flush() @@ -488,93 +1041,167 @@ def load_soi_data(long_dfs, year): agi_income_lower, agi_income_upper = AGI_STUB_TO_INCOME_RANGE[agi_stub] # Make a National Stratum for each AGI Stub even w/o associated national target - note = f"Geo: 0100000US, AGI > {agi_income_lower}, AGI < {agi_income_upper}" - nat_stratum = Stratum( - parent_stratum_id=None, stratum_group_id=0, notes=note - ) - nat_stratum.constraints_rel.extend( - [ - StratumConstraint( - constraint_variable="ucgid_str", - operation="in", - value="0100000US", - ), - StratumConstraint( - constraint_variable="adjusted_gross_income", - operation="greater_than", - value=str(agi_income_lower), - ), - StratumConstraint( - constraint_variable="adjusted_gross_income", - operation="less_than", - value=str(agi_income_upper), - ), - ] + note = f"National filers, AGI >= {agi_income_lower}, AGI < {agi_income_upper}" + + # Check if national AGI stratum already exists + nat_stratum = ( + session.query(Stratum) + .filter( + Stratum.parent_stratum_id == filer_strata["national"], + Stratum.stratum_group_id == 3, + Stratum.notes == note, + ) + .first() ) - session.add(nat_stratum) - session.flush() - - stratum_lookup = { - "National": nat_stratum.stratum_id, - "State": {}, - "District": {}, - } - for i in range(agi_df.shape[0]): - ucgid_i = agi_df[["ucgid_str"]].iloc[i].values[0] - note = f"Geo: {ucgid_i}, AGI > {agi_income_lower}, AGI < {agi_income_upper}" - - person_count = agi_df.iloc[i][["target_value"]].values[0] - if len(ucgid_i) == 11: # State - new_stratum = Stratum( - parent_stratum_id=stratum_lookup["National"], - stratum_group_id=0, - notes=note, - ) - elif len(ucgid_i) == 13: # District - new_stratum = Stratum( - parent_stratum_id=stratum_lookup["State"][ - "0400000US" + ucgid_i[9:11] - ], - stratum_group_id=0, - notes=note, - ) - new_stratum.constraints_rel.extend( + if not nat_stratum: + nat_stratum = Stratum( + parent_stratum_id=filer_strata["national"], + stratum_group_id=3, # Income/AGI strata group + notes=note, + ) + nat_stratum.constraints_rel.extend( [ StratumConstraint( - constraint_variable="ucgid_str", - operation="in", - value=ucgid_i, + constraint_variable="tax_unit_is_filer", + operation="==", + value="1", ), StratumConstraint( constraint_variable="adjusted_gross_income", - operation="greater_than", + operation=">=", value=str(agi_income_lower), ), StratumConstraint( constraint_variable="adjusted_gross_income", - operation="less_than", + operation="<", value=str(agi_income_upper), ), ] ) - new_stratum.targets_rel.append( - Target( - variable="person_count", - period=year, - value=person_count, - source_id=irs_source.source_id, - active=True, + session.add(nat_stratum) + session.flush() + + agi_stratum_lookup = { + "national": nat_stratum.stratum_id, + "state": {}, + "district": {}, + } + for i in range(agi_df.shape[0]): + ucgid_i = agi_df[["ucgid_str"]].iloc[i].values[0] + geo_info = parse_ucgid(ucgid_i) + person_count = agi_df.iloc[i][["target_value"]].values[0] + + if geo_info["type"] == "state": + parent_stratum_id = filer_strata["state"][ + geo_info["state_fips"] + ] + note = f"State FIPS {geo_info['state_fips']} filers, AGI >= {agi_income_lower}, AGI < {agi_income_upper}" + constraints = [ + StratumConstraint( + constraint_variable="tax_unit_is_filer", + operation="==", + value="1", + ), + StratumConstraint( + constraint_variable="state_fips", + operation="==", + value=str(geo_info["state_fips"]), + ), + ] + elif geo_info["type"] == "district": + parent_stratum_id = filer_strata["district"][ + geo_info["congressional_district_geoid"] + ] + note = f"Congressional District {geo_info['congressional_district_geoid']} filers, AGI >= {agi_income_lower}, AGI < {agi_income_upper}" + constraints = [ + StratumConstraint( + constraint_variable="tax_unit_is_filer", + operation="==", + value="1", + ), + StratumConstraint( + constraint_variable="congressional_district_geoid", + operation="==", + value=str(geo_info["congressional_district_geoid"]), + ), + ] + else: + continue # Skip if not state or district (shouldn't happen, but defensive) + + # Check if stratum already exists + existing_stratum = ( + session.query(Stratum) + .filter( + Stratum.parent_stratum_id == parent_stratum_id, + Stratum.stratum_group_id == 3, + Stratum.notes == note, + ) + .first() + ) + + if existing_stratum: + new_stratum = existing_stratum + else: + new_stratum = Stratum( + parent_stratum_id=parent_stratum_id, + stratum_group_id=3, # Income/AGI strata group + notes=note, + ) + new_stratum.constraints_rel = constraints + new_stratum.constraints_rel.extend( + [ + StratumConstraint( + constraint_variable="adjusted_gross_income", + operation=">=", + value=str(agi_income_lower), + ), + StratumConstraint( + constraint_variable="adjusted_gross_income", + operation="<", + value=str(agi_income_upper), + ), + ] ) + session.add(new_stratum) + session.flush() + + # Check if target already exists and update or create it + existing_target = ( + session.query(Target) + .filter( + Target.stratum_id == new_stratum.stratum_id, + Target.variable == "person_count", + Target.period == year, + ) + .first() ) + if existing_target: + existing_target.value = person_count + existing_target.source_id = irs_source.source_id + else: + new_stratum.targets_rel.append( + Target( + variable="person_count", + period=year, + value=person_count, + source_id=irs_source.source_id, + active=True, + ) + ) + session.add(new_stratum) session.flush() - if len(ucgid_i) == 9: - stratum_lookup["National"] = new_stratum.stratum_id - elif len(ucgid_i) == 11: - stratum_lookup["State"][ucgid_i] = new_stratum.stratum_id + if geo_info["type"] == "state": + agi_stratum_lookup["state"][ + geo_info["state_fips"] + ] = new_stratum.stratum_id + elif geo_info["type"] == "district": + agi_stratum_lookup["district"][ + geo_info["congressional_district_geoid"] + ] = new_stratum.stratum_id session.commit() diff --git a/policyengine_us_data/db/etl_medicaid.py b/policyengine_us_data/db/etl_medicaid.py index 67bf8db56..2038802bd 100644 --- a/policyengine_us_data/db/etl_medicaid.py +++ b/policyengine_us_data/db/etl_medicaid.py @@ -1,19 +1,28 @@ import logging -import requests +import requests import pandas as pd -from sqlmodel import Session, create_engine +import numpy as np +from sqlmodel import Session, create_engine, select from policyengine_us_data.storage import STORAGE_FOLDER from policyengine_us_data.db.create_database_tables import ( - SourceType, Stratum, StratumConstraint, Target, + SourceType, +) +from policyengine_us_data.utils.census import ( + STATE_ABBREV_TO_FIPS, + pull_acs_table, +) +from policyengine_us_data.utils.db import parse_ucgid, get_geographic_strata +from policyengine_us_data.utils.db_metadata import ( + get_or_create_source, + get_or_create_variable_group, + get_or_create_variable_metadata, ) -from policyengine_us_data.utils.census import STATE_ABBREV_TO_FIPS -from policyengine_us_data.utils.db_metadata import get_or_create_source from policyengine_us_data.utils.raw_cache import ( is_cached, cache_path, @@ -25,46 +34,50 @@ logger = logging.getLogger(__name__) -def extract_medicaid_data(year): - # Census ACS survey data - census_cache = f"acs_S2704_district_{year}.json" - if is_cached(census_cache): - logger.info(f"Using cached {census_cache}") - data = load_json(census_cache) - else: - base_url = f"https://api.census.gov/data/{year}/acs/acs1/subject?get=group(S2704)" - url = f"{base_url}&for=congressional+district:*" - logger.info(f"Downloading ACS S2704 for {year}") - response = requests.get(url) - response.raise_for_status() - data = response.json() - save_json(census_cache, data) - - headers = data[0] - data_rows = data[1:] - cd_survey_df = pd.DataFrame(data_rows, columns=headers) - - # CMS Medicaid administrative data +def extract_administrative_medicaid_data(year): cms_cache = f"medicaid_enrollment_{year}.csv" if is_cached(cms_cache): logger.info(f"Using cached {cms_cache}") - state_admin_df = pd.read_csv(cache_path(cms_cache)) - else: - item = "6165f45b-ca93-5bb5-9d06-db29c692a360" - logger.info("Downloading Medicaid enrollment from CMS") - response = requests.get( - f"https://data.medicaid.gov/api/1/metastore/schemas/dataset/items/{item}?show-reference-ids=false" - ) - metadata = response.json() - data_url = metadata["distribution"][0]["data"]["downloadURL"] - state_admin_df = pd.read_csv(data_url) - state_admin_df.to_csv(cache_path(cms_cache), index=False) + return pd.read_csv(cache_path(cms_cache)) - return cd_survey_df, state_admin_df + item = "6165f45b-ca93-5bb5-9d06-db29c692a360" + headers = { + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", + "Accept": "application/json", + "Accept-Language": "en-US,en;q=0.5", + } -def transform_medicaid_data(state_admin_df, cd_survey_df, year): + session = requests.Session() + session.headers.update(headers) + metadata_url = f"https://data.medicaid.gov/api/1/metastore/schemas/dataset/items/{item}?show-reference-ids=false" + print(f"Attempting to fetch Medicaid metadata from: {metadata_url}") + + response = session.get(metadata_url, timeout=30) + response.raise_for_status() + + metadata = response.json() + + if "distribution" not in metadata or len(metadata["distribution"]) == 0: + raise ValueError(f"No distribution found in metadata for item {item}") + + data_url = metadata["distribution"][0]["data"]["downloadURL"] + print(f"Downloading Medicaid data from: {data_url}") + + state_admin_df = pd.read_csv(data_url) + state_admin_df.to_csv(cache_path(cms_cache), index=False) + print( + f"Successfully downloaded {len(state_admin_df)} rows of Medicaid administrative data" + ) + return state_admin_df + + +def extract_survey_medicaid_data(year): + return pull_acs_table("S2704", "District", year) + + +def transform_administrative_medicaid_data(state_admin_df, year): reporting_period = year * 100 + 12 print(f"Reporting period is {reporting_period}") state_df = state_admin_df.loc[ @@ -79,22 +92,19 @@ def transform_medicaid_data(state_admin_df, cd_survey_df, year): state_df["FIPS"] = state_df["State Abbreviation"].map(STATE_ABBREV_TO_FIPS) - cd_df = cd_survey_df[ - ["GEO_ID", "state", "congressional district", "S2704_C02_006E"] - ] - - nc_cd_sum = cd_df.loc[cd_df.state == "37"].S2704_C02_006E.astype(int).sum() - nc_state_sum = state_df.loc[state_df.FIPS == "37"][ - "Total Medicaid Enrollment" - ].values[0] - assert nc_cd_sum > 0.5 * nc_state_sum - assert nc_cd_sum <= nc_state_sum - state_df = state_df.rename( columns={"Total Medicaid Enrollment": "medicaid_enrollment"} ) state_df["ucgid_str"] = "0400000US" + state_df["FIPS"].astype(str) + return state_df[["ucgid_str", "medicaid_enrollment"]] + + +def transform_survey_medicaid_data(cd_survey_df): + cd_df = cd_survey_df[ + ["GEO_ID", "state", "congressional district", "S2704_C02_006E"] + ] + cd_df = cd_df.rename( columns={ "S2704_C02_006E": "medicaid_enrollment", @@ -103,8 +113,7 @@ def transform_medicaid_data(state_admin_df, cd_survey_df, year): ) cd_df = cd_df.loc[cd_df.state != "72"] - out_cols = ["ucgid_str", "medicaid_enrollment"] - return state_df[out_cols], cd_df[out_cols] + return cd_df[["ucgid_str", "medicaid_enrollment"]] def load_medicaid_data(long_state, long_cd, year): @@ -114,9 +123,8 @@ def load_medicaid_data(long_state, long_cd, year): ) engine = create_engine(DATABASE_URL) - stratum_lookup = {} - with Session(engine) as session: + # Get or create sources admin_source = get_or_create_source( session, name="Medicaid T-MSIS", @@ -126,6 +134,7 @@ def load_medicaid_data(long_state, long_cd, year): url="https://data.medicaid.gov/", notes="State-level Medicaid enrollment from administrative records", ) + survey_source = get_or_create_source( session, name="Census ACS Table S2704", @@ -136,21 +145,45 @@ def load_medicaid_data(long_state, long_cd, year): notes="Congressional district level Medicaid coverage from ACS", ) + # Get or create Medicaid variable group + medicaid_group = get_or_create_variable_group( + session, + name="medicaid_recipients", + category="benefit", + is_histogram=False, + is_exclusive=False, + aggregation_method="sum", + display_order=3, + description="Medicaid enrollment and spending", + ) + + # Create variable metadata + # Note: The actual target variable used is "person_count" with medicaid_enrolled==True constraint + # This metadata entry is kept for consistency with the actual variable being used + get_or_create_variable_metadata( + session, + variable="person_count", + group=medicaid_group, + display_name="Medicaid Enrollment", + display_order=1, + units="count", + notes="Number of people enrolled in Medicaid (person_count with medicaid_enrolled==True)", + ) + + # Fetch existing geographic strata + geo_strata = get_geographic_strata(session) + # National ---------------- + # Create a Medicaid stratum as child of the national geographic stratum nat_stratum = Stratum( - parent_stratum_id=None, - stratum_group_id=0, - notes="Geo: 0100000US Medicaid Enrolled", + parent_stratum_id=geo_strata["national"], + stratum_group_id=5, # Medicaid strata group + notes="National Medicaid Enrolled", ) nat_stratum.constraints_rel = [ - StratumConstraint( - constraint_variable="ucgid_str", - operation="in", - value="0100000US", - ), StratumConstraint( constraint_variable="medicaid_enrolled", - operation="equals", + operation="==", value="True", ), ] @@ -158,29 +191,36 @@ def load_medicaid_data(long_state, long_cd, year): session.add(nat_stratum) session.flush() - stratum_lookup["National"] = nat_stratum.stratum_id + medicaid_stratum_lookup = { + "national": nat_stratum.stratum_id, + "state": {}, + } # State ------------------- - stratum_lookup["State"] = {} for _, row in long_state.iterrows(): + # Parse the UCGID to get state_fips + geo_info = parse_ucgid(row["ucgid_str"]) + state_fips = geo_info["state_fips"] + + # Get the parent geographic stratum + parent_stratum_id = geo_strata["state"][state_fips] - note = f"Geo: {row['ucgid_str']} Medicaid Enrolled" - parent_stratum_id = nat_stratum.stratum_id + note = f"State FIPS {state_fips} Medicaid Enrolled" new_stratum = Stratum( parent_stratum_id=parent_stratum_id, - stratum_group_id=0, + stratum_group_id=5, # Medicaid strata group notes=note, ) new_stratum.constraints_rel = [ StratumConstraint( - constraint_variable="ucgid_str", - operation="in", - value=row["ucgid_str"], + constraint_variable="state_fips", + operation="==", + value=str(state_fips), ), StratumConstraint( constraint_variable="medicaid_enrolled", - operation="equals", + operation="==", value="True", ), ] @@ -195,30 +235,39 @@ def load_medicaid_data(long_state, long_cd, year): ) session.add(new_stratum) session.flush() - stratum_lookup["State"][row["ucgid_str"]] = new_stratum.stratum_id + medicaid_stratum_lookup["state"][ + state_fips + ] = new_stratum.stratum_id # District ------------------- + if long_cd is None: + session.commit() + return + for _, row in long_cd.iterrows(): + # Parse the UCGID to get district info + geo_info = parse_ucgid(row["ucgid_str"]) + cd_geoid = geo_info["congressional_district_geoid"] - note = f"Geo: {row['ucgid_str']} Medicaid Enrolled" - parent_stratum_id = stratum_lookup["State"][ - f'0400000US{row["ucgid_str"][-4:-2]}' - ] + # Get the parent geographic stratum + parent_stratum_id = geo_strata["district"][cd_geoid] + + note = f"Congressional District {cd_geoid} Medicaid Enrolled" new_stratum = Stratum( parent_stratum_id=parent_stratum_id, - stratum_group_id=0, + stratum_group_id=5, # Medicaid strata group notes=note, ) new_stratum.constraints_rel = [ StratumConstraint( - constraint_variable="ucgid_str", - operation="in", - value=row["ucgid_str"], + constraint_variable="congressional_district_geoid", + operation="==", + value=str(cd_geoid), ), StratumConstraint( constraint_variable="medicaid_enrolled", - operation="equals", + operation="==", value="True", ), ] @@ -237,17 +286,37 @@ def load_medicaid_data(long_state, long_cd, year): session.commit() -if __name__ == "__main__": - - year = 2023 +def main(): + year = 2024 # Extract ------------------------------ - cd_survey_df, state_admin_df = extract_medicaid_data(year) + state_admin_df = extract_administrative_medicaid_data(year) + + # TODO: Re-enable CD survey Medicaid targets once we handle the 119th + # Congress district codes (5001900US) vs 118th Congress (5001800US) + # mismatch. The 2024 ACS uses 119th Congress GEO_IDs but the DB + # geographic strata use 118th Congress codes. Need a remapping step. + # When re-enabling, also restore the NC validation assert below. + # + # cd_survey_df = extract_survey_medicaid_data(year) + # long_cd = transform_survey_medicaid_data(cd_survey_df) + # nc_cd_sum = ( + # long_cd.loc[long_cd.ucgid_str.str.contains("5001800US37")] + # .medicaid_enrollment.astype(int) + # .sum() + # ) + # nc_state_sum = long_state.loc[long_state.ucgid_str == "0400000US37"][ + # "medicaid_enrollment" + # ].values[0] + # assert nc_cd_sum > 0.5 * nc_state_sum + # assert nc_cd_sum <= nc_state_sum # Transform ------------------- - long_state, long_cd = transform_medicaid_data( - state_admin_df, cd_survey_df, year - ) + long_state = transform_administrative_medicaid_data(state_admin_df, year) - # Load ----------------------- - load_medicaid_data(long_state, long_cd, year) + # Load (state admin only, no CD survey) --- + load_medicaid_data(long_state, long_cd=None, year=year) + + +if __name__ == "__main__": + main() diff --git a/policyengine_us_data/db/etl_national_targets.py b/policyengine_us_data/db/etl_national_targets.py index 0a03add3d..5cb910d5b 100644 --- a/policyengine_us_data/db/etl_national_targets.py +++ b/policyengine_us_data/db/etl_national_targets.py @@ -22,22 +22,24 @@ def extract_national_targets(): dict Dictionary containing: - direct_sum_targets: Variables that can be summed directly - - tax_filer_targets: Tax-related variables requiring filer - constraint - - conditional_count_targets: Enrollment counts requiring - constraints + - tax_filer_targets: Tax-related variables requiring filer constraint + - conditional_count_targets: Enrollment counts requiring constraints - cbo_targets: List of CBO projection targets - treasury_targets: List of Treasury/JCT targets """ + # Initialize PolicyEngine for parameter access from policyengine_us import Microsimulation sim = Microsimulation( dataset="hf://policyengine/policyengine-us-data/cps_2023.h5" ) + # Direct sum targets - these are regular variables that can be summed + # Store with their actual source year (2024 for hardcoded values from loss.py) HARDCODED_YEAR = 2024 + # Separate tax-related targets that need filer constraint tax_filer_targets = [ { "variable": "salt_deduction", @@ -81,24 +83,21 @@ def extract_national_targets(): "variable": "alimony_income", "value": 13e9, "source": "Survey-reported (post-TCJA grandfathered)", - "notes": "Alimony received - survey reported, " - "not tax-filer restricted", + "notes": "Alimony received - survey reported, not tax-filer restricted", "year": HARDCODED_YEAR, }, { "variable": "alimony_expense", "value": 13e9, "source": "Survey-reported (post-TCJA grandfathered)", - "notes": "Alimony paid - survey reported, " - "not tax-filer restricted", + "notes": "Alimony paid - survey reported, not tax-filer restricted", "year": HARDCODED_YEAR, }, { "variable": "medicaid", "value": 871.7e9, - "source": "https://www.cms.gov/files/document/" "highlights.pdf", - "notes": "CMS 2023 highlights document - " - "total Medicaid spending", + "source": "https://www.cms.gov/files/document/highlights.pdf", + "notes": "CMS 2023 highlights document - total Medicaid spending", "year": HARDCODED_YEAR, }, { @@ -109,10 +108,10 @@ def extract_national_targets(): "year": HARDCODED_YEAR, }, { - "variable": "health_insurance_premiums_without_" "medicare_part_b", + "variable": "health_insurance_premiums_without_medicare_part_b", "value": 385e9, "source": "MEPS/NHEA", - "notes": "Health insurance premiums excluding " "Medicare Part B", + "notes": "Health insurance premiums excluding Medicare Part B", "year": HARDCODED_YEAR, }, { @@ -189,16 +188,17 @@ def extract_national_targets(): "variable": "tip_income", "value": 53.2e9, "source": "IRS Form W-2 Box 7 statistics", - "notes": "Social security tips uprated 40% to account " - "for underreporting", + "notes": "Social security tips uprated 40% to account for underreporting", "year": HARDCODED_YEAR, }, ] + # Conditional count targets - these need strata with constraints + # Store with actual source year conditional_count_targets = [ { "constraint_variable": "medicaid", - "stratum_group_id": 5, + "stratum_group_id": 5, # Medicaid strata group "person_count": 72_429_055, "source": "CMS/HHS administrative data", "notes": "Medicaid enrollment count", @@ -206,7 +206,7 @@ def extract_national_targets(): }, { "constraint_variable": "aca_ptc", - "stratum_group_id": None, + "stratum_group_id": None, # Will use a generic stratum or create new group "person_count": 19_743_689, "source": "CMS marketplace data", "notes": "ACA Premium Tax Credit recipients", @@ -214,14 +214,16 @@ def extract_national_targets(): }, ] + # Add SSN card type NONE targets for multiple years + # Based on loss.py lines 445-460 ssn_none_targets_by_year = [ { "constraint_variable": "ssn_card_type", - "constraint_value": "NONE", - "stratum_group_id": 7, + "constraint_value": "NONE", # Need to specify the value we're checking for + "stratum_group_id": 7, # New group for SSN card type "person_count": 11.0e6, "source": "DHS Office of Homeland Security Statistics", - "notes": "Undocumented population estimate " "for Jan 1, 2022", + "notes": "Undocumented population estimate for Jan 1, 2022", "year": 2022, }, { @@ -229,10 +231,8 @@ def extract_national_targets(): "constraint_value": "NONE", "stratum_group_id": 7, "person_count": 12.2e6, - "source": "Center for Migration Studies " - "ACS-based residual estimate", - "notes": "Undocumented population estimate " - "(published May 2025)", + "source": "Center for Migration Studies ACS-based residual estimate", + "notes": "Undocumented population estimate (published May 2025)", "year": 2023, }, { @@ -241,8 +241,7 @@ def extract_national_targets(): "stratum_group_id": 7, "person_count": 13.0e6, "source": "Reuters synthesis of experts", - "notes": "Undocumented population central estimate " - "(~13-14 million)", + "notes": "Undocumented population central estimate (~13-14 million)", "year": 2024, }, { @@ -251,15 +250,15 @@ def extract_national_targets(): "stratum_group_id": 7, "person_count": 13.0e6, "source": "Reuters synthesis of experts", - "notes": "Same midpoint carried forward - " - "CBP data show 95% drop in border apprehensions", + "notes": "Same midpoint carried forward - CBP data show 95% drop in border apprehensions", "year": 2025, }, ] conditional_count_targets.extend(ssn_none_targets_by_year) - CBO_YEAR = 2023 + # CBO projection targets - get for a specific year + CBO_YEAR = 2023 # Year the CBO projections are for cbo_vars = [ "income_tax", "snap", @@ -285,10 +284,10 @@ def extract_national_targets(): ) except (KeyError, AttributeError) as e: print( - f"Warning: Could not extract CBO parameter " - f"for {variable_name}: {e}" + f"Warning: Could not extract CBO parameter for {variable_name}: {e}" ) + # Treasury/JCT targets (EITC) - get for a specific year TREASURY_YEAR = 2023 try: eitc_value = sim.tax_benefit_system.parameters.calibration.gov.treasury.tax_expenditures.eitc( @@ -304,7 +303,7 @@ def extract_national_targets(): } ] except (KeyError, AttributeError) as e: - print(f"Warning: Could not extract Treasury EITC " f"parameter: {e}") + print(f"Warning: Could not extract Treasury EITC parameter: {e}") treasury_targets = [] return { @@ -318,7 +317,7 @@ def extract_national_targets(): def transform_national_targets(raw_targets): """ - Transform extracted targets into standardized format. + Transform extracted targets into standardized format for loading. Parameters ---------- @@ -329,7 +328,13 @@ def transform_national_targets(raw_targets): ------- tuple (direct_targets_df, tax_filer_df, conditional_targets) + - direct_targets_df: DataFrame with direct sum targets + - tax_filer_df: DataFrame with tax-related targets needing filer constraint + - conditional_targets: List of conditional count targets """ + + # Process direct sum targets (non-tax items and some CBO items) + # Note: income_tax from CBO and eitc from Treasury need filer constraint cbo_non_tax = [ t for t in raw_targets["cbo_targets"] if t["variable"] != "income_tax" ] @@ -339,10 +344,11 @@ def transform_national_targets(raw_targets): all_direct_targets = raw_targets["direct_sum_targets"] + cbo_non_tax + # Tax-related targets that need filer constraint all_tax_filer_targets = ( raw_targets["tax_filer_targets"] + cbo_tax - + raw_targets["treasury_targets"] + + raw_targets["treasury_targets"] # EITC ) direct_df = ( @@ -356,6 +362,7 @@ def transform_national_targets(raw_targets): else pd.DataFrame() ) + # Conditional targets stay as list for special processing conditional_targets = raw_targets["conditional_count_targets"] return direct_df, tax_filer_df, conditional_targets @@ -378,37 +385,38 @@ def load_national_targets( """ DATABASE_URL = ( - f"sqlite:///" f"{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" + f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" ) engine = create_engine(DATABASE_URL) with Session(engine) as session: + # Get or create the calibration source calibration_source = get_or_create_source( session, name="PolicyEngine Calibration Targets", source_type=SourceType.HARDCODED, vintage="Mixed (2023-2024)", - description="National calibration targets from " - "various authoritative sources", + description="National calibration targets from various authoritative sources", url=None, - notes="Aggregated from CMS, IRS, CBO, Treasury, " - "and other federal sources", + notes="Aggregated from CMS, IRS, CBO, Treasury, and other federal sources", ) + # Get the national stratum us_stratum = ( session.query(Stratum) - .filter(Stratum.parent_stratum_id == None) # noqa: E711 + .filter(Stratum.parent_stratum_id == None) .first() ) if not us_stratum: raise ValueError( - "National stratum not found. " - "Run create_initial_strata.py first." + "National stratum not found. Run create_initial_strata.py first." ) + # Process direct sum targets for _, target_data in direct_targets_df.iterrows(): target_year = target_data["year"] + # Check if target already exists existing_target = ( session.query(Target) .filter( @@ -419,6 +427,7 @@ def load_national_targets( .first() ) + # Combine source info into notes notes_parts = [] if pd.notna(target_data.get("notes")): notes_parts.append(target_data["notes"]) @@ -428,10 +437,12 @@ def load_national_targets( combined_notes = " | ".join(notes_parts) if existing_target: + # Update existing target existing_target.value = target_data["value"] existing_target.notes = combined_notes print(f"Updated target: {target_data['variable']}") else: + # Create new target target = Target( stratum_id=us_stratum.stratum_id, variable=target_data["variable"], @@ -444,7 +455,9 @@ def load_national_targets( session.add(target) print(f"Added target: {target_data['variable']}") + # Process tax-related targets that need filer constraint if not tax_filer_df.empty: + # Get or create the national filer stratum national_filer_stratum = ( session.query(Stratum) .filter( @@ -455,9 +468,10 @@ def load_national_targets( ) if not national_filer_stratum: + # Create national filer stratum national_filer_stratum = Stratum( parent_stratum_id=us_stratum.stratum_id, - stratum_group_id=2, + stratum_group_id=2, # Filer population group notes="United States - Tax Filers", ) national_filer_stratum.constraints_rel = [ @@ -471,8 +485,10 @@ def load_national_targets( session.flush() print("Created national filer stratum") + # Add tax-related targets to filer stratum for _, target_data in tax_filer_df.iterrows(): target_year = target_data["year"] + # Check if target already exists existing_target = ( session.query(Target) .filter( @@ -483,23 +499,24 @@ def load_national_targets( .first() ) + # Combine source info into notes notes_parts = [] if pd.notna(target_data.get("notes")): notes_parts.append(target_data["notes"]) notes_parts.append( - f"Source: " f"{target_data.get('source', 'Unknown')}" + f"Source: {target_data.get('source', 'Unknown')}" ) combined_notes = " | ".join(notes_parts) if existing_target: + # Update existing target existing_target.value = target_data["value"] existing_target.notes = combined_notes - print( - f"Updated filer target: " f"{target_data['variable']}" - ) + print(f"Updated filer target: {target_data['variable']}") else: + # Create new target target = Target( - stratum_id=(national_filer_stratum.stratum_id), + stratum_id=national_filer_stratum.stratum_id, variable=target_data["variable"], period=target_year, value=target_data["value"], @@ -508,25 +525,29 @@ def load_national_targets( notes=combined_notes, ) session.add(target) - print(f"Added filer target: " f"{target_data['variable']}") + print(f"Added filer target: {target_data['variable']}") + # Process conditional count targets (enrollment counts) for cond_target in conditional_targets: constraint_var = cond_target["constraint_variable"] stratum_group_id = cond_target.get("stratum_group_id") target_year = cond_target["year"] + # Determine stratum group ID and constraint details if constraint_var == "medicaid": - stratum_group_id = 5 + stratum_group_id = 5 # Medicaid strata group stratum_notes = "National Medicaid Enrollment" constraint_operation = ">" constraint_value = "0" elif constraint_var == "aca_ptc": - stratum_group_id = 6 + stratum_group_id = ( + 6 # EITC group or could create new ACA group + ) stratum_notes = "National ACA Premium Tax Credit Recipients" constraint_operation = ">" constraint_value = "0" elif constraint_var == "ssn_card_type": - stratum_group_id = 7 + stratum_group_id = 7 # SSN card type group stratum_notes = "National Undocumented Population" constraint_operation = "=" constraint_value = cond_target.get("constraint_value", "NONE") @@ -535,6 +556,7 @@ def load_national_targets( constraint_operation = ">" constraint_value = "0" + # Check if this stratum already exists existing_stratum = ( session.query(Stratum) .filter( @@ -546,6 +568,7 @@ def load_national_targets( ) if existing_stratum: + # Update the existing target in this stratum existing_target = ( session.query(Target) .filter( @@ -558,10 +581,9 @@ def load_national_targets( if existing_target: existing_target.value = cond_target["person_count"] - print( - f"Updated enrollment target " f"for {constraint_var}" - ) + print(f"Updated enrollment target for {constraint_var}") else: + # Add new target to existing stratum new_target = Target( stratum_id=existing_stratum.stratum_id, variable="person_count", @@ -569,20 +591,19 @@ def load_national_targets( value=cond_target["person_count"], source_id=calibration_source.source_id, active=True, - notes=( - f"{cond_target['notes']} | " - f"Source: {cond_target['source']}" - ), + notes=f"{cond_target['notes']} | Source: {cond_target['source']}", ) session.add(new_target) - print(f"Added enrollment target " f"for {constraint_var}") + print(f"Added enrollment target for {constraint_var}") else: + # Create new stratum with constraint new_stratum = Stratum( parent_stratum_id=us_stratum.stratum_id, stratum_group_id=stratum_group_id, notes=stratum_notes, ) + # Add constraint new_stratum.constraints_rel = [ StratumConstraint( constraint_variable=constraint_var, @@ -591,6 +612,7 @@ def load_national_targets( ) ] + # Add target new_stratum.targets_rel = [ Target( variable="person_count", @@ -598,17 +620,13 @@ def load_national_targets( value=cond_target["person_count"], source_id=calibration_source.source_id, active=True, - notes=( - f"{cond_target['notes']} | " - f"Source: {cond_target['source']}" - ), + notes=f"{cond_target['notes']} | Source: {cond_target['source']}", ) ] session.add(new_stratum) print( - f"Created stratum and target " - f"for {constraint_var} enrollment" + f"Created stratum and target for {constraint_var} enrollment" ) session.commit() @@ -618,25 +636,28 @@ def load_national_targets( + len(tax_filer_df) + len(conditional_targets) ) - print(f"\nSuccessfully loaded {total_targets} " f"national targets") + print(f"\nSuccessfully loaded {total_targets} national targets") print(f" - {len(direct_targets_df)} direct sum targets") print(f" - {len(tax_filer_df)} tax filer targets") print( - f" - {len(conditional_targets)} enrollment count " - f"targets (as strata)" + f" - {len(conditional_targets)} enrollment count targets (as strata)" ) def main(): """Main ETL pipeline for national targets.""" + + # Extract print("Extracting national targets...") raw_targets = extract_national_targets() + # Transform print("Transforming targets...") direct_targets_df, tax_filer_df, conditional_targets = ( transform_national_targets(raw_targets) ) + # Load print("Loading targets into database...") load_national_targets(direct_targets_df, tax_filer_df, conditional_targets) diff --git a/policyengine_us_data/db/etl_snap.py b/policyengine_us_data/db/etl_snap.py index 6f1a64767..48c1eb832 100644 --- a/policyengine_us_data/db/etl_snap.py +++ b/policyengine_us_data/db/etl_snap.py @@ -6,21 +6,29 @@ import pandas as pd import numpy as np import us -from sqlmodel import Session, create_engine +from sqlmodel import Session, create_engine, select from policyengine_us_data.storage import STORAGE_FOLDER from policyengine_us_data.db.create_database_tables import ( - SourceType, Stratum, StratumConstraint, Target, + Source, + SourceType, + VariableGroup, + VariableMetadata, ) from policyengine_us_data.utils.census import ( pull_acs_table, STATE_NAME_TO_FIPS, ) -from policyengine_us_data.utils.db_metadata import get_or_create_source +from policyengine_us_data.utils.db import parse_ucgid, get_geographic_strata +from policyengine_us_data.utils.db_metadata import ( + get_or_create_source, + get_or_create_variable_group, + get_or_create_variable_metadata, +) from policyengine_us_data.utils.raw_cache import ( is_cached, cache_path, @@ -51,29 +59,19 @@ def extract_administrative_snap_data(year=2023): "Upgrade-Insecure-Requests": "1", } + session = requests.Session() + session.headers.update(headers) + try: - session = requests.Session() - session.headers.update(headers) - - main_page = "https://www.fns.usda.gov/pd/supplemental-nutrition-assistance-program-snap" - try: - session.get(main_page, timeout=30) - except: - pass - - logger.info("Downloading SNAP data from USDA FNS") - response = session.get(url, timeout=30, allow_redirects=True) - response.raise_for_status() - except requests.exceptions.RequestException as e: - print(f"Error downloading file: {e}") - try: - alt_url = "https://www.fns.usda.gov/sites/default/files/resource-files/snap-zip-fy69tocurrent-6.zip" - response = session.get(alt_url, timeout=30, allow_redirects=True) - response.raise_for_status() - except requests.exceptions.RequestException as e2: - print(f"Alternative URL also failed: {e2}") - return None + session.get( + "https://www.fns.usda.gov/pd/supplemental-nutrition-assistance-program-snap", + timeout=30, + ) + except Exception: + pass + response = session.get(url, timeout=30, allow_redirects=True) + response.raise_for_status() save_bytes(cache_file, response.content) return zipfile.ZipFile(io.BytesIO(response.content)) @@ -166,9 +164,8 @@ def load_administrative_snap_data(df_states, year): ) engine = create_engine(DATABASE_URL) - stratum_lookup = {} - with Session(engine) as session: + # Get or create the administrative source admin_source = get_or_create_source( session, name="USDA FNS SNAP Data", @@ -179,21 +176,53 @@ def load_administrative_snap_data(df_states, year): notes="State-level administrative totals for households and costs", ) + # Get or create the SNAP variable group + snap_group = get_or_create_variable_group( + session, + name="snap_recipients", + category="benefit", + is_histogram=False, + is_exclusive=False, + aggregation_method="sum", + display_order=2, + description="SNAP (food stamps) recipient counts and benefits", + ) + + # Get or create variable metadata + get_or_create_variable_metadata( + session, + variable="snap", + group=snap_group, + display_name="SNAP Benefits", + display_order=1, + units="dollars", + notes="Annual SNAP benefit costs", + ) + + get_or_create_variable_metadata( + session, + variable="household_count", + group=snap_group, + display_name="SNAP Household Count", + display_order=2, + units="count", + notes="Number of households receiving SNAP", + ) + + # Fetch existing geographic strata + geo_strata = get_geographic_strata(session) + # National ---------------- + # Create a SNAP stratum as child of the national geographic stratum nat_stratum = Stratum( - parent_stratum_id=None, - stratum_group_id=0, - notes="Geo: 0100000US Received SNAP Benefits", + parent_stratum_id=geo_strata["national"], + stratum_group_id=4, # SNAP strata group + notes="National Received SNAP Benefits", ) nat_stratum.constraints_rel = [ - StratumConstraint( - constraint_variable="ucgid_str", - operation="in", - value="0100000US", - ), StratumConstraint( constraint_variable="snap", - operation="greater_than", + operation=">", value="0", ), ] @@ -202,29 +231,33 @@ def load_administrative_snap_data(df_states, year): session.add(nat_stratum) session.flush() - stratum_lookup["National"] = nat_stratum.stratum_id + snap_stratum_lookup = {"national": nat_stratum.stratum_id, "state": {}} # State ------------------- - stratum_lookup["State"] = {} for _, row in df_states.iterrows(): + # Parse the UCGID to get state_fips + geo_info = parse_ucgid(row["ucgid_str"]) + state_fips = geo_info["state_fips"] + + # Get the parent geographic stratum + parent_stratum_id = geo_strata["state"][state_fips] - note = f"Geo: {row['ucgid_str']} Received SNAP Benefits" - parent_stratum_id = nat_stratum.stratum_id + note = f"State FIPS {state_fips} Received SNAP Benefits" new_stratum = Stratum( parent_stratum_id=parent_stratum_id, - stratum_group_id=0, + stratum_group_id=4, # SNAP strata group notes=note, ) new_stratum.constraints_rel = [ StratumConstraint( - constraint_variable="ucgid_str", - operation="in", - value=row["ucgid_str"], + constraint_variable="state_fips", + operation="==", + value=str(state_fips), ), StratumConstraint( constraint_variable="snap", - operation="greater_than", + operation=">", value="0", ), ] @@ -249,17 +282,18 @@ def load_administrative_snap_data(df_states, year): ) session.add(new_stratum) session.flush() - stratum_lookup["State"][row["ucgid_str"]] = new_stratum.stratum_id + snap_stratum_lookup["state"][state_fips] = new_stratum.stratum_id session.commit() - return stratum_lookup + return snap_stratum_lookup -def load_survey_snap_data(survey_df, year, stratum_lookup=None): - """Use an already defined stratum_lookup to load the survey SNAP data""" +def load_survey_snap_data(survey_df, year, snap_stratum_lookup): + """Use an already defined snap_stratum_lookup to load the survey SNAP data - if stratum_lookup is None: - raise ValueError("stratum_lookup must be provided") + Note: snap_stratum_lookup should contain the SNAP strata created by + load_administrative_snap_data, so we don't recreate them. + """ DATABASE_URL = ( f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" @@ -267,6 +301,7 @@ def load_survey_snap_data(survey_df, year, stratum_lookup=None): engine = create_engine(DATABASE_URL) with Session(engine) as session: + # Get or create the survey source survey_source = get_or_create_source( session, name="Census ACS Table S2201", @@ -277,27 +312,36 @@ def load_survey_snap_data(survey_df, year, stratum_lookup=None): notes="Congressional district level SNAP household counts from ACS", ) + # Fetch existing geographic strata + geo_strata = get_geographic_strata(session) + # Create new strata for districts whose households recieve SNAP benefits district_df = survey_df.copy() for _, row in district_df.iterrows(): - note = f"Geo: {row['ucgid_str']} Received SNAP Benefits" - state_ucgid_str = "0400000US" + row["ucgid_str"][9:11] - state_stratum_id = stratum_lookup["State"][state_ucgid_str] + # Parse the UCGID to get district info + geo_info = parse_ucgid(row["ucgid_str"]) + cd_geoid = geo_info["congressional_district_geoid"] + + # Get the parent geographic stratum + parent_stratum_id = geo_strata["district"][cd_geoid] + + note = f"Congressional District {cd_geoid} Received SNAP Benefits" + new_stratum = Stratum( - parent_stratum_id=state_stratum_id, - stratum_group_id=0, + parent_stratum_id=parent_stratum_id, + stratum_group_id=4, # SNAP strata group notes=note, ) new_stratum.constraints_rel = [ StratumConstraint( - constraint_variable="ucgid_str", - operation="in", - value=row["ucgid_str"], + constraint_variable="congressional_district_geoid", + operation="==", + value=str(cd_geoid), ), StratumConstraint( constraint_variable="snap", - operation="greater_than", + operation=">", value="0", ), ] @@ -315,7 +359,7 @@ def load_survey_snap_data(survey_df, year, stratum_lookup=None): session.commit() - return stratum_lookup + return snap_stratum_lookup def main(): @@ -330,8 +374,8 @@ def main(): district_survey_df = transform_survey_snap_data(raw_survey_df) # Load ----------- - stratum_lookup = load_administrative_snap_data(state_admin_df, year) - load_survey_snap_data(district_survey_df, year, stratum_lookup) + snap_stratum_lookup = load_administrative_snap_data(state_admin_df, year) + load_survey_snap_data(district_survey_df, year, snap_stratum_lookup) if __name__ == "__main__": diff --git a/policyengine_us_data/db/migrate_stratum_group_ids.py b/policyengine_us_data/db/migrate_stratum_group_ids.py new file mode 100644 index 000000000..8a0839c9f --- /dev/null +++ b/policyengine_us_data/db/migrate_stratum_group_ids.py @@ -0,0 +1,137 @@ +""" +TODO: what is this file? Do we still need it? + + +Migration script to update stratum_group_id values to represent conceptual categories. + +New scheme: +- 1: Geographic (US, states, congressional districts) +- 2: Age-based strata +- 3: Income/AGI-based strata +- 4: SNAP recipient strata +- 5: Medicaid enrollment strata +- 6: EITC recipient strata +""" + +from sqlmodel import Session, create_engine, select +from policyengine_us_data.storage import STORAGE_FOLDER +from policyengine_us_data.db.create_database_tables import ( + Stratum, + StratumConstraint, +) + + +def migrate_stratum_group_ids(): + """Update stratum_group_id values based on constraint variables.""" + + DATABASE_URL = ( + f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" + ) + engine = create_engine(DATABASE_URL) + + with Session(engine) as session: + print("Starting stratum_group_id migration...") + print("=" * 60) + + # Track updates + updates = { + "Geographic": 0, + "Age": 0, + "Income/AGI": 0, + "SNAP": 0, + "Medicaid": 0, + "EITC": 0, + } + + # Get all strata + all_strata = session.exec(select(Stratum)).unique().all() + + for stratum in all_strata: + # Get constraints for this stratum + constraints = session.exec( + select(StratumConstraint).where( + StratumConstraint.stratum_id == stratum.stratum_id + ) + ).all() + + # Determine new group_id based on constraints + constraint_vars = [c.constraint_variable for c in constraints] + + # Geographic strata (no demographic constraints) + if not constraint_vars or all( + cv in ["state_fips", "congressional_district_geoid"] + for cv in constraint_vars + ): + if stratum.stratum_group_id != 1: + stratum.stratum_group_id = 1 + updates["Geographic"] += 1 + + # Age strata + elif "age" in constraint_vars: + if stratum.stratum_group_id != 2: + stratum.stratum_group_id = 2 + updates["Age"] += 1 + + # Income/AGI strata + elif "adjusted_gross_income" in constraint_vars: + if stratum.stratum_group_id != 3: + stratum.stratum_group_id = 3 + updates["Income/AGI"] += 1 + + # SNAP strata + elif "snap" in constraint_vars: + if stratum.stratum_group_id != 4: + stratum.stratum_group_id = 4 + updates["SNAP"] += 1 + + # Medicaid strata + elif "medicaid_enrolled" in constraint_vars: + if stratum.stratum_group_id != 5: + stratum.stratum_group_id = 5 + updates["Medicaid"] += 1 + + # EITC strata + elif "eitc_child_count" in constraint_vars: + if stratum.stratum_group_id != 6: + stratum.stratum_group_id = 6 + updates["EITC"] += 1 + + # Commit changes + session.commit() + + # Report results + print("\nMigration complete!") + print("-" * 60) + print("Updates made:") + for category, count in updates.items(): + if count > 0: + print(f" {category:15}: {count:5} strata updated") + + # Verify final counts + print("\nFinal stratum_group_id distribution:") + print("-" * 60) + + group_names = { + 1: "Geographic", + 2: "Age", + 3: "Income/AGI", + 4: "SNAP", + 5: "Medicaid", + 6: "EITC", + } + + for group_id, name in group_names.items(): + count = len( + session.exec( + select(Stratum).where(Stratum.stratum_group_id == group_id) + ) + .unique() + .all() + ) + print(f" Group {group_id} ({name:12}): {count:5} strata") + + print("\n✅ Migration successful!") + + +if __name__ == "__main__": + migrate_stratum_group_ids() diff --git a/policyengine_us_data/db/validate_database.py b/policyengine_us_data/db/validate_database.py index 3760706b8..2fa819f29 100644 --- a/policyengine_us_data/db/validate_database.py +++ b/policyengine_us_data/db/validate_database.py @@ -20,7 +20,6 @@ if not var_name in system.variables.keys(): raise ValueError(f"{var_name} not a policyengine-us variable") -constraint_vars = set(stratum_constraints_df["constraint_variable"]) -print(f"Constraint variables: {sorted(constraint_vars)}") -print(f"Target variables validated: {len(set(targets_df['variable']))}") -print("Validation passed.") +for var_name in set(stratum_constraints_df["constraint_variable"]): + if not var_name in system.variables.keys(): + raise ValueError(f"{var_name} not a policyengine-us variable") diff --git a/policyengine_us_data/db/validate_hierarchy.py b/policyengine_us_data/db/validate_hierarchy.py new file mode 100644 index 000000000..c868e84e7 --- /dev/null +++ b/policyengine_us_data/db/validate_hierarchy.py @@ -0,0 +1,326 @@ +""" +Validation script to ensure the parent-child hierarchy is working correctly. +Checks geographic and age strata relationships. +""" + +import sys +from sqlmodel import Session, create_engine, select +from policyengine_us_data.storage import STORAGE_FOLDER +from policyengine_us_data.db.create_database_tables import ( + Stratum, + StratumConstraint, +) + + +def validate_geographic_hierarchy(session): + """Validate the geographic hierarchy: US -> States -> Congressional Districts""" + + print("\n" + "=" * 60) + print("VALIDATING GEOGRAPHIC HIERARCHY") + print("=" * 60) + + errors = [] + + # Check US stratum exists and has no parent + us_stratum = session.exec( + select(Stratum).where( + Stratum.stratum_group_id == 1, Stratum.parent_stratum_id == None + ) + ).first() + + if not us_stratum: + errors.append( + "ERROR: No US-level stratum found (should have parent_stratum_id = None)" + ) + else: + print( + f"✓ US stratum found: {us_stratum.notes} (ID: {us_stratum.stratum_id})" + ) + + # Check it has no constraints + us_constraints = session.exec( + select(StratumConstraint).where( + StratumConstraint.stratum_id == us_stratum.stratum_id + ) + ).all() + + if us_constraints: + errors.append( + f"ERROR: US stratum has {len(us_constraints)} constraints, should have 0" + ) + else: + print("✓ US stratum has no constraints (correct)") + + # Check states + states = ( + session.exec( + select(Stratum).where( + Stratum.stratum_group_id == 1, + Stratum.parent_stratum_id == us_stratum.stratum_id, + ) + ) + .unique() + .all() + ) + + print(f"\n✓ Found {len(states)} state strata") + if len(states) != 51: # 50 states + DC + errors.append( + f"WARNING: Expected 51 states (including DC), found {len(states)}" + ) + + # Verify each state has proper constraints + state_ids = {} + for state in states[:5]: # Sample first 5 states + constraints = session.exec( + select(StratumConstraint).where( + StratumConstraint.stratum_id == state.stratum_id + ) + ).all() + + state_fips_constraint = [ + c for c in constraints if c.constraint_variable == "state_fips" + ] + if not state_fips_constraint: + errors.append( + f"ERROR: State '{state.notes}' has no state_fips constraint" + ) + else: + state_ids[state.stratum_id] = state.notes + print( + f" - {state.notes}: state_fips = {state_fips_constraint[0].value}" + ) + + # Check congressional districts + print("\nChecking Congressional Districts...") + + # Count total CDs (including delegate districts) + all_cds = ( + session.exec( + select(Stratum).where( + Stratum.stratum_group_id == 1, + ( + Stratum.notes.like("%Congressional District%") + | Stratum.notes.like("%Delegate District%") + ), + ) + ) + .unique() + .all() + ) + + print(f"✓ Found {len(all_cds)} congressional/delegate districts") + if len(all_cds) != 436: + errors.append( + f"WARNING: Expected 436 congressional districts (including DC delegate), found {len(all_cds)}" + ) + + # Verify CDs are children of correct states (spot check) + wyoming_id = None + for state in states: + if "Wyoming" in state.notes: + wyoming_id = state.stratum_id + break + + if wyoming_id: + # Check Wyoming's congressional district + wyoming_cds = ( + session.exec( + select(Stratum).where( + Stratum.stratum_group_id == 1, + Stratum.parent_stratum_id == wyoming_id, + Stratum.notes.like("%Congressional%"), + ) + ) + .unique() + .all() + ) + + if len(wyoming_cds) != 1: + errors.append( + f"ERROR: Wyoming should have 1 CD, found {len(wyoming_cds)}" + ) + else: + print(f"✓ Wyoming has correct number of CDs: 1") + + # Verify no other state's CDs are incorrectly parented to Wyoming + wrong_parent_cds = ( + session.exec( + select(Stratum).where( + Stratum.stratum_group_id == 1, + Stratum.parent_stratum_id == wyoming_id, + ~Stratum.notes.like("%Wyoming%"), + Stratum.notes.like("%Congressional%"), + ) + ) + .unique() + .all() + ) + + if wrong_parent_cds: + errors.append( + f"ERROR: Found {len(wrong_parent_cds)} non-Wyoming CDs incorrectly parented to Wyoming" + ) + for cd in wrong_parent_cds[:5]: + errors.append(f" - {cd.notes}") + else: + print( + "✓ No congressional districts incorrectly parented to Wyoming" + ) + + return errors + + +def validate_demographic_strata(session): + """Validate demographic strata are properly attached to geographic strata""" + + print("\n" + "=" * 60) + print("VALIDATING DEMOGRAPHIC STRATA") + print("=" * 60) + + errors = [] + + # Group names for the new scheme + group_names = { + 2: ("Age", 18), + 3: ("Income/AGI", 9), + 4: ("SNAP", 1), + 5: ("Medicaid", 1), + 6: ("EITC", 4), + } + + # Validate each demographic group + for group_id, (name, expected_per_geo) in group_names.items(): + strata = ( + session.exec( + select(Stratum).where(Stratum.stratum_group_id == group_id) + ) + .unique() + .all() + ) + + expected_total = expected_per_geo * 488 # 488 geographic areas + print(f"\n{name} strata (group {group_id}):") + print(f" Found: {len(strata)}") + print( + f" Expected: {expected_total} ({expected_per_geo} × 488 geographic areas)" + ) + + if len(strata) != expected_total: + errors.append( + f"WARNING: {name} has {len(strata)} strata, expected {expected_total}" + ) + + # Check parent relationships for a sample of demographic strata + print("\nChecking parent relationships (sample):") + sample_strata = ( + session.exec( + select(Stratum).where( + Stratum.stratum_group_id > 1 + ) # All demographic groups + ) + .unique() + .all()[:100] + ) # Take first 100 + + correct_parents = 0 + wrong_parents = 0 + no_parents = 0 + + for stratum in sample_strata: + if stratum.parent_stratum_id: + parent = session.get(Stratum, stratum.parent_stratum_id) + if parent and parent.stratum_group_id == 1: # Geographic parent + correct_parents += 1 + else: + wrong_parents += 1 + errors.append( + f"ERROR: Stratum {stratum.stratum_id} has non-geographic parent" + ) + else: + no_parents += 1 + errors.append(f"ERROR: Stratum {stratum.stratum_id} has no parent") + + print(f" Sample of {len(sample_strata)} demographic strata:") + print(f" - With geographic parent: {correct_parents}") + print(f" - With wrong parent: {wrong_parents}") + print(f" - With no parent: {no_parents}") + + return errors + + +def validate_constraint_uniqueness(session): + """Check that constraint combinations produce unique hashes""" + + print("\n" + "=" * 60) + print("VALIDATING CONSTRAINT UNIQUENESS") + print("=" * 60) + + errors = [] + + # Check for duplicate definition_hashes + all_strata = session.exec(select(Stratum)).unique().all() + hash_counts = {} + + for stratum in all_strata: + if stratum.definition_hash in hash_counts: + hash_counts[stratum.definition_hash].append(stratum) + else: + hash_counts[stratum.definition_hash] = [stratum] + + duplicates = { + h: strata for h, strata in hash_counts.items() if len(strata) > 1 + } + + if duplicates: + errors.append( + f"ERROR: Found {len(duplicates)} duplicate definition_hashes" + ) + for hash_val, strata in list(duplicates.items())[:3]: # Show first 3 + errors.append( + f" Hash {hash_val[:10]}... appears {len(strata)} times:" + ) + for s in strata[:3]: + errors.append(f" - ID {s.stratum_id}: {s.notes[:50]}") + else: + print(f"✓ All {len(all_strata)} strata have unique definition_hashes") + + return errors + + +def main(): + """Run all validation checks""" + + DATABASE_URL = ( + f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" + ) + engine = create_engine(DATABASE_URL) + + all_errors = [] + + with Session(engine) as session: + # Run validation checks + all_errors.extend(validate_geographic_hierarchy(session)) + all_errors.extend(validate_demographic_strata(session)) + all_errors.extend(validate_constraint_uniqueness(session)) + + # Summary + print("\n" + "=" * 60) + print("VALIDATION SUMMARY") + print("=" * 60) + + if all_errors: + print(f"\n❌ Found {len(all_errors)} issues:\n") + for error in all_errors: + print(f" {error}") + sys.exit(1) + else: + print("\n✅ All validation checks passed!") + print(" - Geographic hierarchy is correct") + print(" - Demographic strata properly organized and attached") + print(" - All constraint combinations are unique") + sys.exit(0) + + +if __name__ == "__main__": + main() From 53ceece588487ad3379a917808df60aac35f47be Mon Sep 17 00:00:00 2001 From: "baogorek@gmail.com" Date: Thu, 29 Jan 2026 11:06:41 -0500 Subject: [PATCH 5/8] feat: atomic parallel local area publishing with Modal Volume - Add Modal Volume staging for persistent cache - Implement parallel build workers (configurable --num-workers) - Add manifest validation with SHA256 checksums - Add retry logic with exponential backoff for HF uploads - Version files under v{version}/ paths - Update latest.json atomically after all uploads succeed - Add --skip-upload flag for build-only testing Co-Authored-By: Claude Opus 4.5 --- .github/workflows/local_area_publish.yaml | 27 +- Makefile | 2 + modal_app/local_area.py | 507 +++++++++++++++++- modal_app/worker_script.py | 109 ++++ .../publish_local_area.py | 160 ++++++ policyengine_us_data/db/etl_medicaid.py | 40 +- policyengine_us_data/utils/data_upload.py | 253 ++++++++- policyengine_us_data/utils/manifest.py | 188 +++++++ pyproject.toml | 1 + 9 files changed, 1268 insertions(+), 19 deletions(-) create mode 100644 modal_app/worker_script.py create mode 100644 policyengine_us_data/utils/manifest.py diff --git a/.github/workflows/local_area_publish.yaml b/.github/workflows/local_area_publish.yaml index e23468a69..7e756ad85 100644 --- a/.github/workflows/local_area_publish.yaml +++ b/.github/workflows/local_area_publish.yaml @@ -10,11 +10,22 @@ on: repository_dispatch: types: [calibration-updated] workflow_dispatch: + inputs: + num_workers: + description: 'Number of parallel workers' + required: false + default: '8' + type: string + skip_upload: + description: 'Skip upload (build only)' + required: false + default: false + type: boolean # Trigger strategy: # 1. Automatic: Code changes to local_area_calibration/ pushed to main # 2. repository_dispatch: Calibration workflow triggers after uploading new weights -# 3. workflow_dispatch: Manual trigger when you update weights/data on HF yourself +# 3. workflow_dispatch: Manual trigger with optional parameters jobs: publish-local-area: @@ -39,4 +50,16 @@ jobs: run: pip install modal - name: Run local area publishing on Modal - run: modal run modal_app/local_area.py --branch=${{ github.head_ref || github.ref_name }} + run: | + NUM_WORKERS="${{ github.event.inputs.num_workers || '8' }}" + SKIP_UPLOAD="${{ github.event.inputs.skip_upload || 'false' }}" + BRANCH="${{ github.head_ref || github.ref_name }}" + + CMD="modal run modal_app/local_area.py --branch=${BRANCH} --num-workers=${NUM_WORKERS}" + + if [ "$SKIP_UPLOAD" = "true" ]; then + CMD="${CMD} --skip-upload" + fi + + echo "Running: $CMD" + $CMD diff --git a/Makefile b/Makefile index a2297de5b..27a7b356f 100644 --- a/Makefile +++ b/Makefile @@ -54,6 +54,7 @@ documentation-dev: myst start database: + rm -f policyengine_us_data/storage/calibration/policy_data.db python policyengine_us_data/db/create_database_tables.py python policyengine_us_data/db/create_initial_strata.py python policyengine_us_data/db/etl_national_targets.py @@ -64,6 +65,7 @@ database: python policyengine_us_data/db/validate_database.py database-refresh: + rm -f policyengine_us_data/storage/calibration/policy_data.db rm -rf policyengine_us_data/storage/calibration/raw_inputs/ $(MAKE) database diff --git a/modal_app/local_area.py b/modal_app/local_area.py index 8a1bd2b83..84f82a8fa 100644 --- a/modal_app/local_area.py +++ b/modal_app/local_area.py @@ -1,19 +1,41 @@ +""" +Modal app for publishing local area H5 files with parallel workers. + +Architecture: +1. Coordinator partitions work across N workers +2. Workers build H5 files in parallel, writing to shared Volume +3. Validation generates manifest with checksums +4. Atomic upload to versioned paths, updates latest.json last + +Usage: + modal run modal_app/local_area.py --branch=main --num-workers=8 +""" + import os import subprocess +import json import modal +from pathlib import Path +from typing import List, Dict app = modal.App("policyengine-us-data-local-area") hf_secret = modal.Secret.from_name("huggingface-token") gcp_secret = modal.Secret.from_name("gcp-credentials") +staging_volume = modal.Volume.from_name( + "local-area-staging", + create_if_missing=True, +) + image = ( modal.Image.debian_slim(python_version="3.13") .apt_install("git") - .pip_install("uv") + .pip_install("uv", "tomli") ) REPO_URL = "https://github.com/PolicyEngine/policyengine-us-data.git" +VOLUME_MOUNT = "/staging" def setup_gcp_credentials(): @@ -28,37 +50,492 @@ def setup_gcp_credentials(): return None +def setup_repo(branch: str): + """Clone repo and install dependencies.""" + repo_dir = Path("/root/policyengine-us-data") + + if not repo_dir.exists(): + os.chdir("/root") + subprocess.run(["git", "clone", "-b", branch, REPO_URL], check=True) + os.chdir("policyengine-us-data") + subprocess.run(["uv", "sync", "--locked"], check=True) + else: + os.chdir(repo_dir) + + +def get_version() -> str: + """Get package version from pyproject.toml.""" + import tomli + + with open("pyproject.toml", "rb") as f: + pyproject = tomli.load(f) + return pyproject["project"]["version"] + + +def partition_work( + states: List[str], + districts: List[str], + cities: List[str], + num_workers: int, + completed: set, +) -> List[List[Dict]]: + """Partition work items across N workers.""" + remaining = [] + + for s in states: + item_id = f"state:{s}" + if item_id not in completed: + remaining.append({"type": "state", "id": s, "weight": 5}) + + for d in districts: + item_id = f"district:{d}" + if item_id not in completed: + remaining.append({"type": "district", "id": d, "weight": 1}) + + for c in cities: + item_id = f"city:{c}" + if item_id not in completed: + remaining.append({"type": "city", "id": c, "weight": 3}) + + remaining.sort(key=lambda x: -x["weight"]) + + chunks = [[] for _ in range(num_workers)] + for i, item in enumerate(remaining): + chunks[i % num_workers].append(item) + + return [c for c in chunks if c] + + +def get_completed_from_volume(version_dir: Path) -> set: + """Scan volume to find already-built files.""" + completed = set() + + states_dir = version_dir / "states" + if states_dir.exists(): + for f in states_dir.glob("*.h5"): + completed.add(f"state:{f.stem}") + + districts_dir = version_dir / "districts" + if districts_dir.exists(): + for f in districts_dir.glob("*.h5"): + completed.add(f"district:{f.stem}") + + cities_dir = version_dir / "cities" + if cities_dir.exists(): + for f in cities_dir.glob("*.h5"): + completed.add(f"city:{f.stem}") + + return completed + + @app.function( image=image, secrets=[hf_secret, gcp_secret], - memory=8192, + volumes={VOLUME_MOUNT: staging_volume}, + memory=16384, cpu=4.0, - timeout=86400, # 24h: processes 50 states + 435 districts with checkpointing + timeout=14400, ) -def publish_all_local_areas(branch: str = "main"): +def build_areas_worker( + branch: str, + version: str, + work_items: List[Dict], + calibration_inputs: Dict[str, str], +) -> Dict: + """ + Worker function that builds a subset of H5 files. + Uses subprocess to avoid import conflicts with Modal's environment. + """ setup_gcp_credentials() + setup_repo(branch) + + output_dir = Path(VOLUME_MOUNT) / version + output_dir.mkdir(parents=True, exist_ok=True) + + work_items_json = json.dumps(work_items) + + result = subprocess.run( + [ + "uv", + "run", + "python", + "modal_app/worker_script.py", + "--work-items", + work_items_json, + "--weights-path", + calibration_inputs["weights"], + "--dataset-path", + calibration_inputs["dataset"], + "--db-path", + calibration_inputs["database"], + "--output-dir", + str(output_dir), + ], + capture_output=True, + text=True, + env=os.environ.copy(), + ) + + print(result.stderr) + + if result.returncode != 0: + return { + "completed": [], + "failed": [f"{item['type']}:{item['id']}" for item in work_items], + "errors": [{"error": result.stderr}], + } + + try: + results = json.loads(result.stdout) + except json.JSONDecodeError: + results = { + "completed": [], + "failed": [], + "errors": [{"error": f"Failed to parse output: {result.stdout}"}], + } + + staging_volume.commit() + return results - os.chdir("/root") - subprocess.run(["git", "clone", "-b", branch, REPO_URL], check=True) - os.chdir("policyengine-us-data") - # Use uv sync to install exact versions from uv.lock - subprocess.run(["uv", "sync", "--locked"], check=True) - subprocess.run( +@app.function( + image=image, + secrets=[hf_secret], + volumes={VOLUME_MOUNT: staging_volume}, + memory=4096, + timeout=1800, +) +def validate_staging(branch: str, version: str) -> Dict: + """Validate all expected files and generate manifest.""" + setup_repo(branch) + + result = subprocess.run( [ "uv", "run", "python", - "policyengine_us_data/datasets/cps/local_area_calibration/publish_local_area.py", + "-c", + f""" +import json +from pathlib import Path +from policyengine_us_data.utils.manifest import generate_manifest, save_manifest + +staging_dir = Path("{VOLUME_MOUNT}") +version = "{version}" +manifest = generate_manifest(staging_dir, version) +manifest_path = staging_dir / version / "manifest.json" +save_manifest(manifest, manifest_path) +print(json.dumps(manifest)) +""", ], - check=True, + capture_output=True, + text=True, env=os.environ.copy(), ) - return "Local area publishing completed successfully" + print(result.stderr) + + if result.returncode != 0: + raise RuntimeError(f"Validation failed: {result.stderr}") + + manifest = json.loads(result.stdout) + staging_volume.commit() + + print(f"Generated manifest with {len(manifest['files'])} files") + print(f" States: {manifest['totals']['states']}") + print(f" Districts: {manifest['totals']['districts']}") + print(f" Cities: {manifest['totals']['cities']}") + print( + f" Total size: {manifest['totals']['total_size_bytes'] / 1e9:.2f} GB" + ) + + return manifest + + +@app.function( + image=image, + secrets=[hf_secret, gcp_secret], + volumes={VOLUME_MOUNT: staging_volume}, + memory=8192, + timeout=14400, +) +def atomic_upload(branch: str, version: str, manifest: Dict) -> str: + """Upload all files from staging to GCS and HF atomically.""" + setup_gcp_credentials() + setup_repo(branch) + + manifest_json = json.dumps(manifest) + + result = subprocess.run( + [ + "uv", + "run", + "python", + "-c", + f""" +import json +from pathlib import Path +from policyengine_us_data.utils.manifest import verify_manifest +from policyengine_us_data.utils.data_upload import ( + upload_versioned_files_to_gcs, + upload_versioned_files_to_hf, + upload_manifest_and_latest, +) + +manifest = json.loads('''{manifest_json}''') +version = "{version}" +staging_dir = Path("{VOLUME_MOUNT}") +version_dir = staging_dir / version + +print("Verifying manifest before upload...") +verification = verify_manifest(staging_dir, manifest) +if not verification["valid"]: + raise ValueError( + f"Manifest verification failed: " + f"{{len(verification['missing'])}} missing, " + f"{{len(verification['checksum_mismatch'])}} checksum mismatches" + ) +print(f"Verified {{verification['verified']}} files") + +files_with_paths = [] +for rel_path in manifest["files"].keys(): + local_path = version_dir / rel_path + files_with_paths.append((local_path, rel_path)) + +print(f"Uploading {{len(files_with_paths)}} files to GCS...") +gcs_count = upload_versioned_files_to_gcs(files_with_paths, version) +print(f"Uploaded {{gcs_count}} files to GCS") + +print(f"Uploading {{len(files_with_paths)}} files to HuggingFace...") +batch_size = 50 +hf_total = 0 +for i in range(0, len(files_with_paths), batch_size): + batch = files_with_paths[i : i + batch_size] + hf_count = upload_versioned_files_to_hf(batch, version) + hf_total += hf_count + print(f" Batch {{i // batch_size + 1}}: uploaded {{hf_count}} files") + +print(f"Uploaded {{hf_total}} files to HuggingFace") + +print("Updating manifest and latest.json...") +upload_manifest_and_latest(manifest, version) + +print(f"Successfully published version {{version}}") +""", + ], + text=True, + env=os.environ.copy(), + ) + + if result.returncode != 0: + raise RuntimeError(f"Upload failed: {result.stderr}") + + return f"Successfully published version {version} with {len(manifest['files'])} files" + + +@app.function( + image=image, + secrets=[hf_secret, gcp_secret], + volumes={VOLUME_MOUNT: staging_volume}, + memory=8192, + timeout=86400, +) +def coordinate_publish( + branch: str = "main", + num_workers: int = 8, + skip_upload: bool = False, +) -> str: + """Coordinate the full publishing workflow.""" + setup_gcp_credentials() + setup_repo(branch) + + version = get_version() + print(f"Publishing version {version} from branch {branch}") + print(f"Using {num_workers} parallel workers") + + staging_dir = Path(VOLUME_MOUNT) + version_dir = staging_dir / version + version_dir.mkdir(parents=True, exist_ok=True) + + calibration_dir = staging_dir / "calibration_inputs" + calibration_dir.mkdir(parents=True, exist_ok=True) + + weights_path = calibration_dir / "w_district_calibration.npy" + dataset_path = calibration_dir / "stratified_extended_cps.h5" + db_path = calibration_dir / "policy_data.db" + + if not all(p.exists() for p in [weights_path, dataset_path, db_path]): + print("Downloading calibration inputs...") + result = subprocess.run( + [ + "uv", + "run", + "python", + "-c", + f""" +from policyengine_us_data.utils.huggingface import download_calibration_inputs +download_calibration_inputs("{calibration_dir}") +print("Done") +""", + ], + text=True, + env=os.environ.copy(), + ) + if result.returncode != 0: + raise RuntimeError(f"Download failed: {result.stderr}") + staging_volume.commit() + print("Calibration inputs downloaded and cached on volume") + else: + print("Using cached calibration inputs from volume") + + calibration_inputs = { + "weights": str(weights_path), + "dataset": str(dataset_path), + "database": str(db_path), + } + + result = subprocess.run( + [ + "uv", + "run", + "python", + "-c", + f""" +import json +from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( + get_all_cds_from_database, + STATE_CODES, +) +from policyengine_us_data.datasets.cps.local_area_calibration.publish_local_area import ( + get_district_friendly_name, +) + +db_uri = "sqlite:///{db_path}" +cds = get_all_cds_from_database(db_uri) +states = list(STATE_CODES.values()) +districts = [get_district_friendly_name(cd) for cd in cds] +print(json.dumps({{"states": states, "districts": districts, "cities": ["NYC"]}})) +""", + ], + capture_output=True, + text=True, + env=os.environ.copy(), + ) + + if result.returncode != 0: + raise RuntimeError(f"Failed to get work items: {result.stderr}") + + work_info = json.loads(result.stdout) + states = work_info["states"] + districts = work_info["districts"] + cities = work_info["cities"] + + staging_volume.reload() + completed = get_completed_from_volume(version_dir) + print(f"Found {len(completed)} already-completed items on volume") + + work_chunks = partition_work( + states, districts, cities, num_workers, completed + ) + + total_remaining = sum(len(c) for c in work_chunks) + print( + f"Remaining work: {total_remaining} items " + f"across {len(work_chunks)} workers" + ) + + if total_remaining == 0: + print("All items already built!") + else: + print("\nSpawning workers...") + handles = [] + for i, chunk in enumerate(work_chunks): + print(f" Worker {i}: {len(chunk)} items") + handle = build_areas_worker.spawn( + branch=branch, + version=version, + work_items=chunk, + calibration_inputs=calibration_inputs, + ) + handles.append(handle) + + print("\nWaiting for workers to complete...") + all_results = [] + all_errors = [] + + for i, handle in enumerate(handles): + try: + result = handle.get() + all_results.append(result) + print( + f" Worker {i}: {len(result['completed'])} completed, " + f"{len(result['failed'])} failed" + ) + if result["errors"]: + all_errors.extend(result["errors"]) + except Exception as e: + all_errors.append({"worker": i, "error": str(e)}) + print(f" Worker {i}: CRASHED - {e}") + + total_completed = sum(len(r["completed"]) for r in all_results) + total_failed = sum(len(r["failed"]) for r in all_results) + + print(f"\nBuild summary:") + print(f" Completed: {total_completed}") + print(f" Failed: {total_failed}") + print(f" Previously completed: {len(completed)}") + + if all_errors: + print(f"\nErrors ({len(all_errors)}):") + for err in all_errors[:5]: + err_msg = err.get("error", "Unknown")[:100] + print(f" - {err.get('item', err.get('worker'))}: {err_msg}") + if len(all_errors) > 5: + print(f" ... and {len(all_errors) - 5} more") + + if total_failed > 0: + raise RuntimeError( + f"Build incomplete: {total_failed} failures. " + f"Volume preserved for retry." + ) + + if skip_upload: + print("\nSkipping upload (--skip-upload flag set)") + return f"Build complete for version {version}. Upload skipped." + + print("\nValidating staging...") + manifest = validate_staging.remote(branch=branch, version=version) + + expected_total = len(states) + len(districts) + len(cities) + actual_total = ( + manifest["totals"]["states"] + + manifest["totals"]["districts"] + + manifest["totals"]["cities"] + ) + + if actual_total < expected_total: + print( + f"WARNING: Expected {expected_total} files, found {actual_total}" + ) + + print("\nStarting atomic upload...") + result = atomic_upload.remote( + branch=branch, version=version, manifest=manifest + ) + + return result @app.local_entrypoint() -def main(branch: str = "main"): - result = publish_all_local_areas.remote(branch=branch) +def main( + branch: str = "main", + num_workers: int = 8, + skip_upload: bool = False, +): + """Local entrypoint for Modal CLI.""" + result = coordinate_publish.remote( + branch=branch, + num_workers=num_workers, + skip_upload=skip_upload, + ) print(result) diff --git a/modal_app/worker_script.py b/modal_app/worker_script.py new file mode 100644 index 000000000..95217e4cf --- /dev/null +++ b/modal_app/worker_script.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python +""" +Worker script for building local area H5 files. + +Called by Modal workers via subprocess to avoid import conflicts. +""" + +import argparse +import json +import sys +import traceback +import numpy as np +from pathlib import Path + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--work-items", required=True, help="JSON work items") + parser.add_argument("--weights-path", required=True) + parser.add_argument("--dataset-path", required=True) + parser.add_argument("--db-path", required=True) + parser.add_argument("--output-dir", required=True) + args = parser.parse_args() + + work_items = json.loads(args.work_items) + weights_path = Path(args.weights_path) + dataset_path = Path(args.dataset_path) + db_path = Path(args.db_path) + output_dir = Path(args.output_dir) + + from policyengine_us_data.datasets.cps.local_area_calibration.publish_local_area import ( + build_state_h5, + build_district_h5, + build_city_h5, + ) + from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( + get_all_cds_from_database, + STATE_CODES, + ) + + db_uri = f"sqlite:///{db_path}" + cds_to_calibrate = get_all_cds_from_database(db_uri) + weights = np.load(weights_path) + + results = { + "completed": [], + "failed": [], + "errors": [], + } + + for item in work_items: + item_type = item["type"] + item_id = item["id"] + + try: + if item_type == "state": + path = build_state_h5( + state_code=item_id, + weights=weights, + cds_to_calibrate=cds_to_calibrate, + dataset_path=dataset_path, + output_dir=output_dir, + ) + elif item_type == "district": + state_code, dist_num = item_id.split("-") + geoid = None + for fips, code in STATE_CODES.items(): + if code == state_code: + geoid = f"{fips}{int(dist_num):02d}" + break + if geoid is None: + raise ValueError(f"Unknown state in district: {item_id}") + + path = build_district_h5( + cd_geoid=geoid, + weights=weights, + cds_to_calibrate=cds_to_calibrate, + dataset_path=dataset_path, + output_dir=output_dir, + ) + elif item_type == "city": + path = build_city_h5( + city_name=item_id, + weights=weights, + cds_to_calibrate=cds_to_calibrate, + dataset_path=dataset_path, + output_dir=output_dir, + ) + else: + raise ValueError(f"Unknown item type: {item_type}") + + if path: + results["completed"].append(f"{item_type}:{item_id}") + print(f"Completed {item_type}:{item_id}", file=sys.stderr) + + except Exception as e: + results["failed"].append(f"{item_type}:{item_id}") + results["errors"].append({ + "item": f"{item_type}:{item_id}", + "error": str(e), + "traceback": traceback.format_exc(), + }) + print(f"FAILED {item_type}:{item_id}: {e}", file=sys.stderr) + + print(json.dumps(results)) + + +if __name__ == "__main__": + main() diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/publish_local_area.py b/policyengine_us_data/datasets/cps/local_area_calibration/publish_local_area.py index e798addf2..4963f3979 100644 --- a/policyengine_us_data/datasets/cps/local_area_calibration/publish_local_area.py +++ b/policyengine_us_data/datasets/cps/local_area_calibration/publish_local_area.py @@ -11,6 +11,7 @@ import os import numpy as np from pathlib import Path +from typing import List, Optional, Set from policyengine_us import Microsimulation from policyengine_us_data.utils.huggingface import download_calibration_inputs @@ -73,6 +74,165 @@ def record_completed_city(city_name: str): f.write(f"{city_name}\n") +def build_state_h5( + state_code: str, + weights: np.ndarray, + cds_to_calibrate: List[str], + dataset_path: Path, + output_dir: Path, +) -> Optional[Path]: + """ + Build a single state H5 file (build only, no upload). + + Args: + state_code: Two-letter state code (e.g., "AL", "CA") + weights: Calibrated weight vector + cds_to_calibrate: Full list of CD GEOIDs from calibration + dataset_path: Path to base dataset H5 file + output_dir: Output directory for H5 file + + Returns: + Path to output H5 file if successful, None if no CDs found + """ + state_fips = None + for fips, code in STATE_CODES.items(): + if code == state_code: + state_fips = fips + break + + if state_fips is None: + print(f"Unknown state code: {state_code}") + return None + + cd_subset = [cd for cd in cds_to_calibrate if int(cd) // 100 == state_fips] + if not cd_subset: + print(f"No CDs found for {state_code}, skipping") + return None + + states_dir = output_dir / "states" + states_dir.mkdir(parents=True, exist_ok=True) + output_path = states_dir / f"{state_code}.h5" + + print(f"\n{'='*60}") + print(f"Building {state_code} ({len(cd_subset)} CDs)") + print(f"{'='*60}") + + create_sparse_cd_stacked_dataset( + weights, + cds_to_calibrate, + cd_subset=cd_subset, + dataset_path=str(dataset_path), + output_path=str(output_path), + ) + + return output_path + + +def build_district_h5( + cd_geoid: str, + weights: np.ndarray, + cds_to_calibrate: List[str], + dataset_path: Path, + output_dir: Path, +) -> Path: + """ + Build a single district H5 file (build only, no upload). + + Args: + cd_geoid: Congressional district GEOID (e.g., "0101" for AL-01) + weights: Calibrated weight vector + cds_to_calibrate: Full list of CD GEOIDs from calibration + dataset_path: Path to base dataset H5 file + output_dir: Output directory for H5 file + + Returns: + Path to output H5 file + """ + cd_int = int(cd_geoid) + state_fips = cd_int // 100 + district_num = cd_int % 100 + state_code = STATE_CODES.get(state_fips, str(state_fips)) + friendly_name = f"{state_code}-{district_num:02d}" + + districts_dir = output_dir / "districts" + districts_dir.mkdir(parents=True, exist_ok=True) + output_path = districts_dir / f"{friendly_name}.h5" + + print(f"\n{'='*60}") + print(f"Building {friendly_name}") + print(f"{'='*60}") + + create_sparse_cd_stacked_dataset( + weights, + cds_to_calibrate, + cd_subset=[cd_geoid], + dataset_path=str(dataset_path), + output_path=str(output_path), + ) + + return output_path + + +def build_city_h5( + city_name: str, + weights: np.ndarray, + cds_to_calibrate: List[str], + dataset_path: Path, + output_dir: Path, +) -> Optional[Path]: + """ + Build a city H5 file (build only, no upload). + + Currently supports NYC only. + + Args: + city_name: City name (currently only "NYC" supported) + weights: Calibrated weight vector + cds_to_calibrate: Full list of CD GEOIDs from calibration + dataset_path: Path to base dataset H5 file + output_dir: Output directory for H5 file + + Returns: + Path to output H5 file if successful, None otherwise + """ + if city_name != "NYC": + print(f"Unsupported city: {city_name}") + return None + + cd_subset = [cd for cd in cds_to_calibrate if cd in NYC_CDS] + if not cd_subset: + print("No NYC-related CDs found, skipping") + return None + + cities_dir = output_dir / "cities" + cities_dir.mkdir(parents=True, exist_ok=True) + output_path = cities_dir / "NYC.h5" + + print(f"\n{'='*60}") + print(f"Building NYC ({len(cd_subset)} CDs)") + print(f"{'='*60}") + + create_sparse_cd_stacked_dataset( + weights, + cds_to_calibrate, + cd_subset=cd_subset, + dataset_path=str(dataset_path), + output_path=str(output_path), + county_filter=NYC_COUNTIES, + ) + + return output_path + + +def get_district_friendly_name(cd_geoid: str) -> str: + """Convert GEOID to friendly name (e.g., '0101' -> 'AL-01').""" + cd_int = int(cd_geoid) + state_fips = cd_int // 100 + district_num = cd_int % 100 + state_code = STATE_CODES.get(state_fips, str(state_fips)) + return f"{state_code}-{district_num:02d}" + + def build_and_upload_states( weights_path: Path, dataset_path: Path, diff --git a/policyengine_us_data/db/etl_medicaid.py b/policyengine_us_data/db/etl_medicaid.py index 2038802bd..ed1841447 100644 --- a/policyengine_us_data/db/etl_medicaid.py +++ b/policyengine_us_data/db/etl_medicaid.py @@ -88,13 +88,51 @@ def transform_administrative_medicaid_data(state_admin_df, year): "Reporting Period", "Total Medicaid Enrollment", ], - ] + ].copy() state_df["FIPS"] = state_df["State Abbreviation"].map(STATE_ABBREV_TO_FIPS) state_df = state_df.rename( columns={"Total Medicaid Enrollment": "medicaid_enrollment"} ) + + # Handle states with 0 or NaN enrollment by using most recent non-zero value + # This addresses data quality issues where some states have missing Dec data + problem_states = state_df[ + (state_df["medicaid_enrollment"] == 0) + | (state_df["medicaid_enrollment"].isna()) + ]["State Abbreviation"].tolist() + + if problem_states: + print( + f"Warning: States with 0/NaN enrollment in {reporting_period}: {problem_states}" + ) + print("Attempting to use most recent non-zero values...") + + for state_abbrev in problem_states: + # Find most recent non-zero final report for this state + state_history = state_admin_df[ + (state_admin_df["State Abbreviation"] == state_abbrev) + & (state_admin_df["Final Report"] == "Y") + & (state_admin_df["Total Medicaid Enrollment"] > 0) + & (state_admin_df["Reporting Period"] < reporting_period) + ].sort_values("Reporting Period", ascending=False) + + if not state_history.empty: + fallback_value = state_history.iloc[0][ + "Total Medicaid Enrollment" + ] + fallback_period = state_history.iloc[0]["Reporting Period"] + print( + f" {state_abbrev}: Using {fallback_value:,.0f} from period {fallback_period}" + ) + state_df.loc[ + state_df["State Abbreviation"] == state_abbrev, + "medicaid_enrollment", + ] = fallback_value + else: + print(f" {state_abbrev}: No historical data found, keeping 0") + state_df["ucgid_str"] = "0400000US" + state_df["FIPS"].astype(str) return state_df[["ucgid_str", "medicaid_enrollment"]] diff --git a/policyengine_us_data/utils/data_upload.py b/policyengine_us_data/utils/data_upload.py index 039364a1f..8428fc6d9 100644 --- a/policyengine_us_data/utils/data_upload.py +++ b/policyengine_us_data/utils/data_upload.py @@ -1,13 +1,27 @@ -from typing import List +from typing import List, Dict, Optional, Tuple from huggingface_hub import HfApi, CommitOperationAdd from huggingface_hub.errors import RevisionNotFoundError from google.cloud import storage from pathlib import Path from importlib import metadata import google.auth +import httpx +import json import logging import os +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + retry_if_exception_type, + before_sleep_log, +) + +DEFAULT_HF_TIMEOUT = 300 +MAX_RETRIES = 5 +RETRY_BASE_WAIT = 30 + def upload_data_files( files: List[str], @@ -224,3 +238,240 @@ def upload_local_area_batch_to_hf( logging.info( f"Uploaded {len(operations)} files to Hugging Face {hf_repo_name} in single commit." ) + + +@retry( + stop=stop_after_attempt(MAX_RETRIES), + wait=wait_exponential(multiplier=RETRY_BASE_WAIT, min=30, max=300), + retry=retry_if_exception_type( + ( + httpx.ReadTimeout, + httpx.ConnectTimeout, + httpx.RemoteProtocolError, + ConnectionError, + ) + ), + before_sleep=before_sleep_log(logging.getLogger(), logging.WARNING), +) +def hf_create_commit_with_retry( + api: HfApi, + operations: List[CommitOperationAdd], + repo_id: str, + repo_type: str, + token: str, + commit_message: str, +): + """ + Create HuggingFace commit with retry logic for timeout errors. + + Uses exponential backoff: 30s, 60s, 120s, 240s, 300s (capped) + """ + return api.create_commit( + token=token, + repo_id=repo_id, + operations=operations, + repo_type=repo_type, + commit_message=commit_message, + ) + + +def upload_versioned_files_to_gcs( + files_with_paths: List[Tuple[Path, str]], + version: str, + gcs_bucket_name: str = "policyengine-us-data", +) -> int: + """ + Upload files to versioned paths in GCS. + + Args: + files_with_paths: List of (local_path, relative_path) tuples + relative_path is like "states/AL.h5" + version: Version string (e.g., "1.56.0") + gcs_bucket_name: Target bucket name + + Returns: + Number of files uploaded + """ + credentials, project_id = google.auth.default() + storage_client = storage.Client( + credentials=credentials, project=project_id + ) + bucket = storage_client.bucket(gcs_bucket_name) + + uploaded = 0 + for local_path, rel_path in files_with_paths: + local_path = Path(local_path) + if not local_path.exists(): + logging.warning(f"File {local_path} does not exist, skipping.") + continue + + blob_name = f"v{version}/{rel_path}" + blob = bucket.blob(blob_name) + blob.upload_from_filename(local_path) + blob.metadata = {"version": version} + blob.patch() + uploaded += 1 + logging.info(f"Uploaded {blob_name} to GCS bucket {gcs_bucket_name}.") + + return uploaded + + +def upload_versioned_files_to_hf( + files_with_paths: List[Tuple[Path, str]], + version: str, + hf_repo_name: str = "policyengine/policyengine-us-data", + hf_repo_type: str = "model", + commit_message: Optional[str] = None, +) -> int: + """ + Upload files to versioned paths in HuggingFace with retry logic. + + Args: + files_with_paths: List of (local_path, relative_path) tuples + version: Version string + hf_repo_name: HuggingFace repository name + hf_repo_type: Repository type + commit_message: Optional custom commit message + + Returns: + Number of files uploaded + """ + token = os.environ.get("HUGGING_FACE_TOKEN") + api = HfApi() + + operations = [] + for local_path, rel_path in files_with_paths: + local_path = Path(local_path) + if not local_path.exists(): + logging.warning(f"File {local_path} does not exist, skipping.") + continue + operations.append( + CommitOperationAdd( + path_in_repo=f"v{version}/{rel_path}", + path_or_fileobj=str(local_path), + ) + ) + + if not operations: + logging.warning("No files to upload to HuggingFace.") + return 0 + + if commit_message is None: + commit_message = ( + f"Upload {len(operations)} files for version {version}" + ) + + hf_create_commit_with_retry( + api=api, + operations=operations, + repo_id=hf_repo_name, + repo_type=hf_repo_type, + token=token, + commit_message=commit_message, + ) + + logging.info( + f"Uploaded {len(operations)} files to HuggingFace {hf_repo_name} " + f"at v{version}/ with retry support." + ) + return len(operations) + + +def upload_manifest_and_latest( + manifest: Dict, + version: str, + previous_version: Optional[str] = None, + gcs_bucket_name: str = "policyengine-us-data", + hf_repo_name: str = "policyengine/policyengine-us-data", + hf_repo_type: str = "model", +) -> None: + """ + Upload manifest.json to versioned path, then update latest.json. + + This is the final step that makes a version "live". The latest.json + pointer is only updated after the manifest is successfully uploaded. + + Args: + manifest: Manifest dictionary (from manifest.py) + version: Current version string + previous_version: Version being replaced (for history) + gcs_bucket_name: GCS bucket name + hf_repo_name: HuggingFace repository + hf_repo_type: Repository type + """ + from policyengine_us_data.utils.manifest import create_latest_pointer + + token = os.environ.get("HUGGING_FACE_TOKEN") + credentials, project_id = google.auth.default() + storage_client = storage.Client( + credentials=credentials, project=project_id + ) + bucket = storage_client.bucket(gcs_bucket_name) + api = HfApi() + + manifest_json = json.dumps(manifest, indent=2) + manifest_blob_name = f"v{version}/manifest.json" + + gcs_manifest_blob = bucket.blob(manifest_blob_name) + gcs_manifest_blob.upload_from_string( + manifest_json, content_type="application/json" + ) + logging.info(f"Uploaded {manifest_blob_name} to GCS.") + + hf_create_commit_with_retry( + api=api, + operations=[ + CommitOperationAdd( + path_in_repo=manifest_blob_name, + path_or_fileobj=manifest_json.encode("utf-8"), + ) + ], + repo_id=hf_repo_name, + repo_type=hf_repo_type, + token=token, + commit_message=f"Upload manifest for version {version}", + ) + logging.info(f"Uploaded {manifest_blob_name} to HuggingFace.") + + previous_versions = None + try: + existing_latest_blob = bucket.blob("latest.json") + if existing_latest_blob.exists(): + existing_latest = json.loads( + existing_latest_blob.download_as_string() + ) + previous_versions = existing_latest.get("previous_versions", []) + if previous_version is None: + previous_version = existing_latest.get("current_version") + except Exception as e: + logging.warning(f"Could not read existing latest.json: {e}") + + latest = create_latest_pointer( + version=version, + previous_version=previous_version, + previous_versions=previous_versions, + ) + latest_json = json.dumps(latest, indent=2) + + gcs_latest_blob = bucket.blob("latest.json") + gcs_latest_blob.upload_from_string( + latest_json, content_type="application/json" + ) + logging.info("Updated latest.json in GCS.") + + hf_create_commit_with_retry( + api=api, + operations=[ + CommitOperationAdd( + path_in_repo="latest.json", + path_or_fileobj=latest_json.encode("utf-8"), + ) + ], + repo_id=hf_repo_name, + repo_type=hf_repo_type, + token=token, + commit_message=f"Update latest.json to version {version}", + ) + logging.info("Updated latest.json in HuggingFace.") + + logging.info(f"Version {version} is now live!") diff --git a/policyengine_us_data/utils/manifest.py b/policyengine_us_data/utils/manifest.py new file mode 100644 index 000000000..831cb2147 --- /dev/null +++ b/policyengine_us_data/utils/manifest.py @@ -0,0 +1,188 @@ +""" +Manifest utilities for atomic deployment of local area H5 files. + +Provides checksum computation, manifest generation, and verification +for ensuring data integrity during uploads. +""" + +import hashlib +import json +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional + + +def compute_file_checksum(file_path: Path) -> str: + """ + Compute SHA256 checksum of a file. + + Args: + file_path: Path to the file + + Returns: + Hex-encoded SHA256 hash string + """ + sha256 = hashlib.sha256() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(8192), b""): + sha256.update(chunk) + return sha256.hexdigest() + + +def generate_manifest( + staging_dir: Path, + version: str, + categories: Optional[List[str]] = None, +) -> Dict: + """ + Generate manifest.json for all H5 files in staging directory. + + Args: + staging_dir: Root staging directory (contains version subdirs) + version: Version string (e.g., "1.56.0") + categories: List of categories to include (default: states, districts, + cities) + + Returns: + Manifest dictionary with structure: + { + "version": "1.56.0", + "created_at": "2026-01-29T12:00:00Z", + "files": { + "states/AL.h5": {"sha256": "...", "size_bytes": 12345}, + ... + }, + "totals": { + "states": 50, + "districts": 435, + "cities": 1, + "total_size_bytes": 987654321 + } + } + """ + if categories is None: + categories = ["states", "districts", "cities"] + + manifest = { + "version": version, + "created_at": datetime.utcnow().isoformat() + "Z", + "files": {}, + "totals": {cat: 0 for cat in categories}, + } + manifest["totals"]["total_size_bytes"] = 0 + + version_dir = staging_dir / version + + for category in categories: + category_dir = version_dir / category + if not category_dir.exists(): + continue + + for h5_file in sorted(category_dir.glob("*.h5")): + rel_path = f"{category}/{h5_file.name}" + file_size = h5_file.stat().st_size + + manifest["files"][rel_path] = { + "sha256": compute_file_checksum(h5_file), + "size_bytes": file_size, + } + manifest["totals"][category] += 1 + manifest["totals"]["total_size_bytes"] += file_size + + return manifest + + +def verify_manifest(staging_dir: Path, manifest: Dict) -> Dict: + """ + Verify all files in manifest exist and have correct checksums. + + Args: + staging_dir: Root staging directory + manifest: Manifest dictionary to verify against + + Returns: + Verification result: + { + "valid": True/False, + "missing": ["states/AL.h5", ...], + "checksum_mismatch": ["districts/CA-01.h5", ...], + "verified": 486 + } + """ + version = manifest["version"] + version_dir = staging_dir / version + + result = { + "valid": True, + "missing": [], + "checksum_mismatch": [], + "verified": 0, + } + + for rel_path, file_info in manifest["files"].items(): + file_path = version_dir / rel_path + + if not file_path.exists(): + result["missing"].append(rel_path) + result["valid"] = False + continue + + actual_checksum = compute_file_checksum(file_path) + if actual_checksum != file_info["sha256"]: + result["checksum_mismatch"].append(rel_path) + result["valid"] = False + continue + + result["verified"] += 1 + + return result + + +def save_manifest(manifest: Dict, output_path: Path) -> None: + """Save manifest to JSON file.""" + with open(output_path, "w") as f: + json.dump(manifest, f, indent=2) + + +def load_manifest(manifest_path: Path) -> Dict: + """Load manifest from JSON file.""" + with open(manifest_path, "r") as f: + return json.load(f) + + +def create_latest_pointer( + version: str, + previous_version: Optional[str] = None, + previous_versions: Optional[List[Dict]] = None, +) -> Dict: + """ + Create latest.json pointer structure. + + Args: + version: Current version to point to + previous_version: The version being replaced (will be added to history) + previous_versions: Existing version history (from old latest.json) + + Returns: + Latest pointer dictionary: + { + "current_version": "1.56.0", + "updated_at": "2026-01-29T12:00:00Z", + "manifest_url": "v1.56.0/manifest.json", + "previous_versions": [...] + } + """ + now = datetime.utcnow().isoformat() + "Z" + + history = [] + if previous_version: + history.append({"version": previous_version, "deprecated_at": now}) + if previous_versions: + history.extend(previous_versions[:9]) + + return { + "current_version": version, + "updated_at": now, + "manifest_url": f"v{version}/manifest.json", + "previous_versions": history, + } diff --git a/pyproject.toml b/pyproject.toml index 01a89d4f4..d603754ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dependencies = [ "sqlmodel>=0.0.24", "xlrd>=2.0.2", "spm-calculator>=0.1.0", + "tenacity>=8.0.0", ] [project.optional-dependencies] From 08e851df712ddef3e6140f37c7c87d9e17183165 Mon Sep 17 00:00:00 2001 From: "baogorek@gmail.com" Date: Thu, 29 Jan 2026 11:08:54 -0500 Subject: [PATCH 6/8] chore: update uv.lock for tenacity dependency Co-Authored-By: Claude Opus 4.5 --- uv.lock | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/uv.lock b/uv.lock index f78a18fe3..a917d5345 100644 --- a/uv.lock +++ b/uv.lock @@ -637,6 +637,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f8/0a/a3871375c7b9727edaeeea994bfff7c63ff7804c9829c19309ba2e058807/greenlet-3.3.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:b01548f6e0b9e9784a2c99c5651e5dc89ffcbe870bc5fb2e5ef864e9cc6b5dcb", size = 276379, upload-time = "2025-12-04T14:23:30.498Z" }, { url = "https://files.pythonhosted.org/packages/43/ab/7ebfe34dce8b87be0d11dae91acbf76f7b8246bf9d6b319c741f99fa59c6/greenlet-3.3.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:349345b770dc88f81506c6861d22a6ccd422207829d2c854ae2af8025af303e3", size = 597294, upload-time = "2025-12-04T14:50:06.847Z" }, { url = "https://files.pythonhosted.org/packages/a4/39/f1c8da50024feecd0793dbd5e08f526809b8ab5609224a2da40aad3a7641/greenlet-3.3.0-cp312-cp312-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:e8e18ed6995e9e2c0b4ed264d2cf89260ab3ac7e13555b8032b25a74c6d18655", size = 607742, upload-time = "2025-12-04T14:57:42.349Z" }, + { url = "https://files.pythonhosted.org/packages/77/cb/43692bcd5f7a0da6ec0ec6d58ee7cddb606d055ce94a62ac9b1aa481e969/greenlet-3.3.0-cp312-cp312-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:c024b1e5696626890038e34f76140ed1daf858e37496d33f2af57f06189e70d7", size = 622297, upload-time = "2025-12-04T15:07:13.552Z" }, { url = "https://files.pythonhosted.org/packages/75/b0/6bde0b1011a60782108c01de5913c588cf51a839174538d266de15e4bf4d/greenlet-3.3.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:047ab3df20ede6a57c35c14bf5200fcf04039d50f908270d3f9a7a82064f543b", size = 609885, upload-time = "2025-12-04T14:26:02.368Z" }, { url = "https://files.pythonhosted.org/packages/49/0e/49b46ac39f931f59f987b7cd9f34bfec8ef81d2a1e6e00682f55be5de9f4/greenlet-3.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2d9ad37fc657b1102ec880e637cccf20191581f75c64087a549e66c57e1ceb53", size = 1567424, upload-time = "2025-12-04T15:04:23.757Z" }, { url = "https://files.pythonhosted.org/packages/05/f5/49a9ac2dff7f10091935def9165c90236d8f175afb27cbed38fb1d61ab6b/greenlet-3.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:83cd0e36932e0e7f36a64b732a6f60c2fc2df28c351bae79fbaf4f8092fe7614", size = 1636017, upload-time = "2025-12-04T14:27:29.688Z" }, @@ -644,6 +645,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/02/2f/28592176381b9ab2cafa12829ba7b472d177f3acc35d8fbcf3673d966fff/greenlet-3.3.0-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:a1e41a81c7e2825822f4e068c48cb2196002362619e2d70b148f20a831c00739", size = 275140, upload-time = "2025-12-04T14:23:01.282Z" }, { url = "https://files.pythonhosted.org/packages/2c/80/fbe937bf81e9fca98c981fe499e59a3f45df2a04da0baa5c2be0dca0d329/greenlet-3.3.0-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9f515a47d02da4d30caaa85b69474cec77b7929b2e936ff7fb853d42f4bf8808", size = 599219, upload-time = "2025-12-04T14:50:08.309Z" }, { url = "https://files.pythonhosted.org/packages/c2/ff/7c985128f0514271b8268476af89aee6866df5eec04ac17dcfbc676213df/greenlet-3.3.0-cp313-cp313-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:7d2d9fd66bfadf230b385fdc90426fcd6eb64db54b40c495b72ac0feb5766c54", size = 610211, upload-time = "2025-12-04T14:57:43.968Z" }, + { url = "https://files.pythonhosted.org/packages/79/07/c47a82d881319ec18a4510bb30463ed6891f2ad2c1901ed5ec23d3de351f/greenlet-3.3.0-cp313-cp313-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:30a6e28487a790417d036088b3bcb3f3ac7d8babaa7d0139edbaddebf3af9492", size = 624311, upload-time = "2025-12-04T15:07:14.697Z" }, { url = "https://files.pythonhosted.org/packages/fd/8e/424b8c6e78bd9837d14ff7df01a9829fc883ba2ab4ea787d4f848435f23f/greenlet-3.3.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:087ea5e004437321508a8d6f20efc4cfec5e3c30118e1417ea96ed1d93950527", size = 612833, upload-time = "2025-12-04T14:26:03.669Z" }, { url = "https://files.pythonhosted.org/packages/b5/ba/56699ff9b7c76ca12f1cdc27a886d0f81f2189c3455ff9f65246780f713d/greenlet-3.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ab97cf74045343f6c60a39913fa59710e4bd26a536ce7ab2397adf8b27e67c39", size = 1567256, upload-time = "2025-12-04T15:04:25.276Z" }, { url = "https://files.pythonhosted.org/packages/1e/37/f31136132967982d698c71a281a8901daf1a8fbab935dce7c0cf15f942cc/greenlet-3.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5375d2e23184629112ca1ea89a53389dddbffcf417dad40125713d88eb5f96e8", size = 1636483, upload-time = "2025-12-04T14:27:30.804Z" }, @@ -1877,6 +1879,7 @@ dependencies = [ { name = "sqlmodel" }, { name = "statsmodels" }, { name = "tables" }, + { name = "tenacity" }, { name = "torch" }, { name = "tqdm" }, { name = "us" }, @@ -1927,6 +1930,7 @@ requires-dist = [ { name = "sqlmodel", specifier = ">=0.0.24" }, { name = "statsmodels", specifier = ">=0.14.5" }, { name = "tables", specifier = ">=3.10.2" }, + { name = "tenacity", specifier = ">=8.0.0" }, { name = "torch", specifier = ">=2.7.1" }, { name = "tqdm", specifier = ">=4.60.0" }, { name = "us", specifier = ">=2.0.0" }, From ae0237c49dd9d728db308dcf10d12c34310b43fc Mon Sep 17 00:00:00 2001 From: "baogorek@gmail.com" Date: Thu, 29 Jan 2026 12:44:05 -0500 Subject: [PATCH 7/8] fix: correct calibration input paths for HuggingFace download Co-Authored-By: Claude Opus 4.5 --- modal_app/local_area.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/modal_app/local_area.py b/modal_app/local_area.py index 84f82a8fa..c9e0624f9 100644 --- a/modal_app/local_area.py +++ b/modal_app/local_area.py @@ -360,9 +360,10 @@ def coordinate_publish( calibration_dir = staging_dir / "calibration_inputs" calibration_dir.mkdir(parents=True, exist_ok=True) - weights_path = calibration_dir / "w_district_calibration.npy" - dataset_path = calibration_dir / "stratified_extended_cps.h5" - db_path = calibration_dir / "policy_data.db" + # hf_hub_download preserves directory structure, so files are in calibration/ subdir + weights_path = calibration_dir / "calibration" / "w_district_calibration.npy" + dataset_path = calibration_dir / "calibration" / "stratified_extended_cps.h5" + db_path = calibration_dir / "calibration" / "policy_data.db" if not all(p.exists() for p in [weights_path, dataset_path, db_path]): print("Downloading calibration inputs...") From afb8e1fe26e455bc13e9e1c184c57bb8127146f9 Mon Sep 17 00:00:00 2001 From: "baogorek@gmail.com" Date: Thu, 29 Jan 2026 13:19:40 -0500 Subject: [PATCH 8/8] chore: format code and update changelog for parallel publishing Co-Authored-By: Claude Opus 4.5 --- changelog_entry.yaml | 3 +++ modal_app/local_area.py | 8 ++++++-- modal_app/worker_script.py | 12 +++++++----- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/changelog_entry.yaml b/changelog_entry.yaml index bfc2edd01..f501ad2f7 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -14,3 +14,6 @@ added: - Added CPS_2024_Full class for full-sample 2024 CPS generation - Added raw_cache utility for Census data caching + - Added atomic parallel local area H5 publishing with Modal Volume staging + - Added manifest validation with SHA256 checksums for versioned uploads + - Added HuggingFace retry logic with exponential backoff to fix timeout errors diff --git a/modal_app/local_area.py b/modal_app/local_area.py index c9e0624f9..b2f217a0a 100644 --- a/modal_app/local_area.py +++ b/modal_app/local_area.py @@ -361,8 +361,12 @@ def coordinate_publish( calibration_dir.mkdir(parents=True, exist_ok=True) # hf_hub_download preserves directory structure, so files are in calibration/ subdir - weights_path = calibration_dir / "calibration" / "w_district_calibration.npy" - dataset_path = calibration_dir / "calibration" / "stratified_extended_cps.h5" + weights_path = ( + calibration_dir / "calibration" / "w_district_calibration.npy" + ) + dataset_path = ( + calibration_dir / "calibration" / "stratified_extended_cps.h5" + ) db_path = calibration_dir / "calibration" / "policy_data.db" if not all(p.exists() for p in [weights_path, dataset_path, db_path]): diff --git a/modal_app/worker_script.py b/modal_app/worker_script.py index 95217e4cf..b197260e8 100644 --- a/modal_app/worker_script.py +++ b/modal_app/worker_script.py @@ -95,11 +95,13 @@ def main(): except Exception as e: results["failed"].append(f"{item_type}:{item_id}") - results["errors"].append({ - "item": f"{item_type}:{item_id}", - "error": str(e), - "traceback": traceback.format_exc(), - }) + results["errors"].append( + { + "item": f"{item_type}:{item_id}", + "error": str(e), + "traceback": traceback.format_exc(), + } + ) print(f"FAILED {item_type}:{item_id}: {e}", file=sys.stderr) print(json.dumps(results))