Lalonde Experiment: Permutation Weighting for ATE Estimation
This example demonstrates using permutation weighting on the classic Lalonde (1986) observational dataset to estimate the average treatment effect (ATE) of a job training program on earnings.
The dataset combines experimental treatment units from the NSW program with non-experimental control units, creating confounding and selection bias that must be addressed to recover the experimental benchmark ATE of $1,794.
Reference: LaLonde, R. J. (1986). "Evaluating the Econometric Evaluations of Training Programs with Experimental Data". The American Economic Review, 76(4), 604-620.
import time
from pathlib import Path
from typing import TypedDict
import jax.numpy as jnp
import numpy as np
from jax import Array
from stochpw import (
MLPDiscriminator,
PermutationWeighter,
effective_sample_size,
standardized_mean_difference,
)
class LalondeData(TypedDict):
"""Type definition for Lalonde dataset."""
X: Array
A: Array
Y: Array
feature_names: list[str]
ate_benchmark: float
Helper Functions
def load_lalonde_nsw() -> LalondeData:
"""
Load the Lalonde NSW (National Supported Work) observational dataset.
This dataset contains observational (non-experimental) data from the LaLonde study,
where NSW experimental treatment units are combined with non-experimental control
units, creating confounding and selection bias.
The dataset contains:
- Treatment: participation in job training program (1=treated, 0=control)
- Outcome: real earnings in 1978 (RE78)
- Covariates: age, education, race, marital status, pre-treatment earnings
Returns:
dict: Dictionary with keys:
- X: Covariates array, shape (n, d_x)
- A: Treatment array, shape (n, 1)
- Y: Outcome (earnings) array, shape (n, 1)
- feature_names: List of covariate names
- ate_benchmark: Experimental ATE estimate from RCT ($1,794)
"""
# Find the data file - try multiple locations to handle different execution contexts
# In notebooks, __file__ may not be defined, so use current working directory
try:
current_dir = Path(__file__).parent
except NameError:
# Running in notebook context, use current working directory
current_dir = Path.cwd()
# Try several possible locations
possible_paths = [
current_dir / "nsw_data.csv", # Same directory as script or notebook
current_dir / "examples" / "nsw_data.csv", # When run from project root
Path("examples") / "nsw_data.csv", # Relative to project root
]
data_file = None
for path in possible_paths:
if path.exists():
data_file = path
break
if data_file is None:
raise FileNotFoundError(
"Data file not found. Tried:\n"
+ "\n".join(f" - {p}" for p in possible_paths)
+ "\n\nPlease ensure nsw_data.csv is in the examples/ directory."
)
# Load data
data = np.genfromtxt(data_file, delimiter=",", skip_header=1)
# Extract treatment, outcome, and covariates
A = data[:, 0] # Treatment indicator (first column)
Y = data[:, -1] # RE78 earnings (last column)
X = data[:, 1:9] # All other columns are covariates
# Feature names (from dataset documentation)
feature_names = [
"age",
"education",
"black",
"hispanic",
"married",
"nodegree",
"RE74", # earnings in 1974
"RE75", # earnings in 1975
]
# Experimental ATE benchmark (from LaLonde 1986 paper)
# This is the "true" ATE estimated from the randomized experiment
ate_benchmark = 1794.0
return {
"X": jnp.array(X),
"A": jnp.array(A),
"Y": jnp.array(Y),
"feature_names": feature_names,
"ate_benchmark": ate_benchmark,
}
def estimate_ate(Y: Array, A: Array, weights: Array) -> float:
"""
Estimate the average treatment effect (ATE) using weighted means.
ATE = E[Y(1) - Y(0)] = E[Y|A=1] - E[Y|A=0]
Args:
Y: Outcome array, shape (n, 1) or (n,)
A: Treatment array, shape (n, 1) or (n,)
weights: Importance weights, shape (n,)
Returns:
float: Estimated ATE
"""
y_flat = Y.flatten()
a_flat = A.flatten()
treated_mask = a_flat == 1
control_mask = a_flat == 0
# Weighted mean for treated
weighted_y1 = jnp.sum(y_flat[treated_mask] * weights[treated_mask])
weighted_n1 = jnp.sum(weights[treated_mask])
mean_y1 = weighted_y1 / weighted_n1
# Weighted mean for control
weighted_y0 = jnp.sum(y_flat[control_mask] * weights[control_mask])
weighted_n0 = jnp.sum(weights[control_mask])
mean_y0 = weighted_y0 / weighted_n0
ate = mean_y1 - mean_y0
return float(ate)
Load Dataset
start_time = time.time()
print("=" * 70)
print("Lalonde Experiment: Permutation Weighting for ATE Estimation")
print("=" * 70)
# Load observational data
print("\nLoading Lalonde NSW observational dataset...")
data = load_lalonde_nsw()
X, A, Y = data["X"], data["A"], data["Y"]
feature_names = data["feature_names"]
ate_benchmark = data["ate_benchmark"]
n_treated = int(A.sum())
n_control = len(A) - n_treated
print("\nDataset statistics:")
print(f" Total samples: {len(X)}")
print(f" Treated: {n_treated}")
print(f" Control: {n_control}")
print(f" Covariates: {X.shape[1]}")
print(f" Covariate names: {', '.join(feature_names)}")
======================================================================
Lalonde Experiment: Permutation Weighting for ATE Estimation
======================================================================
Loading Lalonde NSW observational dataset...
Dataset statistics:
Total samples: 458
Treated: 177
Control: 281
Covariates: 8
Covariate names: age, education, black, hispanic, married, nodegree, RE74, RE75
Experimental Benchmark (Ground Truth)
print(f"\n{'=' * 70}")
print("Experimental Benchmark (Ground Truth)")
print(f"{'=' * 70}")
print(f" Experimental ATE: ${ate_benchmark:.2f}")
print(" (From the original randomized controlled trial)")
======================================================================
Experimental Benchmark (Ground Truth)
======================================================================
Experimental ATE: $1794.00
(From the original randomized controlled trial)
Naive Estimate (No Adjustment)
print(f"\n{'=' * 70}")
print("Naive Estimate (No Adjustment)")
print(f"{'=' * 70}")
weights_naive = jnp.ones(len(X))
ate_naive = estimate_ate(Y, A, weights_naive)
naive_error = ate_naive - ate_benchmark
naive_pct_error = (naive_error / ate_benchmark) * 100
print(f" Naive ATE: ${ate_naive:.2f}")
print(f" Error: ${naive_error:.2f} ({naive_pct_error:+.1f}%)")
# Check initial balance
smd_naive = standardized_mean_difference(X, A, weights_naive)
print("\n Covariate balance:")
print(f" Max |SMD|: {jnp.abs(smd_naive).max():.3f}")
print(" (Values > 0.1 indicate imbalance)")
print("\n Per-covariate imbalance:")
for i, (name, smd_val) in enumerate(zip(feature_names, smd_naive)):
print(f" {name:12s}: {smd_val:+.3f}")
======================================================================
Naive Estimate (No Adjustment)
======================================================================
Naive ATE: $4224.48
Error: $2430.48 (+135.5%)
Covariate balance:
Max |SMD|: 0.899
(Values > 0.1 indicate imbalance)
Per-covariate imbalance:
age : +0.175
education : +0.323
black : +0.279
hispanic : -0.087
married : -0.220
nodegree : -0.086
RE74 : -0.884
RE75 : -0.899
Permutation Weighting with Simple MLP
print(f"\n{'=' * 70}")
print("Permutation Weighting (Simple MLP)")
print(f"{'=' * 70}")
# Fit with a simple MLP architecture
mlp_simple = MLPDiscriminator(hidden_dims=[3])
weighter_simple = PermutationWeighter(
discriminator=mlp_simple,
num_epochs=500,
batch_size=len(X), # Full batch
random_state=42,
)
print("\nFitting weighter...")
_ = weighter_simple.fit(X, A)
weights_simple = weighter_simple.predict(X, A)
# Estimate ATE
ate_pw_simple = estimate_ate(Y, A, weights_simple)
pw_error_simple = ate_pw_simple - ate_benchmark
pw_pct_error_simple = (pw_error_simple / ate_benchmark) * 100
print(f"\n Permutation-weighted ATE: ${ate_pw_simple:.2f}")
print(f" Error: ${pw_error_simple:.2f} ({pw_pct_error_simple:+.1f}%)")
# Check balance improvement
smd_pw_simple = standardized_mean_difference(X, A, weights_simple)
print("\n Covariate balance after weighting:")
print(f" Max |SMD|: {jnp.abs(smd_pw_simple).max():.3f}")
balance_improvement = (1 - jnp.abs(smd_pw_simple).max() / jnp.abs(smd_naive).max()) * 100
print(f" Balance improvement: {balance_improvement:.1f}%")
# ESS
ess_simple = effective_sample_size(weights_simple)
ess_ratio_simple = ess_simple / len(weights_simple)
print("\n Effective sample size:")
print(f" ESS: {ess_simple:.0f} / {len(weights_simple)} ({ess_ratio_simple:.1%})")
======================================================================
Permutation Weighting (Simple MLP)
======================================================================
Fitting weighter...
Permutation-weighted ATE: $2512.67
Error: $718.67 (+40.1%)
Covariate balance after weighting:
Max |SMD|: 3.033
Balance improvement: -237.5%
Effective sample size:
ESS: 56 / 458 (12.2%)
Permutation Weighting with Larger MLP
print(f"\n{'=' * 70}")
print("Permutation Weighting (Larger MLP)")
print(f"{'=' * 70}")
# Try a larger architecture
mlp_large = MLPDiscriminator(hidden_dims=[32, 16])
weighter_large = PermutationWeighter(
discriminator=mlp_large,
num_epochs=500,
batch_size=len(X),
random_state=42,
)
print("\nFitting weighter...")
_ = weighter_large.fit(X, A)
weights_large = weighter_large.predict(X, A)
# Estimate ATE
ate_pw_large = estimate_ate(Y, A, weights_large)
pw_error_large = ate_pw_large - ate_benchmark
pw_pct_error_large = (pw_error_large / ate_benchmark) * 100
print(f"\n Permutation-weighted ATE: ${ate_pw_large:.2f}")
print(f" Error: ${pw_error_large:.2f} ({pw_pct_error_large:+.1f}%)")
# Check balance
smd_pw_large = standardized_mean_difference(X, A, weights_large)
print("\n Covariate balance after weighting:")
print(f" Max |SMD|: {jnp.abs(smd_pw_large).max():.3f}")
balance_improvement_large = (1 - jnp.abs(smd_pw_large).max() / jnp.abs(smd_naive).max()) * 100
print(f" Balance improvement: {balance_improvement_large:.1f}%")
# ESS
ess_large = effective_sample_size(weights_large)
ess_ratio_large = ess_large / len(weights_large)
print("\n Effective sample size:")
print(f" ESS: {ess_large:.0f} / {len(weights_large)} ({ess_ratio_large:.1%})")
======================================================================
Permutation Weighting (Larger MLP)
======================================================================
Fitting weighter...
Permutation-weighted ATE: $-633.69
Error: $-2427.69 (-135.3%)
Covariate balance after weighting:
Max |SMD|: 7.756
Balance improvement: -763.2%
Effective sample size:
ESS: 2 / 458 (0.5%)
Summary Comparison
print(f"\n{'=' * 70}")
print("Summary Comparison")
print(f"{'=' * 70}")
print(f"\n{'Method':<30} {'ATE Estimate':<15} {'Error':<15} {'% Error':<12}")
print("-" * 72)
print(f"{'Experimental (Benchmark)':<30} ${ate_benchmark:>12.2f} {'---':>12} {'---':>12}")
print(
f"{'Naive (Unadjusted)':<30} ${ate_naive:>12.2f} "
+ f"${naive_error:>12.2f} {naive_pct_error:>10.1f}%"
)
print(
f"{'PW (Simple MLP)':<30} ${ate_pw_simple:>12.2f} "
+ f"${pw_error_simple:>12.2f} {pw_pct_error_simple:>10.1f}%"
)
print(
f"{'PW (Larger MLP)':<30} ${ate_pw_large:>12.2f} "
+ f"${pw_error_large:>12.2f} {pw_pct_error_large:>10.1f}%"
)
print("\n Improvement over naive:")
improvement_over_naive = abs(naive_error) - abs(pw_error_simple)
print(f"\n Improvement over naive: ${improvement_over_naive:.2f}")
print(f"\n{'=' * 70}")
print("✓ Lalonde experiment completed successfully!")
elapsed_time = time.time() - start_time
print(f"⏱ Total execution time: {elapsed_time:.2f} seconds")
print(f"{'=' * 70}")
======================================================================
Summary Comparison
======================================================================
Method ATE Estimate Error % Error
------------------------------------------------------------------------
Experimental (Benchmark) $ 1794.00 --- ---
Naive (Unadjusted) $ 4224.48 $ 2430.48 135.5%
PW (Simple MLP) $ 2512.67 $ 718.67 40.1%
PW (Larger MLP) $ -633.69 $ -2427.69 -135.3%
Improvement over naive:
Improvement over naive: $1711.81
======================================================================
✓ Lalonde experiment completed successfully!
⏱ Total execution time: 15.01 seconds
======================================================================