Skip to content

Comprehensive Diagnostics Demo

This example demonstrates comprehensive diagnostics in stochpw:

  1. Balance reports
  2. Weight statistics
  3. ROC curves (most important discriminator diagnostic)
  4. Calibration curves
  5. Visualization with plotnine
import time

import jax
import jax.numpy as jnp
import optax

from stochpw import (
    PermutationWeighter,
    balance_report,
    calibration_curve,
    roc_curve,
    standardized_mean_difference,
    weight_statistics,
)
from stochpw.plotting import (
    plot_balance_diagnostics,
    plot_calibration_curve,
    plot_roc_curve,
    plot_weight_distribution,
)

Generate Data with Confounding

def generate_confounded_data(n: int = 1000, seed: int = 42):
    """Generate synthetic data with treatment-covariate confounding."""
    key = jax.random.PRNGKey(seed)
    key1, key2, _key3 = jax.random.split(key, 3)

    # Generate covariates
    X = jax.random.normal(key1, (n, 5))

    # Treatment depends strongly on first two covariates (confounding)
    propensity = jax.nn.sigmoid(1.5 * X[:, 0] + X[:, 1] - 0.5)
    A = (jax.random.uniform(key2, (n,)) < propensity).astype(float)

    return X, A


start_time = time.time()

print("=" * 70)
print("Comprehensive Diagnostics Demo")
print("=" * 70)

# Generate data
X, A = generate_confounded_data(n=1000, seed=42)
print(f"\nGenerated data: X.shape={X.shape}, A.shape={A.shape}")
print(f"Treatment balance: {jnp.mean(A):.2%} treated")
======================================================================
Comprehensive Diagnostics Demo
======================================================================



Generated data: X.shape=(1000, 5), A.shape=(1000,)
Treatment balance: 44.10% treated

Step 1: Initial Balance Assessment

print("\n" + "=" * 70)
print("Step 1: Initial Balance Assessment")
print("=" * 70)

uniform_weights = jnp.ones(X.shape[0])
initial_smd = standardized_mean_difference(X, A, uniform_weights)
print(f"\nInitial max SMD: {jnp.max(jnp.abs(initial_smd)):.4f}")
print(f"Initial mean SMD: {jnp.mean(jnp.abs(initial_smd)):.4f}")

# Get full balance report
initial_report = balance_report(X, A, uniform_weights)
print(f"\nTreatment type: {initial_report['treatment_type']}")
print(f"Number of features: {initial_report['n_features']}")
print(f"Number of samples: {initial_report['n_samples']}")
======================================================================
Step 1: Initial Balance Assessment
======================================================================



Initial max SMD: 1.1802
Initial mean SMD: 0.4308



Treatment type: binary
Number of features: 5
Number of samples: 1000

Step 2: Fit Permutation Weighter

print("\n" + "=" * 70)
print("Step 2: Fit Permutation Weighter")
print("=" * 70)

opt = optax.rmsprop(learning_rate=0.1)
weighter = PermutationWeighter(num_epochs=50, batch_size=250, random_state=42, optimizer=opt)

_ = weighter.fit(X, A)
weights = weighter.predict(X, A)

assert weighter.history_ is not None
print(f"\nTraining completed in {len(weighter.history_['loss'])} epochs")
print(f"Final training loss: {weighter.history_['loss'][-1]:.4f}")
======================================================================
Step 2: Fit Permutation Weighter
======================================================================



Training completed in 50 epochs
Final training loss: 0.6573

Step 3: Balance After Weighting

print("\n" + "=" * 70)
print("Step 3: Balance After Weighting")
print("=" * 70)

final_smd = standardized_mean_difference(X, A, weights)
print(f"\nFinal max SMD: {jnp.max(jnp.abs(final_smd)):.4f}")
print(f"Final mean SMD: {jnp.mean(jnp.abs(final_smd)):.4f}")
smd_improvement = (1 - jnp.max(jnp.abs(final_smd)) / jnp.max(jnp.abs(initial_smd))) * 100
print(f"SMD improvement: {smd_improvement:.1f}%")

# Get comprehensive balance report
final_report = balance_report(X, A, weights)

print(f"\nEffective Sample Size: {final_report['ess']:.0f} / {final_report['n_samples']}")
print(f"ESS Ratio: {final_report['ess_ratio']:.2%}")
======================================================================
Step 3: Balance After Weighting
======================================================================

Final max SMD: 0.4051
Final mean SMD: 0.1781
SMD improvement: 65.7%

Effective Sample Size: 724 / 1000
ESS Ratio: 72.38%

Step 4: Weight Distribution Analysis

print("\n" + "=" * 70)
print("Step 4: Weight Distribution Analysis")
print("=" * 70)

w_stats = weight_statistics(weights)

