Skip to content

diagnostics

Diagnostic utilities for assessing balance and weight quality.

balance_report(X, A, weights)

Generate comprehensive balance report.

Parameters:

Name Type Description Default
X (Array, shape(n_samples, n_features))

Covariates

required
A (Array, shape(n_samples, 1) or (n_samples,))

Treatments

required
weights (Array, shape(n_samples))

Sample weights

required

Returns:

Name Type Description
report dict

Comprehensive balance report with: - smd: Array of SMD per covariate - max_smd: Maximum absolute SMD across covariates - mean_smd: Mean absolute SMD across covariates - ess: Effective sample size - ess_ratio: ESS / n_samples - weight_stats: Dictionary of weight distribution statistics - n_samples: Number of samples - n_features: Number of features - treatment_type: 'binary' or 'continuous'

Notes

This function provides a complete overview of balance quality after weighting, useful for reporting and model diagnostics.

Source code in src/stochpw/diagnostics/advanced.py
def balance_report(
    X: Array, A: Array, weights: Array
) -> dict[str, float | Array | dict[str, float] | str | int]:
    """
    Generate comprehensive balance report.

    Parameters
    ----------
    X : jax.Array, shape (n_samples, n_features)
        Covariates
    A : jax.Array, shape (n_samples, 1) or (n_samples,)
        Treatments
    weights : jax.Array, shape (n_samples,)
        Sample weights

    Returns
    -------
    report : dict
        Comprehensive balance report with:
        - smd: Array of SMD per covariate
        - max_smd: Maximum absolute SMD across covariates
        - mean_smd: Mean absolute SMD across covariates
        - ess: Effective sample size
        - ess_ratio: ESS / n_samples
        - weight_stats: Dictionary of weight distribution statistics
        - n_samples: Number of samples
        - n_features: Number of features
        - treatment_type: 'binary' or 'continuous'

    Notes
    -----
    This function provides a complete overview of balance quality after
    weighting, useful for reporting and model diagnostics.
    """
    from .balance import standardized_mean_difference
    from .weights import effective_sample_size as ess_fn

    # Ensure A is 1D for type detection
    A_flat = A.squeeze() if A.ndim == 2 else A

    # Detect treatment type
    unique_a = jnp.unique(A_flat)
    is_binary = len(unique_a) == 2
    treatment_type = "binary" if is_binary else "continuous"

    # Compute SMD
    smd = standardized_mean_difference(X, A, weights)
    max_smd = float(jnp.max(jnp.abs(smd)))
    mean_smd = float(jnp.mean(jnp.abs(smd)))

    # Compute ESS
    ess = float(ess_fn(weights))
    n_samples = len(weights)
    ess_ratio = ess / n_samples

    # Weight statistics
    w_stats = weight_statistics(weights)

    return {
        "smd": smd,
        "max_smd": max_smd,
        "mean_smd": mean_smd,
        "ess": ess,
        "ess_ratio": ess_ratio,
        "weight_stats": w_stats,
        "n_samples": n_samples,
        "n_features": X.shape[1],
        "treatment_type": treatment_type,
    }

calibration_curve(discriminator_probs, true_labels, num_bins=10)

Compute calibration curve for discriminator predictions.

A well-calibrated discriminator should have predicted probabilities that match the true frequencies of the labels.

Parameters:

Name Type Description Default
discriminator_probs (Array, shape(n_samples))

Predicted probabilities from discriminator (values between 0 and 1)

required
true_labels (Array, shape(n_samples))

True binary labels (0 or 1)

required
num_bins int

Number of bins to divide probability range into

10

Returns:

Name Type Description
bin_centers (Array, shape(num_bins))

Center of each probability bin

true_frequencies (Array, shape(num_bins))

Actual frequency of positive class in each bin

counts (Array, shape(num_bins))

Number of samples in each bin

Notes

Perfect calibration means true_frequencies == bin_centers for all bins.

