diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..bfd1545c 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,5 @@ +- bump: minor + changes: + added: + - Deduplication logic in SparseMatrixBuilder (option to remove duplicate targets or select most specific geographic level). + - Entity aware target calculations for correct entity counts. \ No newline at end of file diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/calibration_utils.py b/policyengine_us_data/datasets/cps/local_area_calibration/calibration_utils.py index aa954aba..5d63d6c7 100644 --- a/policyengine_us_data/datasets/cps/local_area_calibration/calibration_utils.py +++ b/policyengine_us_data/datasets/cps/local_area_calibration/calibration_utils.py @@ -610,3 +610,78 @@ def calculate_spm_thresholds_for_cd( thresholds[i] = base * equiv_scale * geoadj return thresholds + + +def build_concept_id(variable: str, constraints: List[str]) -> str: + """ + Build normalized concept ID from variable + constraints. + + The concept ID uniquely identifies a calibration target "concept" + based on the variable being measured and its non-geographic constraints. + + Args: + variable: Target variable name (e.g., "person_count", "snap") + constraints: List of constraint strings (e.g., ["age>=5", "age<18"]) + + Returns: + Normalized concept ID string + + Examples: + >>> build_concept_id("person_count", ["age>=5", "age<18"]) + 'person_count_age_gte_5_age_lt_18' + >>> build_concept_id("snap", ["snap>0"]) + 'snap_snap_gt_0' + >>> build_concept_id("snap", []) + 'snap' + """ + if not constraints: + return variable + + # Normalize and sort constraints for consistent IDs + normalized = [] + for c in sorted(constraints): + c_norm = ( + c.replace(">=", "_gte_") + .replace("<=", "_lte_") + .replace(">", "_gt_") + .replace("<", "_lt_") + .replace("==", "_eq_") + .replace("=", "_eq_") + .replace(" ", "") + ) + normalized.append(c_norm) + + return f"{variable}_{'_'.join(normalized)}" + + +def extract_constraints_from_row( + row: pd.Series, exclude_geo: bool = True +) -> List[str]: + """ + Extract constraint list from a target row's constraint_info column. + + Args: + row: DataFrame row with 'constraint_info' column containing + pipe-separated constraints (e.g., "age>=5|age<18|state_fips=6") + exclude_geo: If True, filter out geographic constraints + (state_fips, congressional_district_geoid, tax_unit_is_filer) + + Returns: + List of constraint strings like ["age>=5", "age<18"] + """ + if "constraint_info" not in row or pd.isna(row["constraint_info"]): + return [] + + constraints = row["constraint_info"].split("|") + + if exclude_geo: + geo_vars = [ + "state_fips", + "congressional_district_geoid", + "tax_unit_is_filer", + ] + constraints = [ + c for c in constraints if not any(geo in c for geo in geo_vars) + ] + + return constraints diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/fit_calibration_weights.py b/policyengine_us_data/datasets/cps/local_area_calibration/fit_calibration_weights.py index ee3d3847..af58d446 100644 --- a/policyengine_us_data/datasets/cps/local_area_calibration/fit_calibration_weights.py +++ b/policyengine_us_data/datasets/cps/local_area_calibration/fit_calibration_weights.py @@ -111,11 +111,20 @@ "snap", ], }, + deduplicate=True, + dedup_mode="within_geography", ) -print(f"Matrix shape: {X_sparse.shape}") -print(f"Targets: {len(targets_df)}") +# Print concept and deduplication summaries +builder.print_concept_summary() +builder.print_dedup_summary() +print(f"\nMatrix shape: {X_sparse.shape}") +print(f"Targets after deduplication: {len(targets_df)}") + +# ============================================================================ +# STEP 2: FILTER TO ACHIEVABLE TARGETS +# ============================================================================ # Filter to achievable targets (rows with non-zero data) row_sums = np.array(X_sparse.sum(axis=1)).flatten() achievable_mask = row_sums > 0 @@ -128,7 +137,7 @@ targets_df = targets_df[achievable_mask].reset_index(drop=True) X_sparse = X_sparse[achievable_mask, :] -print(f"Filtered matrix shape: {X_sparse.shape}") +print(f"Final matrix shape: {X_sparse.shape}") # Extract target vector and names targets = targets_df["value"].values @@ -138,14 +147,14 @@ ] # ============================================================================ -# STEP 2: INITIALIZE WEIGHTS +# STEP 3: INITIALIZE WEIGHTS # ============================================================================ initial_weights = np.ones(X_sparse.shape[1]) * 100 print(f"\nInitial weights shape: {initial_weights.shape}") print(f"Initial weights sum: {initial_weights.sum():,.0f}") # ============================================================================ -# STEP 3: CREATE MODEL +# STEP 4: CREATE MODEL # ============================================================================ print("\nCreating SparseCalibrationWeights model...") model = SparseCalibrationWeights( @@ -161,7 +170,7 @@ ) # ============================================================================ -# STEP 4: TRAIN IN CHUNKS +# STEP 5: TRAIN IN CHUNKS # ============================================================================ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") calibration_log = pd.DataFrame() @@ -204,7 +213,7 @@ calibration_log = pd.concat([calibration_log, chunk_df], ignore_index=True) # ============================================================================ -# STEP 5: EXTRACT AND SAVE WEIGHTS +# STEP 6: EXTRACT AND SAVE WEIGHTS # ============================================================================ with torch.no_grad(): w = model.get_weights(deterministic=True).cpu().numpy() @@ -224,7 +233,7 @@ print(f"LOG_PATH:{log_path}") # ============================================================================ -# STEP 6: VERIFY PREDICTIONS +# STEP 7: VERIFY PREDICTIONS # ============================================================================ print("\n" + "=" * 60) print("PREDICTION VERIFICATION") diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py b/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py index b12629fb..c6520300 100644 --- a/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py +++ b/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py @@ -12,7 +12,8 @@ import numpy as np import pandas as pd from scipy import sparse -from sqlalchemy import create_engine, text +from dataclasses import dataclass +from sqlalchemy import create_engine logger = logging.getLogger(__name__) @@ -20,9 +21,21 @@ get_calculated_variables, apply_op, _get_geo_level, + build_concept_id, + extract_constraints_from_row, ) +@dataclass +class ConceptDuplicateWarning: + """Warning when multiple values exist for the same concept.""" + + concept_id: str + duplicates: List[dict] + selected: dict + reason: str + + class SparseMatrixBuilder: """Build sparse calibration matrices for geo-stacking.""" @@ -40,6 +53,12 @@ def __init__( self.dataset_path = dataset_path self._entity_rel_cache = None + # Populated after build_matrix() with deduplicate=True + self.concept_summary: Optional[pd.DataFrame] = None + self.dedup_warnings: List[ConceptDuplicateWarning] = [] + self.targets_before_dedup: Optional[pd.DataFrame] = None + self.targets_after_dedup: Optional[pd.DataFrame] = None + def _build_entity_relationship(self, sim) -> pd.DataFrame: """ Build entity relationship DataFrame mapping persons to all entity IDs. @@ -138,6 +157,117 @@ def _evaluate_constraints_entity_aware( return household_mask + def _calculate_target_values_entity_aware( + self, + state_sim, + target_variable: str, + non_geo_constraints: List[dict], + geo_mask: np.ndarray, + n_households: int, + ) -> np.ndarray: + """ + Calculate target values at household level, handling count targets. + + For count targets (*_count): Count entities per household satisfying + constraints + For value targets: Sum values at household level (existing behavior) + + Args: + state_sim: Microsimulation with state_fips set + target_variable: The target variable name (e.g., "snap", + "person_count") + non_geo_constraints: List of constraint dicts (geographic + constraints should be pre-filtered) + geo_mask: Boolean mask array for geographic filtering (household + level) + n_households: Number of households + + Returns: + Float array of target values at household level + """ + is_count_target = target_variable.endswith("_count") + + if not is_count_target: + # Value target: use existing entity-aware constraint evaluation + entity_mask = self._evaluate_constraints_entity_aware( + state_sim, non_geo_constraints, n_households + ) + mask = geo_mask & entity_mask + + target_values = state_sim.calculate( + target_variable, map_to="household" + ).values + return (target_values * mask).astype(np.float32) + + # Count target: need to count entities satisfying constraints + entity_rel = self._build_entity_relationship(state_sim) + n_persons = len(entity_rel) + + # Evaluate constraints at person level (don't aggregate to HH yet) + person_mask = np.ones(n_persons, dtype=bool) + for c in non_geo_constraints: + constraint_values = state_sim.calculate( + c["variable"], map_to="person" + ).values + person_mask &= apply_op( + constraint_values, c["operation"], c["value"] + ) + + # Get target entity from variable definition + target_entity = state_sim.tax_benefit_system.variables[ + target_variable + ].entity.key + + household_ids = state_sim.calculate( + "household_id", map_to="household" + ).values + geo_mask_map = dict(zip(household_ids, geo_mask)) + + if target_entity == "household": + # household_count: 1 per qualifying household + if non_geo_constraints: + entity_mask = self._evaluate_constraints_entity_aware( + state_sim, non_geo_constraints, n_households + ) + return (geo_mask & entity_mask).astype(np.float32) + return geo_mask.astype(np.float32) + + if target_entity == "person": + # Count persons satisfying constraints per household + entity_rel["satisfies"] = person_mask + entity_rel["geo_ok"] = entity_rel["household_id"].map(geo_mask_map) + filtered = entity_rel[ + entity_rel["satisfies"] & entity_rel["geo_ok"] + ] + counts = filtered.groupby("household_id")["person_id"].nunique() + else: + # For tax_unit, spm_unit: aggregate person mask to entity, then + # count + entity_id_col = f"{target_entity}_id" + entity_rel["satisfies"] = person_mask + entity_satisfies = entity_rel.groupby(entity_id_col)[ + "satisfies" + ].any() + + entity_rel_unique = entity_rel[ + ["household_id", entity_id_col] + ].drop_duplicates() + entity_rel_unique["entity_ok"] = entity_rel_unique[ + entity_id_col + ].map(entity_satisfies) + entity_rel_unique["geo_ok"] = entity_rel_unique[ + "household_id" + ].map(geo_mask_map) + filtered = entity_rel_unique[ + entity_rel_unique["entity_ok"] & entity_rel_unique["geo_ok"] + ] + counts = filtered.groupby("household_id")[entity_id_col].nunique() + + # Build result aligned with household order + return np.array( + [counts.get(hh_id, 0) for hh_id in household_ids], dtype=np.float32 + ) + def _query_targets(self, target_filter: dict) -> pd.DataFrame: """Query targets based on filter criteria using OR logic.""" or_conditions = [] @@ -159,11 +289,10 @@ def _query_targets(self, target_filter: dict) -> pd.DataFrame: or_conditions.append(f"t.stratum_id IN ({ids})") if not or_conditions: - raise ValueError( - "target_filter must specify at least one filter criterion" - ) - - where_clause = " OR ".join(f"({c})" for c in or_conditions) + # No filter criteria: fetch all targets + where_clause = "1=1" + else: + where_clause = " OR ".join(f"({c})" for c in or_conditions) query = f""" SELECT t.target_id, t.stratum_id, t.variable, t.value, t.period, @@ -198,6 +327,235 @@ def _get_geographic_id(self, stratum_id: int) -> str: return c["value"] return "US" + def _get_constraint_info(self, stratum_id: int) -> str: + """Build pipe-separated constraint string for concept identification.""" + constraints = self._get_constraints(stratum_id) + parts = [] + for c in constraints: + op = "==" if c["operation"] == "=" else c["operation"] + parts.append(f"{c['variable']}{op}{c['value']}") + return "|".join(parts) if parts else None + + def _deduplicate_targets( + self, + targets_df: pd.DataFrame, + mode: str = "within_geography", + priority_column: str = "geo_priority", + ) -> pd.DataFrame: + """ + Deduplicate targets by concept before matrix building. + + Stores results in instance attributes for later inspection: + - self.concept_summary: DataFrame summarizing concepts + - self.dedup_warnings: List of ConceptDuplicateWarning + - self.targets_before_dedup: Original targets DataFrame + - self.targets_after_dedup: Deduplicated targets DataFrame + + Args: + targets_df: DataFrame with target rows including geographic_id + and constraint_info columns + mode: Deduplication mode ("within_geography" or + "hierarchical_fallback") + priority_column: Column to sort by when selecting among + duplicates. Lower values = higher priority. + + Returns: + Deduplicated DataFrame with reset index + """ + df = targets_df.copy() + + # Add geo_priority if not present (CD=1, State=2, National=3) + if priority_column not in df.columns: + df["geo_priority"] = df["geographic_id"].apply( + lambda g: 3 if g == "US" else (1 if int(g) >= 100 else 2) + ) + priority_column = "geo_priority" + + # Build concept_id for each row + df["_concept_id"] = df.apply( + lambda row: build_concept_id( + row["variable"], + extract_constraints_from_row(row, exclude_geo=True), + ), + axis=1, + ) + + # Store concept summary + self.concept_summary = df.groupby("_concept_id").agg( + count=("_concept_id", "size"), + variable=("variable", "first"), + geos=("geographic_id", lambda x: list(x.unique())), + ) + + # Store original for comparison + self.targets_before_dedup = df.copy() + + # Determine deduplication key based on mode + if mode == "within_geography": + if "geographic_id" not in df.columns: + raise ValueError( + "Mode 'within_geography' requires 'geographic_id' column" + ) + dedupe_key = ["_concept_id", "geographic_id"] + elif mode == "hierarchical_fallback": + dedupe_key = ["_concept_id"] + else: + raise ValueError( + f"Unknown mode '{mode}'. Use 'within_geography' or " + "'hierarchical_fallback'" + ) + + # Find and process duplicates + warnings = [] + duplicate_mask = df.duplicated(subset=dedupe_key, keep=False) + duplicates_df = df[duplicate_mask] + + if len(duplicates_df) > 0: + for key_vals, group in duplicates_df.groupby(dedupe_key): + if len(group) <= 1: + continue + + dup_list = [] + for _, dup_row in group.iterrows(): + dup_list.append( + { + "geographic_id": dup_row.get("geographic_id", "?"), + "source": dup_row.get("source_name", "?"), + "period": dup_row.get("period", "?"), + "value": dup_row.get("value", "?"), + "stratum_id": dup_row.get("stratum_id", "?"), + } + ) + + sorted_group = group.sort_values(priority_column) + selected_row = sorted_group.iloc[0] + selected = { + "geographic_id": selected_row.get("geographic_id", "?"), + "source": selected_row.get("source_name", "?"), + "period": selected_row.get("period", "?"), + "value": selected_row.get("value", "?"), + } + + concept_id = ( + key_vals if isinstance(key_vals, str) else key_vals[0] + ) + warnings.append( + ConceptDuplicateWarning( + concept_id=concept_id, + duplicates=dup_list, + selected=selected, + reason=f"Selected by lowest {priority_column}", + ) + ) + + self.dedup_warnings = warnings + + # Deduplicate: sort by key + priority, keep first per key + sort_cols = ( + dedupe_key + [priority_column] + if priority_column in df.columns + else dedupe_key + ) + df_sorted = df.sort_values(sort_cols) + df_deduped = df_sorted.drop_duplicates(subset=dedupe_key, keep="first") + + # Clean up temporary column + df_deduped = df_deduped.drop(columns=["_concept_id"]) + + self.targets_after_dedup = df_deduped.copy() + + return df_deduped.reset_index(drop=True) + + def print_concept_summary(self) -> None: + """ + Print detailed concept summary from the last build_matrix() call. + + Call this after build_matrix() to see what concepts were found. + """ + if self.concept_summary is None: + print("No concept summary available. Run build_matrix() first.") + return + + print("\n" + "=" * 60) + print("CONCEPT SUMMARY") + print("=" * 60) + + n_targets = ( + len(self.targets_before_dedup) + if self.targets_before_dedup is not None + else 0 + ) + print( + f"Found {len(self.concept_summary)} unique concepts " + f"from {n_targets} targets:\n" + ) + + for concept_id, row in self.concept_summary.iterrows(): + n_geos = len(row["geos"]) + print(f" {concept_id}") + print( + f" Variable: {row['variable']}, " + f"Targets: {row['count']}, Geographies: {n_geos}" + ) + + def print_dedup_summary(self) -> None: + """ + Print deduplication summary from the last build_matrix() call. + + Call this after build_matrix() to see what duplicates were removed. + """ + if self.targets_before_dedup is None: + print("No dedup summary available. Run build_matrix() first.") + return + + print("\n" + "=" * 60) + print("DEDUPLICATION SUMMARY") + print("=" * 60) + + before = len(self.targets_before_dedup) + after = ( + len(self.targets_after_dedup) + if self.targets_after_dedup is not None + else 0 + ) + removed = before - after + + print(f"Total targets queried: {before}") + print(f"Targets after deduplication: {after}") + print(f"Duplicates removed: {removed}") + + if self.dedup_warnings: + print(f"\nDuplicate groups resolved ({len(self.dedup_warnings)}):") + for w in self.dedup_warnings: + print(f"\n Concept: {w.concept_id}") + sel_val = w.selected["value"] + sel_val_str = ( + f"{sel_val:,.0f}" + if isinstance(sel_val, (int, float)) + else str(sel_val) + ) + print( + f" Selected: geo={w.selected['geographic_id']}, " + f"value={sel_val_str}" + ) + print(f" Removed ({len(w.duplicates) - 1}):") + for dup in w.duplicates: + if ( + dup["value"] != w.selected["value"] + or dup["geographic_id"] != w.selected["geographic_id"] + ): + dup_val = dup["value"] + dup_val_str = ( + f"{dup_val:,.0f}" + if isinstance(dup_val, (int, float)) + else str(dup_val) + ) + print( + f" - geo={dup['geographic_id']}, " + f"value={dup_val_str}, " + f"source={dup.get('source', '?')}" + ) + def _create_state_sim(self, state: int, n_households: int): """Create a fresh simulation with state_fips set to given state.""" from policyengine_us import Microsimulation @@ -213,19 +571,34 @@ def _create_state_sim(self, state: int, n_households: int): return state_sim def build_matrix( - self, sim, target_filter: dict + self, + sim, + target_filter: dict, + deduplicate: bool = True, + dedup_mode: str = "within_geography", ) -> Tuple[pd.DataFrame, sparse.csr_matrix, Dict[str, List[str]]]: """ Build sparse calibration matrix. Args: - sim: Microsimulation instance (used for household_ids, or as template) + sim: Microsimulation instance (used for household_ids, or + as template) target_filter: Dict specifying which targets to include - {"stratum_group_ids": [4]} for SNAP targets - {"target_ids": [123, 456]} for specific targets + - an empty dict {} will fetch all targets + deduplicate: If True, deduplicate targets by concept before + building the matrix (default True) + dedup_mode: Deduplication mode - "within_geography" (default) + removes duplicates with same concept AND geography, or + "hierarchical_fallback" keeps most specific geography + per concept Returns: Tuple of (targets_df, X_sparse, household_id_mapping) + + After calling this method, you can use print_concept_summary() and + print_dedup_summary() to see details about concepts and deduplication. """ household_ids = sim.calculate( "household_id", map_to="household" @@ -235,16 +608,24 @@ def build_matrix( n_cols = n_households * n_cds targets_df = self._query_targets(target_filter) - n_targets = len(targets_df) - if n_targets == 0: + if len(targets_df) == 0: raise ValueError("No targets found matching filter") targets_df["geographic_id"] = targets_df["stratum_id"].apply( self._get_geographic_id ) + targets_df["constraint_info"] = targets_df["stratum_id"].apply( + self._get_constraint_info + ) + + # Deduplicate targets by concept before building matrix + if deduplicate: + targets_df = self._deduplicate_targets(targets_df, mode=dedup_mode) + + n_targets = len(targets_df) - # Sort by (geo_level, variable, geographic_id) for contiguous group rows + # Sort by (geo_level, variable, geographic_id) for contiguous group targets_df["_geo_level"] = targets_df["geographic_id"].apply( _get_geo_level ) @@ -316,22 +697,20 @@ def build_matrix( if not geo_mask.any(): continue - # Evaluate non-geographic constraints at entity level - entity_mask = self._evaluate_constraints_entity_aware( - state_sim, non_geo_constraints, n_households + # Calculate target values with entity-aware handling + # This properly handles count targets (*_count) by counting + # entities rather than summing values + masked_values = self._calculate_target_values_entity_aware( + state_sim, + target["variable"], + non_geo_constraints, + geo_mask, + n_households, ) - # Combine geographic and entity-aware masks - mask = geo_mask & entity_mask - - if not mask.any(): + if not masked_values.any(): continue - target_values = state_sim.calculate( - target["variable"], map_to="household" - ).values - masked_values = (target_values * mask).astype(np.float32) - nonzero = np.where(masked_values != 0)[0] if len(nonzero) > 0: X[row_idx, col_start + nonzero] = masked_values[ diff --git a/policyengine_us_data/tests/test_local_area_calibration/conftest.py b/policyengine_us_data/tests/test_local_area_calibration/conftest.py index 7abcbafb..70df400a 100644 --- a/policyengine_us_data/tests/test_local_area_calibration/conftest.py +++ b/policyengine_us_data/tests/test_local_area_calibration/conftest.py @@ -23,6 +23,11 @@ # Format: (variable_name, rtol) # variable_name as per the targets in policy_data.db # rtol is relative tolerance for comparison +# +# NOTE: Count targets (person_count, tax_unit_count) are excluded because +# they have constraints (e.g., age>=5|age<18) that make the X_sparse values +# different from raw sim.calculate() values. Count targets are tested +# separately in test_count_targets.py with controlled mock data. VARIABLES_TO_TEST = [ ("snap", 1e-2), ("income_tax", 1e-2), diff --git a/policyengine_us_data/tests/test_local_area_calibration/test_concept_deduplication.py b/policyengine_us_data/tests/test_local_area_calibration/test_concept_deduplication.py new file mode 100644 index 00000000..57fe510e --- /dev/null +++ b/policyengine_us_data/tests/test_local_area_calibration/test_concept_deduplication.py @@ -0,0 +1,439 @@ +""" +Tests for concept ID building, constraint extraction, and deduplication. + +These tests verify that: +1. Concept IDs are built consistently from variable + non-geo constraints +2. Constraints are correctly extracted from DataFrame rows +3. Deduplication correctly identifies and removes duplicates via the builder +""" + +import unittest +import tempfile +import os +import pandas as pd +from sqlalchemy import create_engine, text + +from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( + build_concept_id, + extract_constraints_from_row, +) +from policyengine_us_data.datasets.cps.local_area_calibration.sparse_matrix_builder import ( + SparseMatrixBuilder, +) + + +class TestBuildConceptId(unittest.TestCase): + """Test concept ID building from variable + constraints.""" + + def test_variable_only(self): + """Test concept ID with no constraints.""" + result = build_concept_id("snap", []) + self.assertEqual(result, "snap") + + def test_single_constraint(self): + """Test concept ID with single constraint.""" + result = build_concept_id("snap", ["snap>0"]) + self.assertEqual(result, "snap_snap_gt_0") + + def test_multiple_constraints_sorted(self): + """Test that constraints are sorted for consistency.""" + # Order shouldn't matter - result should be the same + result1 = build_concept_id("person_count", ["age>=5", "age<18"]) + result2 = build_concept_id("person_count", ["age<18", "age>=5"]) + self.assertEqual(result1, result2) + self.assertEqual(result1, "person_count_age_lt_18_age_gte_5") + + def test_operator_normalization(self): + """Test that operators are normalized correctly.""" + self.assertIn("_gte_", build_concept_id("x", ["a>=1"])) + self.assertIn("_lte_", build_concept_id("x", ["a<=1"])) + self.assertIn("_gt_", build_concept_id("x", ["a>1"])) + self.assertIn("_lt_", build_concept_id("x", ["a<1"])) + self.assertIn("_eq_", build_concept_id("x", ["a==1"])) + self.assertIn("_eq_", build_concept_id("x", ["a=1"])) + + def test_spaces_removed(self): + """Test that spaces are removed from constraints.""" + result = build_concept_id("x", ["age >= 5"]) + self.assertNotIn(" ", result) + + +class TestExtractConstraints(unittest.TestCase): + """Test constraint extraction from DataFrame rows.""" + + def test_no_constraint_info(self): + """Test row without constraint_info column.""" + row = pd.Series({"variable": "snap", "value": 1000}) + result = extract_constraints_from_row(row) + self.assertEqual(result, []) + + def test_null_constraint_info(self): + """Test row with null constraint_info.""" + row = pd.Series( + {"variable": "snap", "constraint_info": None, "value": 1000} + ) + result = extract_constraints_from_row(row) + self.assertEqual(result, []) + + def test_single_constraint(self): + """Test row with single constraint.""" + row = pd.Series( + {"variable": "snap", "constraint_info": "snap>0", "value": 1000} + ) + result = extract_constraints_from_row(row) + self.assertEqual(result, ["snap>0"]) + + def test_multiple_constraints(self): + """Test row with pipe-separated constraints.""" + row = pd.Series( + { + "variable": "person_count", + "constraint_info": "age>=5|age<18", + "value": 1000, + } + ) + result = extract_constraints_from_row(row) + self.assertEqual(result, ["age>=5", "age<18"]) + + def test_exclude_geo_constraints(self): + """Test that geographic constraints are excluded by default.""" + row = pd.Series( + { + "variable": "person_count", + "constraint_info": "age>=5|state_fips=6|age<18", + "value": 1000, + } + ) + result = extract_constraints_from_row(row, exclude_geo=True) + self.assertEqual(result, ["age>=5", "age<18"]) + self.assertNotIn("state_fips=6", result) + + def test_include_geo_constraints(self): + """Test that geographic constraints can be included.""" + row = pd.Series( + { + "variable": "person_count", + "constraint_info": "age>=5|state_fips=6", + "value": 1000, + } + ) + result = extract_constraints_from_row(row, exclude_geo=False) + self.assertIn("state_fips=6", result) + + def test_exclude_cd_geoid(self): + """Test that CD geoid constraints are excluded.""" + row = pd.Series( + { + "variable": "snap", + "constraint_info": "snap>0|congressional_district_geoid=601", + "value": 1000, + } + ) + result = extract_constraints_from_row(row, exclude_geo=True) + self.assertEqual(result, ["snap>0"]) + + def test_exclude_filer_constraint(self): + """Test that tax_unit_is_filer constraint is excluded.""" + row = pd.Series( + { + "variable": "income_tax", + "constraint_info": "tax_unit_is_filer=True|income>0", + "value": 1000, + } + ) + result = extract_constraints_from_row(row, exclude_geo=True) + self.assertEqual(result, ["income>0"]) + + +class TestBuilderDeduplication(unittest.TestCase): + """Test deduplication logic through SparseMatrixBuilder.""" + + @classmethod + def setUpClass(cls): + """Create a temporary database with test data.""" + cls.temp_db = tempfile.NamedTemporaryFile(suffix=".db", delete=False) + cls.db_path = cls.temp_db.name + cls.temp_db.close() + + cls.db_uri = f"sqlite:///{cls.db_path}" + engine = create_engine(cls.db_uri) + + # Create schema + with engine.connect() as conn: + conn.execute(text(""" + CREATE TABLE stratum_groups ( + stratum_group_id INTEGER PRIMARY KEY, + name TEXT + ) + """)) + conn.execute(text(""" + CREATE TABLE strata ( + stratum_id INTEGER PRIMARY KEY, + stratum_group_id INTEGER + ) + """)) + conn.execute(text(""" + CREATE TABLE stratum_constraints ( + constraint_id INTEGER PRIMARY KEY, + stratum_id INTEGER, + constraint_variable TEXT, + operation TEXT, + value TEXT + ) + """)) + conn.execute(text(""" + CREATE TABLE targets ( + target_id INTEGER PRIMARY KEY, + stratum_id INTEGER, + variable TEXT, + value REAL, + period INTEGER + ) + """)) + conn.commit() + + @classmethod + def tearDownClass(cls): + """Remove temporary database.""" + os.unlink(cls.db_path) + + def setUp(self): + """Clear tables before each test.""" + engine = create_engine(self.db_uri) + with engine.connect() as conn: + conn.execute(text("DELETE FROM targets")) + conn.execute(text("DELETE FROM stratum_constraints")) + conn.execute(text("DELETE FROM strata")) + conn.execute(text("DELETE FROM stratum_groups")) + conn.commit() + + def _insert_test_data(self, strata, constraints, targets): + """Helper to insert test data into database.""" + engine = create_engine(self.db_uri) + with engine.connect() as conn: + # Insert stratum groups + conn.execute( + text("INSERT OR IGNORE INTO stratum_groups VALUES (1, 'test')") + ) + + # Insert strata + for stratum_id, group_id in strata: + conn.execute( + text("INSERT INTO strata VALUES (:sid, :gid)"), + {"sid": stratum_id, "gid": group_id}, + ) + + # Insert constraints + for i, (stratum_id, var, op, val) in enumerate(constraints): + conn.execute( + text(""" + INSERT INTO stratum_constraints + VALUES (:cid, :sid, :var, :op, :val) + """), + { + "cid": i + 1, + "sid": stratum_id, + "var": var, + "op": op, + "val": val, + }, + ) + + # Insert targets + for i, (stratum_id, variable, value, period) in enumerate(targets): + conn.execute( + text(""" + INSERT INTO targets + VALUES (:tid, :sid, :var, :val, :period) + """), + { + "tid": i + 1, + "sid": stratum_id, + "var": variable, + "val": value, + "period": period, + }, + ) + + conn.commit() + + def test_no_duplicates_preserved(self): + """Test that targets with different concepts are all preserved.""" + # Two different variables for the same CD - should NOT deduplicate + self._insert_test_data( + strata=[(1, 1), (2, 1)], + constraints=[ + (1, "congressional_district_geoid", "=", "601"), + (2, "congressional_district_geoid", "=", "601"), + ], + targets=[ + (1, "snap", 1000, 2023), + (2, "medicaid", 2000, 2023), + ], + ) + + builder = SparseMatrixBuilder( + db_uri=self.db_uri, + time_period=2023, + cds_to_calibrate=["601"], + ) + + # Call _deduplicate_targets directly with prepared DataFrame + targets_df = builder._query_targets({"stratum_group_ids": [1]}) + targets_df["geographic_id"] = targets_df["stratum_id"].apply( + builder._get_geographic_id + ) + targets_df["constraint_info"] = targets_df["stratum_id"].apply( + builder._get_constraint_info + ) + + result = builder._deduplicate_targets(targets_df) + + self.assertEqual(len(result), 2) + self.assertEqual(len(builder.dedup_warnings), 0) + + def test_duplicate_same_geo_deduplicated(self): + """Test that same concept at same geography is deduplicated.""" + # Same variable, same CD, different periods - should deduplicate + self._insert_test_data( + strata=[(1, 1), (2, 1)], + constraints=[ + (1, "congressional_district_geoid", "=", "601"), + (2, "congressional_district_geoid", "=", "601"), + ], + targets=[ + (1, "snap", 1000, 2023), + (2, "snap", 1100, 2022), # Same concept, same geo + ], + ) + + builder = SparseMatrixBuilder( + db_uri=self.db_uri, + time_period=2023, + cds_to_calibrate=["601"], + ) + + targets_df = builder._query_targets({"stratum_group_ids": [1]}) + targets_df["geographic_id"] = targets_df["stratum_id"].apply( + builder._get_geographic_id + ) + targets_df["constraint_info"] = targets_df["stratum_id"].apply( + builder._get_constraint_info + ) + + result = builder._deduplicate_targets(targets_df) + + self.assertEqual(len(result), 1) + self.assertEqual(len(builder.dedup_warnings), 1) + + def test_same_concept_different_geos_preserved(self): + """Test that same concept at different geos is NOT deduplicated.""" + # Same variable, different CDs - should NOT deduplicate + self._insert_test_data( + strata=[(1, 1), (2, 1)], + constraints=[ + (1, "congressional_district_geoid", "=", "601"), + (2, "congressional_district_geoid", "=", "602"), + ], + targets=[ + (1, "snap", 1000, 2023), + (2, "snap", 1100, 2023), + ], + ) + + builder = SparseMatrixBuilder( + db_uri=self.db_uri, + time_period=2023, + cds_to_calibrate=["601", "602"], + ) + + targets_df = builder._query_targets({"stratum_group_ids": [1]}) + targets_df["geographic_id"] = targets_df["stratum_id"].apply( + builder._get_geographic_id + ) + targets_df["constraint_info"] = targets_df["stratum_id"].apply( + builder._get_constraint_info + ) + + result = builder._deduplicate_targets(targets_df) + + self.assertEqual(len(result), 2) # Both kept + self.assertEqual(len(builder.dedup_warnings), 0) + + def test_different_constraints_different_concepts(self): + """Test that different constraints create different concepts.""" + # Same variable but different age constraints - different concepts + self._insert_test_data( + strata=[(1, 1), (2, 1)], + constraints=[ + (1, "congressional_district_geoid", "=", "601"), + (1, "age", ">=", "5"), + (1, "age", "<", "18"), + (2, "congressional_district_geoid", "=", "601"), + (2, "age", ">=", "18"), + (2, "age", "<", "65"), + ], + targets=[ + (1, "person_count", 1000, 2023), + (2, "person_count", 2000, 2023), + ], + ) + + builder = SparseMatrixBuilder( + db_uri=self.db_uri, + time_period=2023, + cds_to_calibrate=["601"], + ) + + targets_df = builder._query_targets({"stratum_group_ids": [1]}) + targets_df["geographic_id"] = targets_df["stratum_id"].apply( + builder._get_geographic_id + ) + targets_df["constraint_info"] = targets_df["stratum_id"].apply( + builder._get_constraint_info + ) + + result = builder._deduplicate_targets(targets_df) + + self.assertEqual(len(result), 2) # Different concepts + self.assertEqual(len(builder.dedup_warnings), 0) + + def test_hierarchical_fallback_keeps_most_specific(self): + """Test hierarchical fallback mode keeps CD over state over national.""" + # Same concept at CD, state, and national levels + self._insert_test_data( + strata=[(1, 1), (2, 1), (3, 1)], + constraints=[ + (1, "congressional_district_geoid", "=", "601"), + (2, "state_fips", "=", "6"), + # stratum 3 has no geo constraint = national + ], + targets=[ + (1, "snap", 1200000, 2023), # CD level + (2, "snap", 15000000, 2023), # State level + (3, "snap", 110000000000, 2023), # National level + ], + ) + + builder = SparseMatrixBuilder( + db_uri=self.db_uri, + time_period=2023, + cds_to_calibrate=["601"], + ) + + targets_df = builder._query_targets({"stratum_group_ids": [1]}) + targets_df["geographic_id"] = targets_df["stratum_id"].apply( + builder._get_geographic_id + ) + targets_df["constraint_info"] = targets_df["stratum_id"].apply( + builder._get_constraint_info + ) + + result = builder._deduplicate_targets( + targets_df, mode="hierarchical_fallback" + ) + + self.assertEqual(len(result), 1) + # CD level should be kept (geo_priority=1) + self.assertEqual(result.iloc[0]["geographic_id"], "601") + self.assertEqual(result.iloc[0]["value"], 1200000) diff --git a/policyengine_us_data/tests/test_local_area_calibration/test_count_targets.py b/policyengine_us_data/tests/test_local_area_calibration/test_count_targets.py new file mode 100644 index 00000000..46eae4eb --- /dev/null +++ b/policyengine_us_data/tests/test_local_area_calibration/test_count_targets.py @@ -0,0 +1,415 @@ +""" +Tests for count target handling in SparseMatrixBuilder. + +These tests verify that count targets (e.g., person_count, tax_unit_count) +are correctly handled by counting entities that satisfy constraints, rather +than summing values. +""" + +import pytest +import numpy as np +from dataclasses import dataclass + +from policyengine_us_data.datasets.cps.local_area_calibration.sparse_matrix_builder import ( + SparseMatrixBuilder, +) + + +@dataclass +class MockEntity: + """Mock entity with a key attribute.""" + + key: str + + +@dataclass +class MockVariable: + """Mock variable with entity information.""" + + entity: MockEntity + + @classmethod + def create(cls, entity_key: str) -> "MockVariable": + return cls(entity=MockEntity(key=entity_key)) + + +class MockTaxBenefitSystem: + """Mock tax benefit system with variable definitions.""" + + def __init__(self): + self.variables = { + "person_count": MockVariable.create("person"), + "tax_unit_count": MockVariable.create("tax_unit"), + "household_count": MockVariable.create("household"), + "spm_unit_count": MockVariable.create("spm_unit"), + "snap": MockVariable.create("spm_unit"), + } + + +@dataclass +class MockCalculationResult: + """Mock result from simulation.calculate().""" + + values: np.ndarray + + +class MockSimulation: + """Mock simulation for testing count target calculations.""" + + def __init__(self, entity_data: dict, variable_values: dict): + """ + Args: + entity_data: Dict with person_id, household_id, tax_unit_id, + spm_unit_id arrays (all at person level) + variable_values: Dict mapping variable names to their values + at the appropriate entity level + """ + self.entity_data = entity_data + self.variable_values = variable_values + self.tax_benefit_system = MockTaxBenefitSystem() + + def calculate(self, variable: str, map_to: str = None): + """Return mock calculation result.""" + if variable in self.entity_data: + # Entity ID variables + if map_to == "person": + values = np.array(self.entity_data[variable]) + elif map_to == "household": + # Return unique household IDs + values = np.array( + sorted(set(self.entity_data["household_id"])) + ) + else: + values = np.array(self.entity_data[variable]) + elif variable in self.variable_values: + # Regular variables - return at requested level + val_data = self.variable_values[variable] + if map_to == "person": + values = np.array(val_data["person"]) + elif map_to == "household": + values = np.array(val_data["household"]) + else: + values = np.array(val_data.get("default", [])) + else: + values = np.array([]) + + return MockCalculationResult(values=values) + + +@pytest.fixture +def basic_entity_data(): + """ + Create mock entity relationships with known household compositions. + + Household 1 (id=100): 3 people (ages 5, 12, 40) -> 2 aged 5-17 + Household 2 (id=200): 2 people (ages 3, 25) -> 0 aged 5-17 + Household 3 (id=300): 4 people (ages 6, 8, 10, 45) -> 3 aged 5-17 + """ + return { + "person_id": [1, 2, 3, 4, 5, 6, 7, 8, 9], + "household_id": [100, 100, 100, 200, 200, 300, 300, 300, 300], + "tax_unit_id": [10, 10, 10, 20, 20, 30, 30, 30, 30], + "spm_unit_id": [ + 1000, + 1000, + 1000, + 2000, + 2000, + 3000, + 3000, + 3000, + 3000, + ], + } + + +@pytest.fixture +def basic_variable_values(): + """Variable values for basic household composition tests.""" + return { + "age": { + "person": [5, 12, 40, 3, 25, 6, 8, 10, 45], + "household": [40, 25, 45], # Not used for age constraints + }, + "person_count": { + "person": [1, 1, 1, 1, 1, 1, 1, 1, 1], + "household": [3, 2, 4], # Sum per household + }, + "snap": { + "person": [100, 100, 100, 0, 0, 200, 200, 200, 200], + "household": [300, 0, 800], + }, + } + + +@pytest.fixture +def basic_sim(basic_entity_data, basic_variable_values): + """Mock simulation with basic household compositions.""" + return MockSimulation(basic_entity_data, basic_variable_values) + + +@pytest.fixture +def builder(): + """Create a minimal SparseMatrixBuilder (won't use DB for unit tests).""" + return SparseMatrixBuilder( + db_uri="sqlite:///:memory:", + time_period=2023, + cds_to_calibrate=["101"], + ) + + +# Tests for basic count target calculation +class TestCountTargetCalculation: + """Test _calculate_target_values_entity_aware for count targets.""" + + def test_person_count_with_age_constraints(self, builder, basic_sim): + """Test person_count correctly counts persons in age range per HH.""" + # Constraints: age >= 5 AND age < 18 + constraints = [ + {"variable": "age", "operation": ">=", "value": 5}, + {"variable": "age", "operation": "<", "value": 18}, + ] + + geo_mask = np.array([True, True, True]) # All households included + n_households = 3 + + result = builder._calculate_target_values_entity_aware( + basic_sim, + "person_count", + constraints, + geo_mask, + n_households, + ) + + # Expected: HH1 has 2 people (ages 5, 12), HH2 has 0, HH3 has 3 (6,8,10) + expected = np.array([2, 0, 3], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + def test_person_count_no_constraints(self, builder, basic_sim): + """Test person_count without constraints returns all persons per HH.""" + constraints = [] + geo_mask = np.array([True, True, True]) + n_households = 3 + + result = builder._calculate_target_values_entity_aware( + basic_sim, + "person_count", + constraints, + geo_mask, + n_households, + ) + + # Expected: HH1 has 3 people, HH2 has 2, HH3 has 4 + expected = np.array([3, 2, 4], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + def test_person_count_with_geo_mask(self, builder, basic_sim): + """Test person_count respects geographic mask.""" + constraints = [ + {"variable": "age", "operation": ">=", "value": 5}, + {"variable": "age", "operation": "<", "value": 18}, + ] + + # Only include households 1 and 3 + geo_mask = np.array([True, False, True]) + n_households = 3 + + result = builder._calculate_target_values_entity_aware( + basic_sim, + "person_count", + constraints, + geo_mask, + n_households, + ) + + # Expected: HH1=2, HH2=0 (masked out), HH3=3 + expected = np.array([2, 0, 3], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + def test_value_target_uses_sum(self, builder, basic_sim): + """Test that non-count targets sum values (existing behavior).""" + # SNAP is a value target, not a count target + constraints = [] + geo_mask = np.array([True, True, True]) + n_households = 3 + + result = builder._calculate_target_values_entity_aware( + basic_sim, + "snap", + constraints, + geo_mask, + n_households, + ) + + # Expected: Sum of snap values per household + expected = np.array([300, 0, 800], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + def test_household_count_no_constraints(self, builder, basic_sim): + """Test household_count returns 1 for each qualifying household.""" + constraints = [] + geo_mask = np.array([True, True, True]) + n_households = 3 + + result = builder._calculate_target_values_entity_aware( + basic_sim, + "household_count", + constraints, + geo_mask, + n_households, + ) + + # Expected: 1 for each household in geo_mask + expected = np.array([1, 1, 1], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + def test_household_count_with_geo_mask(self, builder, basic_sim): + """Test household_count respects geographic mask.""" + constraints = [] + geo_mask = np.array([True, False, True]) + n_households = 3 + + result = builder._calculate_target_values_entity_aware( + basic_sim, + "household_count", + constraints, + geo_mask, + n_households, + ) + + # Expected: 1 for HH1, 0 for HH2 (masked), 1 for HH3 + expected = np.array([1, 0, 1], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + +# Fixtures for complex entity relationship tests +@pytest.fixture +def complex_entity_data(): + """ + Create entity data with multiple tax units per household. + + Household 1 (id=100): 4 people in 2 tax units + Tax unit 10: person 1 (age 30, filer), person 2 (age 28) + Tax unit 11: person 3 (age 65, filer), person 4 (age 62) + Household 2 (id=200): 2 people in 1 tax unit + Tax unit 20: person 5 (age 45, filer), person 6 (age 16) + """ + return { + "person_id": [1, 2, 3, 4, 5, 6], + "household_id": [100, 100, 100, 100, 200, 200], + "tax_unit_id": [10, 10, 11, 11, 20, 20], + "spm_unit_id": [1000, 1000, 1000, 1000, 2000, 2000], + } + + +@pytest.fixture +def complex_variable_values(): + """Variable values for complex entity relationship tests.""" + return { + "age": { + "person": [30, 28, 65, 62, 45, 16], + "household": [65, 45], + }, + "is_tax_unit_head": { + "person": [True, False, True, False, True, False], + "household": [2, 1], # count of heads per HH + }, + "tax_unit_count": { + "person": [1, 1, 1, 1, 1, 1], + "household": [2, 1], + }, + "person_count": { + "person": [1, 1, 1, 1, 1, 1], + "household": [4, 2], + }, + } + + +@pytest.fixture +def complex_sim(complex_entity_data, complex_variable_values): + """Mock simulation with complex entity relationships.""" + return MockSimulation(complex_entity_data, complex_variable_values) + + +# Tests for complex entity relationships +class TestCountTargetWithRealEntities: + """Test count targets with more complex entity relationships.""" + + def test_tax_unit_count_no_constraints(self, builder, complex_sim): + """Test tax_unit_count counts all tax units per household.""" + constraints = [] + geo_mask = np.array([True, True]) + n_households = 2 + + result = builder._calculate_target_values_entity_aware( + complex_sim, + "tax_unit_count", + constraints, + geo_mask, + n_households, + ) + + # Expected: HH1 has 2 tax units, HH2 has 1 + expected = np.array([2, 1], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + def test_tax_unit_count_with_age_constraint(self, builder, complex_sim): + """Test tax_unit_count with age constraint on members.""" + # Count tax units that have at least one person aged >= 65 + constraints = [ + {"variable": "age", "operation": ">=", "value": 65}, + ] + geo_mask = np.array([True, True]) + n_households = 2 + + result = builder._calculate_target_values_entity_aware( + complex_sim, + "tax_unit_count", + constraints, + geo_mask, + n_households, + ) + + # Expected: HH1 has 1 tax unit (TU 11) with person >=65, HH2 has 0 + expected = np.array([1, 0], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + def test_person_count_seniors(self, builder, complex_sim): + """Test person_count for seniors (age >= 65).""" + constraints = [ + {"variable": "age", "operation": ">=", "value": 65}, + ] + geo_mask = np.array([True, True]) + n_households = 2 + + result = builder._calculate_target_values_entity_aware( + complex_sim, + "person_count", + constraints, + geo_mask, + n_households, + ) + + # Expected: HH1 has 1 senior (age 65), HH2 has 0 + expected = np.array([1, 0], dtype=np.float32) + np.testing.assert_array_equal(result, expected) + + def test_person_count_children(self, builder, complex_sim): + """Test person_count for children (age < 18).""" + constraints = [ + {"variable": "age", "operation": "<", "value": 18}, + ] + geo_mask = np.array([True, True]) + n_households = 2 + + result = builder._calculate_target_values_entity_aware( + complex_sim, + "person_count", + constraints, + geo_mask, + n_households, + ) + + # Expected: HH1 has 0 children, HH2 has 1 (age 16) + expected = np.array([0, 1], dtype=np.float32) + np.testing.assert_array_equal(result, expected)