print("\nWeight Statistics:")
print(f"  Mean: {w_stats['mean']:.3f}")
print(f"  Std: {w_stats['std']:.3f}")
print(f"  Min: {w_stats['min']:.3f}")
print(f"  Max: {w_stats['max']:.3f}")
print(f"  CV (std/mean): {w_stats['cv']:.3f}")
print(f"  Max/Min ratio: {w_stats['max_ratio']:.1f}")
print(f"  Entropy: {w_stats['entropy']:.3f}")
print(f"  N extreme (>10x mean): {w_stats['n_extreme']}")
======================================================================
Step 4: Weight Distribution Analysis
======================================================================

Weight Statistics:
  Mean: 0.896
  Std: 0.554
  Min: 0.054
  Max: 4.255
  CV (std/mean): 0.618
  Max/Min ratio: 79.5
  Entropy: 6.741
  N extreme (>10x mean): 0

Step 5: ROC Curve Analysis

ROC curve is the most important discriminator diagnostic.

print("\n" + "=" * 70)
print("Step 5: ROC Curve Analysis")
print("=" * 70)

# Create permuted data for ROC analysis
key = jax.random.PRNGKey(123)
A_perm = A[jax.random.permutation(key, len(A))]

# Get weights for both observed and permuted data
weights_obs = weights
weights_perm = weighter.predict(X, A_perm)

# Combine weights and create labels (0=observed, 1=permuted)
all_weights = jnp.concatenate([weights_obs, weights_perm])
all_labels = jnp.concatenate([jnp.zeros(len(weights_obs)), jnp.ones(len(weights_perm))])

# Compute ROC curve
fpr, tpr, thresholds = roc_curve(all_weights, all_labels)
# Compute AUC using trapezoidal rule
auc = float(jnp.trapezoid(tpr, fpr))

print(f"\nROC AUC: {auc:.4f}")
print(
    "\nInterpretation: AUC measures discriminator's ability to distinguish observed from permuted."
)
print("  AUC = 0.5: Random guessing (poor discriminator)")
print("  AUC = 1.0: Perfect discrimination")
print(f"  Current AUC = {auc:.4f}: ", end="")
if auc > 0.9:
    print("Excellent discriminator quality")
elif auc > 0.8:
    print("Good discriminator quality")
elif auc > 0.7:
    print("Moderate discriminator quality")
else:
    print("Poor discriminator quality - consider more epochs or larger model")
======================================================================
Step 5: ROC Curve Analysis
======================================================================



ROC AUC: 0.6853

Interpretation: AUC measures discriminator's ability to distinguish observed from permuted.
  AUC = 0.5: Random guessing (poor discriminator)
  AUC = 1.0: Perfect discrimination
  Current AUC = 0.6853: Poor discriminator quality - consider more epochs or larger model

Step 6: Discriminator Calibration

print("\n" + "=" * 70)
print("Step 6: Discriminator Calibration")
print("=" * 70)

# Generate discriminator predictions on training data
assert weighter.params_ is not None
AX = jnp.einsum("bi,bj->bij", A[:, None] if A.ndim == 1 else A, X).reshape(X.shape[0], -1)
logits = weighter.discriminator.apply(weighter.params_, A[:, None] if A.ndim == 1 else A, X, AX)
probs = jax.nn.sigmoid(logits)

# Use same permuted data from ROC analysis
A_perm_reshaped = A_perm[:, None] if A_perm.ndim == 1 else A_perm
AX_perm = jnp.einsum("bi,bj->bij", A_perm_reshaped, X).reshape(X.shape[0], -1)
logits_perm = weighter.discriminator.apply(weighter.params_, A_perm_reshaped, X, AX_perm)
probs_perm = jax.nn.sigmoid(logits_perm)

# Combine for calibration analysis
all_probs = jnp.concatenate([probs, probs_perm])
cal_labels = jnp.concatenate([jnp.zeros(len(probs)), jnp.ones(len(probs_perm))])

bin_centers, true_freqs, counts = calibration_curve(all_probs, cal_labels, num_bins=10)

print("\nCalibration Analysis (10 bins):")
print(f"{'Predicted':<12} {'Observed':<12} {'Count':<10} {'Error':<10}")
print("-" * 44)
for pred, obs, count in zip(bin_centers, true_freqs, counts):
    if count > 0:
        error = abs(pred - obs)
        print(f"{pred:>10.3f}   {obs:>10.3f}   {int(count):>8}   {error:>8.3f}")
======================================================================
Step 6: Discriminator Calibration
======================================================================



Calibration Analysis (10 bins):
Predicted    Observed     Count      Error     
--------------------------------------------
     0.050        0.375          8      0.325
     0.150        0.333         63      0.183
     0.250        0.311        180      0.061
     0.350        0.358        344      0.008
     0.450        0.421        518      0.029
     0.550        0.535        400      0.015
     0.650        0.687        297      0.037
     0.750        0.808        146      0.058
     0.850        0.970         33      0.120
     0.950        1.000         11      0.050

Step 7: Before/After Comparison

print("\n" + "=" * 70)
print("Step 7: Before/After Comparison")
print("=" * 70)