Source code in src/stochpw/diagnostics/advanced.py
def calibration_curve(
    discriminator_probs: Array, true_labels: Array, num_bins: int = 10
) -> tuple[Array, Array, Array]:
    """
    Compute calibration curve for discriminator predictions.

    A well-calibrated discriminator should have predicted probabilities
    that match the true frequencies of the labels.

    Parameters
    ----------
    discriminator_probs : jax.Array, shape (n_samples,)
        Predicted probabilities from discriminator (values between 0 and 1)
    true_labels : jax.Array, shape (n_samples,)
        True binary labels (0 or 1)
    num_bins : int, default=10
        Number of bins to divide probability range into

    Returns
    -------
    bin_centers : jax.Array, shape (num_bins,)
        Center of each probability bin
    true_frequencies : jax.Array, shape (num_bins,)
        Actual frequency of positive class in each bin
    counts : jax.Array, shape (num_bins,)
        Number of samples in each bin

    Notes
    -----
    Perfect calibration means true_frequencies == bin_centers for all bins.
    """
    # Create bins
    bin_edges = jnp.linspace(0, 1, num_bins + 1)
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

    # Digitize predictions into bins
    bin_indices = jnp.digitize(discriminator_probs, bin_edges[1:-1])

    # Compute true frequency and count per bin
    true_frequencies = jnp.zeros(num_bins)
    counts = jnp.zeros(num_bins)

    for i in range(num_bins):
        mask = bin_indices == i
        count = jnp.sum(mask)
        counts = counts.at[i].set(count)

        if count > 0:
            freq = jnp.sum(true_labels[mask]) / count
            true_frequencies = true_frequencies.at[i].set(freq)
        else:
            # For empty bins, use bin center as default
            true_frequencies = true_frequencies.at[i].set(bin_centers[i])

    return bin_centers, true_frequencies, counts

effective_sample_size(weights)

Compute effective sample size (ESS).

ESS = (sum w)^2 / sum(w^2)

Lower values indicate more extreme weights (fewer "effective" samples). ESS = n means uniform weights.

Parameters:

Name Type Description Default
weights (Array, shape(n_samples))

Sample weights

required

Returns:

Name Type Description
ess Array(scalar)

Effective sample size

Source code in src/stochpw/diagnostics/weights.py
@jax.jit
def effective_sample_size(weights: Array) -> Array:
    """
    Compute effective sample size (ESS).

    ESS = (sum w)^2 / sum(w^2)

    Lower values indicate more extreme weights (fewer "effective" samples).
    ESS = n means uniform weights.

    Parameters
    ----------
    weights : jax.Array, shape (n_samples,)
        Sample weights

    Returns
    -------
    ess : jax.Array (scalar)
        Effective sample size
    """
    return jnp.sum(weights) ** 2 / jnp.sum(weights**2)

maximum_mean_discrepancy(X, A, weights, sigma=None)

Compute Maximum Mean Discrepancy (MMD) between weighted treatment groups.

MMD measures distributional distance between groups using a kernel-based approach. Unlike SMD which compares means feature-by-feature, MMD captures higher-order moments and interactions between features.

For binary treatment, computes MMD between weighted treatment and control groups. For continuous treatment, this function is not applicable and returns NaN.

Parameters:

Name Type Description Default
X (Array, shape(n_samples, n_features))

Covariates

required
A (Array, shape(n_samples, 1) or (n_samples,))

Treatments (must be binary)

required
weights (Array, shape(n_samples))

Sample weights

required
sigma float

Bandwidth parameter for RBF kernel. If None, uses median heuristic: sigma = median(pairwise distances) / sqrt(2)

None

Returns:

Name Type Description
mmd float

MMD statistic (non-negative, 0 means identical distributions)

Notes

The MMD is defined as:

.. math:: MMD^2 = E[k(X_1, X_1')] - 2E[k(X_1, X_0)] + E[k(X_0, X_0')]

where X_1, X_1' are from the treated group and X_0, X_0' are from control. This implementation uses weighted expectations based on the provided weights.

References

Gretton, A., Borgwardt, K. M., Rasch, M. J., Schölkopf, B., & Smola, A. (2012). A kernel two-sample test. Journal of Machine Learning Research, 13(1), 723-773.

Examples:

