Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,12 @@ pytest tests/test_rust_backend.py -v
- `bacon_decompose()` - Convenience function for quick decomposition
- Integrated with `TwoWayFixedEffects.decompose()` method

- **`diff_diff/linalg.py`** - Unified linear algebra backend (v1.4.0):
- **`diff_diff/linalg.py`** - Unified linear algebra backend (v1.4.0+):
- `solve_ols()` - OLS solver using scipy's gelsy LAPACK driver (QR-based, faster than SVD)
- `compute_robust_vcov()` - Vectorized HC1 and cluster-robust variance-covariance estimation
- `compute_r_squared()` - R-squared and adjusted R-squared computation
- `LinearRegression` - High-level OLS helper class with unified coefficient extraction and inference
- `InferenceResult` - Dataclass container for coefficient-level inference (SE, t-stat, p-value, CI)
- Single optimization point for all estimators (reduces code duplication)
- Cluster-robust SEs use pandas groupby instead of O(n × clusters) loop

Expand Down Expand Up @@ -270,7 +272,7 @@ Tests mirror the source modules:
- `tests/test_sun_abraham.py` - Tests for SunAbraham interaction-weighted estimator
- `tests/test_triple_diff.py` - Tests for Triple Difference (DDD) estimator
- `tests/test_bacon.py` - Tests for Goodman-Bacon decomposition
- `tests/test_linalg.py` - Tests for unified OLS backend and robust variance estimation
- `tests/test_linalg.py` - Tests for unified OLS backend, robust variance estimation, LinearRegression helper, and InferenceResult
- `tests/test_utils.py` - Tests for parallel trends, robust SE, synthetic weights
- `tests/test_diagnostics.py` - Tests for placebo tests
- `tests/test_wild_bootstrap.py` - Tests for wild cluster bootstrap
Expand Down
2 changes: 1 addition & 1 deletion ROADMAP.md
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ Ongoing maintenance and developer experience.
### Code Quality

- Extract shared within-transformation logic to utils
- Consolidate linear regression helpers
- ~~Consolidate linear regression helpers~~ ✓ Done (v2.1): Added `LinearRegression` helper class and `InferenceResult` dataclass in `linalg.py`. All major estimators (DifferenceInDifferences, TwoWayFixedEffects, SunAbraham, TripleDifference) now use the unified helper for coefficient extraction and inference.
- Consider splitting `staggered.py` (1800+ lines)

### Documentation
Expand Down
7 changes: 7 additions & 0 deletions diff_diff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
run_all_placebo_tests,
run_placebo_test,
)
from diff_diff.linalg import (
InferenceResult,
LinearRegression,
)
from diff_diff.estimators import (
DifferenceInDifferences,
MultiPeriodDiD,
Expand Down Expand Up @@ -199,4 +203,7 @@
"plot_pretrends_power",
# Rust backend
"HAS_RUST_BACKEND",
# Linear algebra helpers
"LinearRegression",
"InferenceResult",
]
70 changes: 32 additions & 38 deletions diff_diff/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
import numpy as np
import pandas as pd

from diff_diff.linalg import compute_r_squared, compute_robust_vcov, solve_ols
from diff_diff.linalg import (
LinearRegression,
compute_r_squared,
compute_robust_vcov,
solve_ols,
)
from diff_diff.results import DiDResults, MultiPeriodDiDResults, PeriodEffect
from diff_diff.utils import (
WildBootstrapResults,
Expand Down Expand Up @@ -262,56 +267,45 @@ def fit(
X = np.column_stack([X, dummies[col].values.astype(float)])
var_names.append(col)

# Fit OLS using unified backend
coefficients, residuals, fitted, vcov = solve_ols(
X, y, return_fitted=True, return_vcov=False
)
r_squared = compute_r_squared(y, residuals)

# Extract ATT (coefficient on interaction term)
# Extract ATT index (coefficient on interaction term)
att_idx = 3 # Index of interaction term
att_var_name = f"{treatment}:{time}"
assert var_names[att_idx] == att_var_name, (
f"ATT index mismatch: expected '{att_var_name}' at index {att_idx}, "
f"but found '{var_names[att_idx]}'"
)
att = coefficients[att_idx]

# Compute degrees of freedom (used for analytical inference)
df = len(y) - X.shape[1] - n_absorbed_effects
# Always use LinearRegression for initial fit (unified code path)
# For wild bootstrap, we don't need cluster SEs from the initial fit
cluster_ids = data[self.cluster].values if self.cluster is not None else None
reg = LinearRegression(
include_intercept=False, # Intercept already in X
robust=self.robust,
cluster_ids=cluster_ids if self.inference != "wild_bootstrap" else None,
alpha=self.alpha,
).fit(X, y, df_adjustment=n_absorbed_effects)

coefficients = reg.coefficients_
residuals = reg.residuals_
fitted = reg.fitted_values_
att = coefficients[att_idx]

# Compute standard errors and inference
# Get inference - either from bootstrap or analytical
if self.inference == "wild_bootstrap" and self.cluster is not None:
# Wild cluster bootstrap for few-cluster inference
cluster_ids = data[self.cluster].values
# Override with wild cluster bootstrap inference
se, p_value, conf_int, t_stat, vcov, _ = self._run_wild_bootstrap_inference(
X, y, residuals, cluster_ids, att_idx
)
elif self.cluster is not None:
cluster_ids = data[self.cluster].values
vcov = compute_robust_vcov(X, residuals, cluster_ids)
se = np.sqrt(vcov[att_idx, att_idx])
t_stat = att / se
p_value = compute_p_value(t_stat, df=df)
conf_int = compute_confidence_interval(att, se, self.alpha, df=df)
elif self.robust:
vcov = compute_robust_vcov(X, residuals)
se = np.sqrt(vcov[att_idx, att_idx])
t_stat = att / se
p_value = compute_p_value(t_stat, df=df)
conf_int = compute_confidence_interval(att, se, self.alpha, df=df)
else:
# Classical OLS standard errors
n = len(y)
k = X.shape[1]
mse = np.sum(residuals**2) / (n - k)
# Use solve() instead of inv() for numerical stability
# solve(A, B) computes X where AX=B, so this yields (X'X)^{-1} * mse
vcov = np.linalg.solve(X.T @ X, mse * np.eye(k))
se = np.sqrt(vcov[att_idx, att_idx])
t_stat = att / se
p_value = compute_p_value(t_stat, df=df)
conf_int = compute_confidence_interval(att, se, self.alpha, df=df)
# Use analytical inference from LinearRegression
vcov = reg.vcov_
inference = reg.get_inference(att_idx)
se = inference.se
t_stat = inference.t_stat
p_value = inference.p_value
conf_int = inference.conf_int

r_squared = compute_r_squared(y, residuals)

# Count observations
n_treated = int(np.sum(d))
Expand Down
Loading