print(f"\n{'Metric':<30} {'Before':<15} {'After':<15} {'Improvement':<15}")
print("-" * 75)
# Type ignore needed because balance_report returns a union type
max_smd_imp = (1 - float(final_report["max_smd"]) / float(initial_report["max_smd"])) * 100  # type: ignore[arg-type]
mean_smd_imp = (1 - float(final_report["mean_smd"]) / float(initial_report["mean_smd"])) * 100  # type: ignore[arg-type]
ess_change = (float(final_report["ess"]) / float(initial_report["ess"]) - 1) * 100  # type: ignore[arg-type]
print(
    f"{'Max SMD':<30} {initial_report['max_smd']:>13.4f}  "
    + f"{final_report['max_smd']:>13.4f}  {max_smd_imp:>12.1f}%"
)
print(
    f"{'Mean SMD':<30} {initial_report['mean_smd']:>13.4f}  "
    + f"{final_report['mean_smd']:>13.4f}  {mean_smd_imp:>12.1f}%"
)
print(
    f"{'ESS':<30} {initial_report['ess']:>13.0f}  "
    + f"{final_report['ess']:>13.0f}  {ess_change:>12.1f}%"
)
print(
    f"{'ESS Ratio':<30} {initial_report['ess_ratio']:>13.2%}  "
    + f"{final_report['ess_ratio']:>13.2%}  {'-':>15}"
)
======================================================================
Step 7: Before/After Comparison
======================================================================

Metric                         Before          After           Improvement    
---------------------------------------------------------------------------
Max SMD                               1.1802         0.4051          65.7%
Mean SMD                              0.4308         0.1781          58.7%
ESS                                     1000            724         -27.6%
ESS Ratio                            100.00%         72.38%                -

Step 8: Create Visualizations

print("\n" + "=" * 70)
print("Step 8: Creating Visualizations")
print("=" * 70)

# Plot ROC curve (most important!)
plot_roc = plot_roc_curve(fpr, tpr, auc)
plot_roc.save("roc_curve.png", dpi=150, width=8, height=8)
print("\n✓ Saved: roc_curve.png (MOST IMPORTANT DIAGNOSTIC)")

# Plot balance diagnostics with standard errors
plot_balance = plot_balance_diagnostics(
    X, A, weights, feature_names=[f"X{i}" for i in range(X.shape[1])]
)
plot_balance.save("balance_diagnostics.png", dpi=150, width=10, height=6)
print("✓ Saved: balance_diagnostics.png (with 95% confidence intervals)")

# Plot weight distribution
plot_weights = plot_weight_distribution(weights)
plot_weights.save("weight_distribution.png", dpi=150, width=8, height=6)
print("✓ Saved: weight_distribution.png")

# Plot calibration
plot_cal = plot_calibration_curve(bin_centers, true_freqs, counts)
plot_cal.save("calibration_curve.png", dpi=150, width=8, height=8)
print("✓ Saved: calibration_curve.png")

print("\nVisualization files saved to current directory")

print("\n" + "=" * 70)
print("Demo Complete!")
elapsed_time = time.time() - start_time
print(f"Total execution time: {elapsed_time:.2f} seconds")
print("=" * 70)
======================================================================
Step 8: Creating Visualizations
======================================================================


/Users/drewd/GitHub/_packages/stochpw/.venv/lib/python3.12/site-packages/plotnine/ggplot.py:623: PlotnineWarning: Saving 8 x 8 in image.
/Users/drewd/GitHub/_packages/stochpw/.venv/lib/python3.12/site-packages/plotnine/ggplot.py:624: PlotnineWarning: Filename: roc_curve.png



✓ Saved: roc_curve.png (MOST IMPORTANT DIAGNOSTIC)


/Users/drewd/GitHub/_packages/stochpw/.venv/lib/python3.12/site-packages/plotnine/ggplot.py:623: PlotnineWarning: Saving 10 x 6 in image.
/Users/drewd/GitHub/_packages/stochpw/.venv/lib/python3.12/site-packages/plotnine/ggplot.py:624: PlotnineWarning: Filename: balance_diagnostics.png
/Users/drewd/GitHub/_packages/stochpw/.venv/lib/python3.12/site-packages/plotnine/ggplot.py:623: PlotnineWarning: Saving 8 x 6 in image.
/Users/drewd/GitHub/_packages/stochpw/.venv/lib/python3.12/site-packages/plotnine/ggplot.py:624: PlotnineWarning: Filename: weight_distribution.png


✓ Saved: balance_diagnostics.png (with 95% confidence intervals)
✓ Saved: weight_distribution.png


✓ Saved: calibration_curve.png

Visualization files saved to current directory

======================================================================
Demo Complete!
Total execution time: 9.53 seconds
======================================================================


/Users/drewd/GitHub/_packages/stochpw/.venv/lib/python3.12/site-packages/plotnine/ggplot.py:623: PlotnineWarning: Saving 8 x 8 in image.
/Users/drewd/GitHub/_packages/stochpw/.venv/lib/python3.12/site-packages/plotnine/ggplot.py:624: PlotnineWarning: Filename: calibration_curve.png

View source on GitHub