>>> import jax.numpy as jnp
>>> from stochpw import maximum_mean_discrepancy
>>> X = jnp.array([[1, 2], [2, 3], [3, 4], [4, 5]])
>>> A = jnp.array([0, 0, 1, 1])
>>> weights = jnp.ones(4)
>>> mmd = maximum_mean_discrepancy(X, A, weights)
Source code in src/stochpw/diagnostics/balance.py
def maximum_mean_discrepancy(
    X: Array, A: Array, weights: Array, sigma: float | None = None
) -> float:
    """
    Compute Maximum Mean Discrepancy (MMD) between weighted treatment groups.

    MMD measures distributional distance between groups using a kernel-based approach.
    Unlike SMD which compares means feature-by-feature, MMD captures higher-order
    moments and interactions between features.

    For binary treatment, computes MMD between weighted treatment and control groups.
    For continuous treatment, this function is not applicable and returns NaN.

    Parameters
    ----------
    X : jax.Array, shape (n_samples, n_features)
        Covariates
    A : jax.Array, shape (n_samples, 1) or (n_samples,)
        Treatments (must be binary)
    weights : jax.Array, shape (n_samples,)
        Sample weights
    sigma : float, optional
        Bandwidth parameter for RBF kernel. If None, uses median heuristic:
        sigma = median(pairwise distances) / sqrt(2)

    Returns
    -------
    mmd : float
        MMD statistic (non-negative, 0 means identical distributions)

    Notes
    -----
    The MMD is defined as:

    .. math::
        MMD^2 = E[k(X_1, X_1')] - 2E[k(X_1, X_0)] + E[k(X_0, X_0')]

    where X_1, X_1' are from the treated group and X_0, X_0' are from control.
    This implementation uses weighted expectations based on the provided weights.

    References
    ----------
    Gretton, A., Borgwardt, K. M., Rasch, M. J., Schölkopf, B., & Smola, A. (2012).
    A kernel two-sample test. Journal of Machine Learning Research, 13(1), 723-773.

    Examples
    --------
    >>> import jax.numpy as jnp
    >>> from stochpw import maximum_mean_discrepancy
    >>> X = jnp.array([[1, 2], [2, 3], [3, 4], [4, 5]])
    >>> A = jnp.array([0, 0, 1, 1])
    >>> weights = jnp.ones(4)
    >>> mmd = maximum_mean_discrepancy(X, A, weights)
    """
    # Ensure A is 1D
    a_1d = A.squeeze() if A.ndim == 2 else A

    # Check if A is binary
    unique_a = jnp.unique(a_1d)
    is_binary = len(unique_a) == 2

    if not is_binary:
        # MMD only applicable for binary treatment
        return float("nan")

    # Split into treatment groups
    a0, a1 = unique_a[0], unique_a[1]
    mask_0 = a_1d == a0
    mask_1 = a_1d == a1

    X_0 = X[mask_0]
    X_1 = X[mask_1]
    weights_0 = weights[mask_0]
    weights_1 = weights[mask_1]

    # Normalize weights within each group
    weights_0_norm = weights_0 / jnp.sum(weights_0)
    weights_1_norm = weights_1 / jnp.sum(weights_1)

    # Compute bandwidth using median heuristic if not provided
    sigma_val: float
    if sigma is None:
        # Sample a subset for efficiency if dataset is large
        n_samples = min(1000, X.shape[0])
        if X.shape[0] > n_samples:
            indices = jnp.linspace(0, X.shape[0] - 1, n_samples, dtype=jnp.int32)
            X_sample = X[indices]
        else:
            X_sample = X

        # Compute pairwise distances
        dists_sq = jnp.sum(X_sample**2, axis=1, keepdims=True)
        pairwise_dists_sq = dists_sq + dists_sq.T - 2 * jnp.dot(X_sample, X_sample.T)
        pairwise_dists = jnp.sqrt(jnp.maximum(pairwise_dists_sq, 0))

        # Median heuristic
        positive_dists = pairwise_dists[pairwise_dists > 0]
        if positive_dists.size > 0:
            median_dist = float(jnp.median(positive_dists))
            sigma_val = float(median_dist / jnp.sqrt(2.0))
        else:
            # All distances are zero (constant features) - use default
            sigma_val = 1.0

        # Avoid zero or very small sigma
        sigma_val = float(jnp.maximum(sigma_val, 0.1))
    else:
        sigma_val = sigma

    # Compute kernel matrices
    K_00 = _rbf_kernel(X_0, X_0, sigma_val)
    K_11 = _rbf_kernel(X_1, X_1, sigma_val)
    K_01 = _rbf_kernel(X_0, X_1, sigma_val)

    # Compute weighted kernel expectations
    # E[k(X_0, X_0')] = sum_i sum_j w_i w_j k(x_i, x_j)
    term_00 = jnp.sum(weights_0_norm[:, None] * weights_0_norm[None, :] * K_00)
    term_11 = jnp.sum(weights_1_norm[:, None] * weights_1_norm[None, :] * K_11)
    term_01 = jnp.sum(weights_0_norm[:, None] * weights_1_norm[None, :] * K_01)

    # MMD^2 = E[k(X_1,X_1')] - 2*E[k(X_0,X_1)] + E[k(X_0,X_0')]
    mmd_sq = term_11 - 2 * term_01 + term_00

    # Return MMD (take square root, ensure non-negative)
    mmd = jnp.sqrt(jnp.maximum(mmd_sq, 0.0))

    return float(mmd)

roc_curve(weights, true_labels, max_points=100)

Compute ROC curve from weights for discriminator performance.

Given weights w(x,a), infers eta(x,a) = w(x,a) / (1 + w(x,a)) and computes the ROC curve for discriminating between observed and permuted data.

Parameters:

Name Type Description Default
weights (Array, shape(n_samples))

Sample weights from permutation weighting

required
true_labels (Array, shape(n_samples))

True binary labels (0=observed, 1=permuted)

required
max_points int

Maximum number of points in the ROC curve (for computational efficiency)

100

Returns:

Name Type Description
fpr Array

False positive rates at each threshold

tpr Array

True positive rates at each threshold

thresholds Array

Thresholds used to compute fpr and tpr

Notes

The ROC curve is the most important diagnostic for discriminator quality. A good discriminator will have high AUC (area under curve), indicating it can successfully distinguish between observed and permuted data.

The discriminator probability eta is inferred from weights as: eta(x,a) = w(x,a) / (1 + w(x,a))

Examples:

>>> # After fitting a weighter
>>> weights = weighter.predict(X, A)
>>> # Create permuted data and labels
>>> weights_perm = weighter.predict(X, A_permuted)
>>> all_weights = jnp.concatenate([weights, weights_perm])
>>> labels = jnp.concatenate([jnp.zeros(len(weights)), jnp.ones(len(weights_perm))])
>>> fpr, tpr, thresholds = roc_curve(all_weights, labels)
>>> auc = jnp.trapezoid(tpr, fpr)
Source code in src/stochpw/diagnostics/advanced.py
def roc_curve(
    weights: Array, true_labels: Array, max_points: int = 100
) -> tuple[Array, Array, Array]:
    """
    Compute ROC curve from weights for discriminator performance.

    Given weights w(x,a), infers eta(x,a) = w(x,a) / (1 + w(x,a)) and
    computes the ROC curve for discriminating between observed and
    permuted data.

    Parameters
    ----------
    weights : jax.Array, shape (n_samples,)
        Sample weights from permutation weighting
    true_labels : jax.Array, shape (n_samples,)
        True binary labels (0=observed, 1=permuted)
    max_points : int, default=100
        Maximum number of points in the ROC curve (for computational efficiency)

    Returns
    -------
    fpr : jax.Array
        False positive rates at each threshold
    tpr : jax.Array
        True positive rates at each threshold
    thresholds : jax.Array
        Thresholds used to compute fpr and tpr

    Notes
    -----
    The ROC curve is the most important diagnostic for discriminator quality.
    A good discriminator will have high AUC (area under curve), indicating
    it can successfully distinguish between observed and permuted data.

    The discriminator probability eta is inferred from weights as:
        eta(x,a) = w(x,a) / (1 + w(x,a))

    Examples
    --------
    >>> # After fitting a weighter
    >>> weights = weighter.predict(X, A)
    >>> # Create permuted data and labels
    >>> weights_perm = weighter.predict(X, A_permuted)
    >>> all_weights = jnp.concatenate([weights, weights_perm])
    >>> labels = jnp.concatenate([jnp.zeros(len(weights)), jnp.ones(len(weights_perm))])
    >>> fpr, tpr, thresholds = roc_curve(all_weights, labels)
    >>> auc = jnp.trapezoid(tpr, fpr)
    """
    # Infer eta from weights: eta = w / (1 + w)
    eta = weights / (1.0 + weights)

    # Use linearly spaced thresholds for efficiency
    min_eta = float(jnp.min(eta))
    max_eta = float(jnp.max(eta))
    thresholds = jnp.linspace(max_eta + 1e-6, min_eta - 1e-6, max_points)

    # Compute TPR and FPR for each threshold using vectorized operations
    n_positive = float(jnp.sum(true_labels == 1))
    n_negative = float(jnp.sum(true_labels == 0))

    # Vectorized computation
    # For each threshold, count predictions >= threshold
    predictions_matrix = eta[:, jnp.newaxis] >= thresholds[jnp.newaxis, :]

    # True positives: predictions==1 AND true_labels==1
    tp = jnp.sum(predictions_matrix & (true_labels[:, jnp.newaxis] == 1), axis=0)

    # False positives: predictions==1 AND true_labels==0
    fp = jnp.sum(predictions_matrix & (true_labels[:, jnp.newaxis] == 0), axis=0)

    # Compute rates
    tpr_array = tp / n_positive if n_positive > 0 else jnp.zeros_like(tp)
    fpr_array = fp / n_negative if n_negative > 0 else jnp.zeros_like(fp)

    return fpr_array, tpr_array, thresholds

standardized_mean_difference(X, A, weights)

Compute weighted standardized mean difference for each covariate.

For binary treatment, computes SMD between weighted treatment groups. For continuous treatment, computes weighted correlation with covariates.

Parameters:

Name Type Description Default
X (Array, shape(n_samples, n_features))

Covariates

required
A (Array, shape(n_samples, 1) or (n_samples,))

Treatments

required
weights (Array, shape(n_samples))

Sample weights

required

Returns:

Name Type Description
smd (Array, shape(n_features))

SMD or correlation for each covariate

Source code in src/stochpw/diagnostics/balance.py
def standardized_mean_difference(X: Array, A: Array, weights: Array) -> Array:
    """
    Compute weighted standardized mean difference for each covariate.

    For binary treatment, computes SMD between weighted treatment groups.
    For continuous treatment, computes weighted correlation with covariates.

    Parameters
    ----------
    X : jax.Array, shape (n_samples, n_features)
        Covariates
    A : jax.Array, shape (n_samples, 1) or (n_samples,)
        Treatments
    weights : jax.Array, shape (n_samples,)
        Sample weights

    Returns
    -------
    smd : jax.Array, shape (n_features,)
        SMD or correlation for each covariate
    """
    # Ensure A is 1D for this computation
    a_1d = A.squeeze() if A.ndim == 2 else A

    # Check if A is binary
    unique_a = jnp.unique(a_1d)
    is_binary = len(unique_a) == 2

    if is_binary:
        # Binary treatment: compute SMD
        a0, a1 = unique_a[0], unique_a[1]
        mask_0 = a_1d == a0
        mask_1 = a_1d == a1

        # Weighted means
        weights_0 = weights * mask_0
        weights_1 = weights * mask_1

        sum_weights_0 = jnp.sum(weights_0)
        sum_weights_1 = jnp.sum(weights_1)

        mean_0 = jnp.average(X, axis=0, weights=weights_0)
        mean_1 = jnp.average(X, axis=0, weights=weights_1)

        # Weighted standard deviations
        var_0 = jnp.sum(weights_0[:, None] * (X - mean_0) ** 2, axis=0) / (sum_weights_0 + 1e-10)
        var_1 = jnp.sum(weights_1[:, None] * (X - mean_1) ** 2, axis=0) / (sum_weights_1 + 1e-10)

        # Pooled standard deviation
        pooled_std = jnp.sqrt((var_0 + var_1) / 2)

        # SMD
        smd = (mean_1 - mean_0) / (pooled_std + 1e-10)

    else:
        # Continuous treatment: compute weighted correlation
        # Normalize weights
        w_norm = weights / jnp.sum(weights)

        # Weighted means
        mean_a = jnp.sum(w_norm * a_1d)
        mean_X = jnp.sum(w_norm[:, None] * X, axis=0)

        # Weighted covariance
        cov = jnp.sum(w_norm[:, None] * (a_1d[:, None] - mean_a) * (X - mean_X), axis=0)

        # Weighted standard deviations
        std_a = jnp.sqrt(jnp.sum(w_norm * (a_1d - mean_a) ** 2))
        std_X = jnp.sqrt(jnp.sum(w_norm[:, None] * (X - mean_X) ** 2, axis=0))

        # Correlation
        smd = cov / (std_a * std_X + 1e-10)

    return smd

standardized_mean_difference_se(X, A, weights)

Compute standard errors for standardized mean differences.

Uses the bootstrap-style approximation for weighted SMD standard errors.

Parameters:

Name Type Description Default
X (Array, shape(n_samples, n_features))

Covariates

required
A (Array, shape(n_samples, 1) or (n_samples,))

Treatments

required
weights (Array, shape(n_samples))

Sample weights

required

Returns:

Name Type Description
se (Array, shape(n_features))

Standard error for each covariate's SMD

Source code in src/stochpw/diagnostics/balance.py
def standardized_mean_difference_se(X: Array, A: Array, weights: Array) -> Array:
    """
    Compute standard errors for standardized mean differences.

    Uses the bootstrap-style approximation for weighted SMD standard errors.

    Parameters
    ----------
    X : jax.Array, shape (n_samples, n_features)
        Covariates
    A : jax.Array, shape (n_samples, 1) or (n_samples,)
        Treatments
    weights : jax.Array, shape (n_samples,)
        Sample weights

    Returns
    -------
    se : jax.Array, shape (n_features,)
        Standard error for each covariate's SMD
    """
    # Ensure A is 1D
    a_1d = A.squeeze() if A.ndim == 2 else A

    # Check if A is binary
    unique_a = jnp.unique(a_1d)
    is_binary = len(unique_a) == 2

    if is_binary:
        # Binary treatment: bootstrap-style SE
        a0, a1 = unique_a[0], unique_a[1]
        mask_0 = a_1d == a0
        mask_1 = a_1d == a1

        # Effective sample sizes
        weights_0 = weights * mask_0
        weights_1 = weights * mask_1
        ess_0 = jnp.sum(weights_0) ** 2 / jnp.sum(weights_0**2)
        ess_1 = jnp.sum(weights_1) ** 2 / jnp.sum(weights_1**2)

        # Approximate SE using ESS
        # SE(SMD) ≈ sqrt(1/n_0 + 1/n_1 + SMD²/(2*(n_0 + n_1)))
        # Use ESS instead of n
        smd = standardized_mean_difference(X, A, weights)
        se = jnp.sqrt(1.0 / ess_0 + 1.0 / ess_1 + smd**2 / (2.0 * (ess_0 + ess_1)))

    else:
        # Continuous treatment: approximate SE for correlation
        n_eff = jnp.sum(weights) ** 2 / jnp.sum(weights**2)
        # SE(correlation) ≈ 1/sqrt(n)
        se = jnp.ones(X.shape[1]) / jnp.sqrt(n_eff)

    return se

weight_statistics(weights)

Compute comprehensive statistics about weight distribution.

Parameters:

Name Type Description Default
weights (Array, shape(n_samples))

Sample weights

required

Returns:

Name Type Description
stats dict

Dictionary with weight statistics: - mean: Mean weight - std: Standard deviation of weights - min: Minimum weight - max: Maximum weight - cv: Coefficient of variation (std/mean) - entropy: Entropy of normalized weights - max_ratio: Ratio of max to min weight - n_extreme: Number of weights > 10x mean

Notes

Useful for diagnosing weight quality and potential issues with extreme or highly variable weights.

Source code in src/stochpw/diagnostics/advanced.py
def weight_statistics(weights: Array) -> dict[str, float]:
    """
    Compute comprehensive statistics about weight distribution.

    Parameters
    ----------
    weights : jax.Array, shape (n_samples,)
        Sample weights

    Returns
    -------
    stats : dict
        Dictionary with weight statistics:
        - mean: Mean weight
        - std: Standard deviation of weights
        - min: Minimum weight
        - max: Maximum weight
        - cv: Coefficient of variation (std/mean)
        - entropy: Entropy of normalized weights
        - max_ratio: Ratio of max to min weight
        - n_extreme: Number of weights > 10x mean

    Notes
    -----
    Useful for diagnosing weight quality and potential issues with
    extreme or highly variable weights.
    """
    mean_w = float(jnp.mean(weights))
    std_w = float(jnp.std(weights))
    min_w = float(jnp.min(weights))
    max_w = float(jnp.max(weights))

    # Coefficient of variation
    cv = std_w / mean_w if mean_w > 0 else float("inf")

    # Entropy of normalized weights
    w_norm = weights / jnp.sum(weights)
    # Add small constant to avoid log(0)
    entropy = float(-jnp.sum(w_norm * jnp.log(w_norm + 1e-10)))

    # Maximum weight ratio
    max_ratio = max_w / min_w if min_w > 0 else float("inf")

    # Number of extreme weights (> 10x mean)
    n_extreme = int(jnp.sum(weights > 10 * mean_w))

    return {
        "mean": mean_w,
        "std": std_w,
        "min": min_w,
        "max": max_w,
        "cv": cv,
        "entropy": entropy,
        "max_ratio": max_ratio,
        "n_extreme": n_extreme,
    }