Source code for ssbc.validation

"""Validation utilities for prediction interval operational bounds.

This module provides tools to empirically validate the theoretical bounds
by running simulations with fixed calibration thresholds on independent test sets.
Validates all reported methods (analytical, exact, hoeffding) when available.

Enforces strict mathematical consistency between the generative model, calibration
statistics, and predictive validation following the framework in
docs/operational_bounds_mathematical_framework.md.
"""

from typing import Any

import numpy as np
from joblib import Parallel, delayed

from ssbc.utils import evaluate_test_dataset
from ssbc.validation_math import (
    validate_metric_mathematical_consistency,
)


def _validate_single_trial(
    trial_idx: int,
    simulator: Any,
    test_size: int,
    threshold_0: float,
    threshold_1: float,
    seed: int | None = None,
) -> dict[str, Any]:
    """Run a single validation trial.

    Parameters
    ----------
    trial_idx : int
        Trial index (for unique random seeds)
    simulator : Any
        Data generator
    test_size : int
        Size of test set
    threshold_0 : float
        Fixed threshold for class 0
    threshold_1 : float
        Fixed threshold for class 1
    seed : int, optional
        Base random seed

    Returns
    -------
    dict
        Trial results with marginal and per-class rates
    """
    # Create a new simulator instance with unique seed for this trial
    if seed is not None:
        # Set global random state first, then create simulator
        np.random.seed(seed + trial_idx)
        # Create a new simulator with the unique seed for this trial
        trial_simulator = type(simulator)(
            p_class1=simulator.p_class1,
            beta_params_class0=(simulator.a0, simulator.b0),
            beta_params_class1=(simulator.a1, simulator.b1),
            seed=seed + trial_idx,
        )
    else:
        trial_simulator = simulator

    # Generate independent test set
    labels_test, probs_test = trial_simulator.generate(test_size)

    # Use evaluate_test_dataset to compute rates (eliminates ~100 lines of duplicate code)
    eval_results = evaluate_test_dataset(
        test_labels=labels_test,
        test_probs=probs_test,
        threshold_0=threshold_0,
        threshold_1=threshold_1,
    )

    # Convert results to expected format
    marginal = eval_results["marginal"]
    class_0 = eval_results["class_0"]
    class_1 = eval_results["class_1"]
    n_test = eval_results["n_test"]

    # Build prediction sets to compute predicted class rates
    from ssbc.utils import build_mondrian_prediction_sets

    prediction_sets = build_mondrian_prediction_sets(probs_test, threshold_0, threshold_1, return_lists=True)

    # Compute predicted class rates: P(predicted_class=X, S=singleton, E=error/correct)
    # Count singletons predicted as class 0 and class 1, and their errors/corrects
    n_singleton_pred_class0 = 0
    n_singleton_pred_class0_error = 0
    n_singleton_pred_class0_correct = 0
    n_singleton_pred_class1 = 0
    n_singleton_pred_class1_error = 0
    n_singleton_pred_class1_correct = 0

    for pred_set, true_label in zip(prediction_sets, labels_test, strict=False):
        if len(pred_set) == 1:  # Singleton
            predicted_class = pred_set[0]
            is_correct = true_label == predicted_class
            if predicted_class == 0:
                n_singleton_pred_class0 += 1
                if is_correct:
                    n_singleton_pred_class0_correct += 1
                else:
                    n_singleton_pred_class0_error += 1
            elif predicted_class == 1:
                n_singleton_pred_class1 += 1
                if is_correct:
                    n_singleton_pred_class1_correct += 1
                else:
                    n_singleton_pred_class1_error += 1

    # Normalize by total test set size
    error_pred_class0 = n_singleton_pred_class0_error / n_test if n_test > 0 else 0.0
    error_pred_class1 = n_singleton_pred_class1_error / n_test if n_test > 0 else 0.0
    correct_pred_class0 = n_singleton_pred_class0_correct / n_test if n_test > 0 else 0.0
    correct_pred_class1 = n_singleton_pred_class1_correct / n_test if n_test > 0 else 0.0

    # Additional marginal metrics involving class-conditional rates (normalized by total)
    # - Singleton, doublet, abstention rates for class 0 and class 1 (normalized by total)
    # - Normalized-by-total error rate for class-0 and class-1 singletons
    # - Conditional error given singleton & class=c
    # Compute rates normalized by total test set size (not by class size)
    singleton_rate_class0 = class_0["n_singletons"] / n_test if n_test > 0 else 0.0
    singleton_rate_class1 = class_1["n_singletons"] / n_test if n_test > 0 else 0.0
    doublet_rate_class0 = class_0["n_doublets"] / n_test if n_test > 0 else 0.0
    doublet_rate_class1 = class_1["n_doublets"] / n_test if n_test > 0 else 0.0
    abstention_rate_class0 = class_0["n_abstentions"] / n_test if n_test > 0 else 0.0
    abstention_rate_class1 = class_1["n_abstentions"] / n_test if n_test > 0 else 0.0
    err_c0_norm = (
        (class_0["n_singletons"] * class_0["singleton_error_rate"]) / n_test if class_0["n_singletons"] > 0 else 0.0
    )
    err_c1_norm = (
        (class_1["n_singletons"] * class_1["singleton_error_rate"]) / n_test if class_1["n_singletons"] > 0 else 0.0
    )
    # Joint correct rates: P(correct AND singleton AND class=0/1), normalized by total
    correct_c0_norm = (
        (class_0["n_singletons"] * (1 - class_0["singleton_error_rate"])) / n_test
        if class_0["n_singletons"] > 0
        else 0.0
    )
    correct_c1_norm = (
        (class_1["n_singletons"] * (1 - class_1["singleton_error_rate"])) / n_test
        if class_1["n_singletons"] > 0
        else 0.0
    )

    return {
        "marginal": {
            "singleton": marginal["singleton_rate"],
            "doublet": marginal["doublet_rate"],
            "abstention": marginal["abstention_rate"],
            "singleton_error": marginal["singleton_error_rate"],
            # Class-specific rates normalized by total
            "singleton_rate_class0": singleton_rate_class0,
            "singleton_rate_class1": singleton_rate_class1,
            "doublet_rate_class0": doublet_rate_class0,
            "doublet_rate_class1": doublet_rate_class1,
            "abstention_rate_class0": abstention_rate_class0,
            "abstention_rate_class1": abstention_rate_class1,
            # Class-specific singleton error rates (normalized by total)
            "singleton_error_class0": err_c0_norm,
            "singleton_error_class1": err_c1_norm,
            # Class-specific singleton correct rates (normalized by total)
            "singleton_correct_class0": correct_c0_norm,
            "singleton_correct_class1": correct_c1_norm,
            # Error/correct rates when singleton is assigned to a specific class (normalized by total)
            "singleton_error_pred_class0": error_pred_class0,
            "singleton_error_pred_class1": error_pred_class1,
            "singleton_correct_pred_class0": correct_pred_class0,
            "singleton_correct_pred_class1": correct_pred_class1,
        },
        "class_0": {
            "singleton": class_0["singleton_rate"],
            "doublet": class_0["doublet_rate"],
            "abstention": class_0["abstention_rate"],
            "singleton_error": class_0["singleton_error_rate"],
            "n_samples": class_0["n_samples"],  # Store for computing test set class prevalence
        },
        "class_1": {
            "singleton": class_1["singleton_rate"],
            "doublet": class_1["doublet_rate"],
            "abstention": class_1["abstention_rate"],
            "singleton_error": class_1["singleton_error_rate"],
            "n_samples": class_1["n_samples"],  # Store for computing test set class prevalence
        },
        "n_test": n_test,  # Store test size for computing prevalence
    }


[docs] def validate_pac_bounds( report: dict[str, Any], simulator: Any, test_size: int, n_trials: int = 1000, seed: int | None = None, verbose: bool = True, n_jobs: int = -1, ) -> dict[str, Any]: """Empirically validate prediction interval operational bounds. Takes a PAC report from generate_rigorous_pac_report() and validates that the theoretical bounds actually hold in practice by: 1. Extracting the FIXED thresholds from calibration 2. Running n_trials simulations with fresh test sets 3. Measuring empirical coverage of all reported bounds (analytical, exact, hoeffding) When the report includes method comparison (prediction_method="all"), validates all three methods separately. Otherwise, validates only the selected method. Parameters ---------- report : dict Output from generate_rigorous_pac_report() simulator : DataGenerator Simulator to generate independent test data (e.g., BinaryClassifierSimulator) test_size : int Size of each test set n_trials : int, default=1000 Number of independent trials seed : int, optional Random seed for reproducibility verbose : bool, default=True Print validation progress n_jobs : int, default=-1 Number of parallel jobs for trial execution. -1 = use all cores (default), 1 = single-threaded, N = use N cores. Returns ------- dict Validation results with: - 'marginal': Marginal operational rates and coverage - 'class_0': Class 0 operational rates and coverage - 'class_1': Class 1 operational rates and coverage Each containing: - 'singleton', 'doublet', 'abstention', 'singleton_error' dicts with: - 'rates': Array of rates across trials - 'mean': Mean rate - 'quantiles': Quantiles (5%, 25%, 50%, 75%, 95%) - 'bounds': Selected/default bounds from report - 'expected': Expected rate from report - 'empirical_coverage': Fraction of trials within selected bounds - 'method_validations': Dict of method-specific validations (when available): - 'analytical': {bounds, empirical_coverage} - 'exact': {bounds, empirical_coverage} - 'hoeffding': {bounds, empirical_coverage} Examples -------- >>> from ssbc import BinaryClassifierSimulator, generate_rigorous_pac_report, validate_pac_bounds >>> sim = BinaryClassifierSimulator(p_class1=0.2, seed=42) >>> labels, probs = sim.generate(100) >>> report = generate_rigorous_pac_report(labels, probs, delta=0.10) >>> validation = validate_pac_bounds(report, sim, test_size=1000, n_trials=1000) >>> print(f"Singleton coverage: {validation['marginal']['singleton']['empirical_coverage']:.1%}") Notes ----- This function is useful for: - Verifying theoretical PAC guarantees empirically - Understanding the tightness of bounds - Debugging issues with bounds calculation - Generating validation plots for papers/reports The empirical coverage should be ≥ PAC level (1 - δ) for rigorous bounds. """ # Extract FIXED thresholds from calibration threshold_0 = report["calibration_result"][0]["threshold"] threshold_1 = report["calibration_result"][1]["threshold"] if verbose: print(f"Using fixed thresholds: q̂₀={threshold_0:.4f}, q̂₁={threshold_1:.4f}") print(f"Running {n_trials} trials with test_size={test_size}...") if n_jobs == -1: print("Using all available CPU cores for parallel execution") elif n_jobs == 1: print("Using single-threaded execution") else: print(f"Using {n_jobs} CPU cores for parallel execution") # Run trials in parallel, with safe fallback to serial if system forbids multiprocessing def _safe_parallel_map(n_jobs_local: int): try: return Parallel(n_jobs=n_jobs_local)( delayed(_validate_single_trial)(trial_idx, simulator, test_size, threshold_0, threshold_1, seed) for trial_idx in range(n_trials) ) except Exception: test_size_int = int(test_size) if not np.isnan(test_size) else 100 return [ _validate_single_trial(trial_idx, simulator, test_size_int, threshold_0, threshold_1, seed) for trial_idx in range(n_trials) ] trial_results = _safe_parallel_map(n_jobs) # Extract results from parallel execution marginal_singleton_rates = [result["marginal"]["singleton"] for result in trial_results] marginal_doublet_rates = [result["marginal"]["doublet"] for result in trial_results] marginal_abstention_rates = [result["marginal"]["abstention"] for result in trial_results] # Class-specific rates normalized by total marginal_singleton_rate_class0 = [result["marginal"]["singleton_rate_class0"] for result in trial_results] marginal_singleton_rate_class1 = [result["marginal"]["singleton_rate_class1"] for result in trial_results] marginal_doublet_rate_class0 = [result["marginal"]["doublet_rate_class0"] for result in trial_results] marginal_doublet_rate_class1 = [result["marginal"]["doublet_rate_class1"] for result in trial_results] marginal_abstention_rate_class0 = [result["marginal"]["abstention_rate_class0"] for result in trial_results] marginal_abstention_rate_class1 = [result["marginal"]["abstention_rate_class1"] for result in trial_results] # Class-specific singleton error metrics (normalized by total) marginal_error_class0_norm = [result["marginal"]["singleton_error_class0"] for result in trial_results] marginal_error_class1_norm = [result["marginal"]["singleton_error_class1"] for result in trial_results] # Class-specific singleton correct metrics (normalized by total) marginal_correct_class0_norm = [result["marginal"]["singleton_correct_class0"] for result in trial_results] marginal_correct_class1_norm = [result["marginal"]["singleton_correct_class1"] for result in trial_results] # Error/correct rates when singleton is assigned to a specific class (normalized by total) marginal_error_pred_class0 = [result["marginal"]["singleton_error_pred_class0"] for result in trial_results] marginal_error_pred_class1 = [result["marginal"]["singleton_error_pred_class1"] for result in trial_results] marginal_correct_pred_class0 = [result["marginal"]["singleton_correct_pred_class0"] for result in trial_results] marginal_correct_pred_class1 = [result["marginal"]["singleton_correct_pred_class1"] for result in trial_results] class_0_singleton_rates = [result["class_0"]["singleton"] for result in trial_results] class_0_doublet_rates = [result["class_0"]["doublet"] for result in trial_results] class_0_abstention_rates = [result["class_0"]["abstention"] for result in trial_results] class_0_singleton_error_rates = [result["class_0"]["singleton_error"] for result in trial_results] class_1_singleton_rates = [result["class_1"]["singleton"] for result in trial_results] class_1_doublet_rates = [result["class_1"]["doublet"] for result in trial_results] class_1_abstention_rates = [result["class_1"]["abstention"] for result in trial_results] class_1_singleton_error_rates = [result["class_1"]["singleton_error"] for result in trial_results] # Convert to arrays marginal_singleton_rates = np.array(marginal_singleton_rates) marginal_doublet_rates = np.array(marginal_doublet_rates) marginal_abstention_rates = np.array(marginal_abstention_rates) # Class-specific rates normalized by total marginal_singleton_rate_class0 = np.array(marginal_singleton_rate_class0) marginal_singleton_rate_class1 = np.array(marginal_singleton_rate_class1) marginal_doublet_rate_class0 = np.array(marginal_doublet_rate_class0) marginal_doublet_rate_class1 = np.array(marginal_doublet_rate_class1) marginal_abstention_rate_class0 = np.array(marginal_abstention_rate_class0) marginal_abstention_rate_class1 = np.array(marginal_abstention_rate_class1) # Note: marginal_singleton_error_rates is NOT extracted because PAC bounds don't compute it marginal_error_class0_norm = np.array(marginal_error_class0_norm) marginal_error_class1_norm = np.array(marginal_error_class1_norm) # Class-specific singleton correct metrics (normalized by total) marginal_correct_class0_norm = np.array(marginal_correct_class0_norm) marginal_correct_class1_norm = np.array(marginal_correct_class1_norm) # Error/correct rates when singleton is assigned to a specific class (normalized by total) marginal_error_pred_class0 = np.array(marginal_error_pred_class0) marginal_error_pred_class1 = np.array(marginal_error_pred_class1) marginal_correct_pred_class0 = np.array(marginal_correct_pred_class0) marginal_correct_pred_class1 = np.array(marginal_correct_pred_class1) class_0_singleton_rates = np.array(class_0_singleton_rates) class_0_doublet_rates = np.array(class_0_doublet_rates) class_0_abstention_rates = np.array(class_0_abstention_rates) class_0_singleton_error_rates = np.array(class_0_singleton_error_rates) class_1_singleton_rates = np.array(class_1_singleton_rates) class_1_doublet_rates = np.array(class_1_doublet_rates) class_1_abstention_rates = np.array(class_1_abstention_rates) class_1_singleton_error_rates = np.array(class_1_singleton_error_rates) # Helper functions def check_coverage(rates: np.ndarray, bounds: tuple[float, float]) -> float: """Check what fraction of rates fall within bounds.""" lower, upper = bounds within = np.sum((rates >= lower) & (rates <= upper)) return within / len(rates) def check_coverage_with_nan(rates: np.ndarray, bounds: tuple[float, float]) -> float: """Check coverage, ignoring NaN values.""" lower, upper = bounds valid = ~np.isnan(rates) if np.sum(valid) == 0: return np.nan rates_valid = rates[valid] within = np.sum((rates_valid >= lower) & (rates_valid <= upper)) return within / len(rates_valid) def compute_quantiles(rates: np.ndarray) -> dict[str, float]: """Compute quantiles, handling NaN.""" valid = rates[~np.isnan(rates)] if np.any(np.isnan(rates)) else rates if len(valid) == 0: return { "q025": np.nan, "q05": np.nan, "q25": np.nan, "q50": np.nan, "q75": np.nan, "q95": np.nan, "q975": np.nan, } return { "q025": float(np.percentile(valid, 2.5)), "q05": float(np.percentile(valid, 5)), "q25": float(np.percentile(valid, 25)), "q50": float(np.percentile(valid, 50)), "q75": float(np.percentile(valid, 75)), "q95": float(np.percentile(valid, 95)), "q975": float(np.percentile(valid, 97.5)), } # Get bounds from report pac_marg = report["pac_bounds_marginal"] pac_0 = report["pac_bounds_class_0"] pac_1 = report["pac_bounds_class_1"] # Extract levels from report params = report.get("parameters", {}) pac_level_marginal = params.get("pac_level_marginal", 0.90) # Default if missing pac_level_0 = params.get("pac_level_0", 0.90) pac_level_1 = params.get("pac_level_1", 0.90) ci_level = params.get("ci_level", 0.95) # Helper function to extract method comparison bounds def extract_method_bounds(pac_bounds: dict, metric_key: str) -> dict[str, tuple[float, float]]: """Extract bounds for all methods (analytical, exact, hoeffding) if available. Returns dict mapping method names to (lower, upper) bounds. If comparison not available, returns only the selected method bounds. """ method_bounds = {} # Get the default selected bounds rate_key = f"{metric_key}_rate_bounds" default_bounds = pac_bounds.get(rate_key, (np.nan, np.nan)) # Check if method comparison is available loo_diag = pac_bounds.get("loo_diagnostics", {}) metric_diag = loo_diag.get(metric_key, {}) if loo_diag else {} if metric_diag and "comparison" in metric_diag: # Extract bounds for all methods from comparison table comp = metric_diag["comparison"] for method_name, method_lower, method_upper in zip( comp["method"], comp["lower"], comp["upper"], strict=False ): # Normalize method names method_key = method_name.lower().replace(" ", "_") if "analytical" in method_key: method_bounds["analytical"] = (method_lower, method_upper) elif "exact" in method_key: method_bounds["exact"] = (method_lower, method_upper) elif "hoeffding" in method_key: method_bounds["hoeffding"] = (method_lower, method_upper) else: # No comparison available - use selected bounds with default method name selected_method = metric_diag.get("selected_method", "selected") if metric_diag else "selected" method_bounds[selected_method] = default_bounds return method_bounds def extract_method_bounds_by_keys( pac_bounds: dict, rate_bounds_key: str, diagnostics_key: str | None = None, ) -> dict[str, tuple[float, float]]: """Extract bounds using explicit keys in pac_bounds (marginal extras). rate_bounds_key: e.g., 'singleton_error_rate_class0_bounds' diagnostics_key: e.g., 'singleton_error_class0' inside loo_diagnostics """ method_bounds: dict[str, tuple[float, float]] = {} default_bounds = pac_bounds.get(rate_bounds_key, (np.nan, np.nan)) if diagnostics_key is not None: loo_diag = pac_bounds.get("loo_diagnostics", {}) metric_diag = loo_diag.get(diagnostics_key, {}) if loo_diag else {} if metric_diag and "comparison" in metric_diag: comp = metric_diag["comparison"] for method_name, method_lower, method_upper in zip( comp["method"], comp["lower"], comp["upper"], strict=False ): method_key = method_name.lower().replace(" ", "_") if "analytical" in method_key: method_bounds["analytical"] = (method_lower, method_upper) elif "exact" in method_key: method_bounds["exact"] = (method_lower, method_upper) elif "hoeffding" in method_key: method_bounds["hoeffding"] = (method_lower, method_upper) else: selected_method = metric_diag.get("selected_method", "selected") if metric_diag else "selected" method_bounds[selected_method] = default_bounds else: # No diagnostics key provided; just return selected bounds under 'selected' method_bounds["selected"] = default_bounds return method_bounds # Helper function to validate a metric across all methods def validate_metric_all_methods( rates: np.ndarray, pac_bounds: dict, metric_key: str, use_nan_check: bool = False, scope: str = "marginal", report: dict[str, Any] | None = None, ) -> dict[str, Any]: """Validate a metric across all reported methods with mathematical consistency checks. Parameters ---------- rates : np.ndarray Test rates from validation trials pac_bounds : dict PAC bounds dictionary for the scope metric_key : str Metric identifier (e.g., "singleton", "doublet") use_nan_check : bool Whether to handle NaN values scope : str Scope: "marginal", "class_0", or "class_1" report : dict, optional Full report for mathematical consistency validation Returns ------- dict Validation result including mathematical consistency checks """ # Extract bounds for all methods method_bounds_map = extract_method_bounds(pac_bounds, metric_key) # Get default bounds (selected method) rate_key = f"{metric_key}_rate_bounds" default_bounds = pac_bounds.get(rate_key, (np.nan, np.nan)) # Get expected rate expected_key = f"expected_{metric_key}_rate" expected = pac_bounds.get(expected_key, np.nan) # Validate each method method_validations = {} for method_name, bounds in method_bounds_map.items(): if use_nan_check: coverage = check_coverage_with_nan(rates, bounds) else: coverage = check_coverage(rates, bounds) method_validations[method_name] = { "bounds": bounds, "empirical_coverage": coverage, } # Mathematical consistency validation math_consistency = None if report is not None: # Map metric_key to full metric name for validation_math # Note: Global marginal rates (singleton, doublet, abstention) mix classes # and don't have valid Bernoulli event definitions - skip validation for these if scope == "marginal" and metric_key in ["singleton", "doublet", "abstention"]: # Still extract calibration info for display, even though we can't validate # Extract n_cal from calibration_result (total calibration size) try: if "calibration_result" in report and len(report["calibration_result"]) > 0: if 0 in report["calibration_result"] and 1 in report["calibration_result"]: n_cal = report["calibration_result"][0].get("n", 0) + report["calibration_result"][1].get( "n", 0 ) else: n_cal = report["calibration_result"][0].get("n", None) else: n_cal = None # Get test_size from parameters params = report.get("parameters", {}) n_test = params.get("test_size", None) # k_cal is not meaningful for global marginal rates (mixes classes) k_cal = None except Exception: k_cal = None n_cal = None n_test = None math_consistency = { "overall_valid": False, "message": "⚠️ Global marginal rates mix classes - no valid Bernoulli event definition", "event_definition": "N/A (mixes class distributions)", "k_cal": k_cal, "n_cal": n_cal, "n_test": n_test, "denominator_alignment": { "valid": False, "message": "N/A (not a valid Bernoulli event)", "issues": ["Global marginal rates cannot be validated as single Bernoulli events"], }, } else: # Map metric_key to full metric name for validation_math full_metric_key = metric_key if scope != "marginal": # For per-class scopes, map to conditional metric names # scope is "class_0" or "class_1", need to convert to "class0" or "class1" class_num = scope.replace("class_", "class") if metric_key == "singleton": full_metric_key = f"singleton_{class_num}" elif metric_key == "doublet": full_metric_key = f"doublet_{class_num}" elif metric_key == "abstention": full_metric_key = f"abstention_{class_num}" elif metric_key == "singleton_error": full_metric_key = f"singleton_error_{class_num}" try: math_consistency = validate_metric_mathematical_consistency( full_metric_key, scope, report, rates, ci_level ) except Exception as e: # Log the actual error for debugging import traceback error_msg = f"⚠️ Mathematical consistency check failed: {str(e)}" math_consistency = { "overall_valid": False, "message": error_msg, "event_definition": "Unknown", "error": str(e), "traceback": traceback.format_exc(), } result = { "rates": rates, "mean": np.nanmean(rates) if use_nan_check else np.mean(rates), "quantiles": compute_quantiles(rates), "bounds": default_bounds, # Selected/default bounds "expected": expected, "empirical_coverage": check_coverage_with_nan(rates, default_bounds) if use_nan_check else check_coverage(rates, default_bounds), "method_validations": method_validations, # All method-specific validations } if math_consistency is not None: result["mathematical_consistency"] = math_consistency return result def validate_metric_by_keys( rates: np.ndarray, pac_bounds: dict, rate_bounds_key: str, diagnostics_key: str, expected_key: str | None = None, use_nan_check: bool = False, scope: str = "marginal", report: dict[str, Any] | None = None, ) -> dict[str, Any]: """Validate a metric using explicit pac_bounds keys with mathematical consistency checks. Parameters ---------- rates : np.ndarray Test rates from validation trials pac_bounds : dict PAC bounds dictionary for the scope rate_bounds_key : str Key for bounds in pac_bounds (e.g., "singleton_rate_class0_bounds") diagnostics_key : str Key in loo_diagnostics (e.g., "singleton_class0") expected_key : str, optional Key for expected rate use_nan_check : bool Whether to handle NaN values scope : str Scope: "marginal", "class_0", or "class_1" report : dict, optional Full report for mathematical consistency validation Returns ------- dict Validation result including mathematical consistency checks """ method_bounds_map = extract_method_bounds_by_keys(pac_bounds, rate_bounds_key, diagnostics_key) default_bounds = pac_bounds.get(rate_bounds_key, (np.nan, np.nan)) expected = pac_bounds.get(expected_key, np.nan) if expected_key else np.nan method_validations = {} for method_name, bounds in method_bounds_map.items(): coverage = check_coverage_with_nan(rates, bounds) if use_nan_check else check_coverage(rates, bounds) method_validations[method_name] = {"bounds": bounds, "empirical_coverage": coverage} # Mathematical consistency validation # Extract metric_key from rate_bounds_key (e.g., "singleton_rate_class0_bounds" -> "singleton_rate_class0") metric_key = rate_bounds_key.replace("_bounds", "") # Don't remove "_rate" - it's part of the key (e.g., "singleton_rate_class0") math_consistency = None if report is not None: try: math_consistency = validate_metric_mathematical_consistency(metric_key, scope, report, rates, ci_level) except Exception as e: # Log the actual error for debugging error_msg = f"⚠️ Mathematical consistency check failed: {str(e)}" math_consistency = { "overall_valid": False, "message": error_msg, "event_definition": "Unknown", "k_cal": "N/A", "n_cal": "N/A", "n_test": "N/A", "error": str(e), } result = { "rates": rates, "mean": np.nanmean(rates) if use_nan_check else np.mean(rates), "quantiles": compute_quantiles(rates), "bounds": default_bounds, "expected": expected, "empirical_coverage": check_coverage_with_nan(rates, default_bounds) if use_nan_check else check_coverage(rates, default_bounds), "method_validations": method_validations, } if math_consistency is not None: result["mathematical_consistency"] = math_consistency return result result = { "n_trials": n_trials, "test_size": test_size, "threshold_0": threshold_0, "threshold_1": threshold_1, "pac_level_marginal": pac_level_marginal, "pac_level_0": pac_level_0, "pac_level_1": pac_level_1, "ci_level": ci_level, "marginal": { "singleton": validate_metric_all_methods( marginal_singleton_rates, pac_marg, "singleton", use_nan_check=False, scope="marginal", report=report, ), "doublet": validate_metric_all_methods( marginal_doublet_rates, pac_marg, "doublet", use_nan_check=False, scope="marginal", report=report ), "abstention": validate_metric_all_methods( marginal_abstention_rates, pac_marg, "abstention", use_nan_check=False, scope="marginal", report=report, ), # Class-specific rates normalized by total "singleton_rate_class0": validate_metric_by_keys( marginal_singleton_rate_class0, pac_marg, "singleton_rate_class0_bounds", "singleton_class0", expected_key="expected_singleton_rate_class0", use_nan_check=False, scope="marginal", report=report, ), "singleton_rate_class1": validate_metric_by_keys( marginal_singleton_rate_class1, pac_marg, "singleton_rate_class1_bounds", "singleton_class1", expected_key="expected_singleton_rate_class1", use_nan_check=False, scope="marginal", report=report, ), "doublet_rate_class0": validate_metric_by_keys( marginal_doublet_rate_class0, pac_marg, "doublet_rate_class0_bounds", "doublet_class0", expected_key="expected_doublet_rate_class0", use_nan_check=False, scope="marginal", report=report, ), "doublet_rate_class1": validate_metric_by_keys( marginal_doublet_rate_class1, pac_marg, "doublet_rate_class1_bounds", "doublet_class1", expected_key="expected_doublet_rate_class1", use_nan_check=False, scope="marginal", report=report, ), "abstention_rate_class0": validate_metric_by_keys( marginal_abstention_rate_class0, pac_marg, "abstention_rate_class0_bounds", "abstention_class0", expected_key="expected_abstention_rate_class0", use_nan_check=False, scope="marginal", report=report, ), "abstention_rate_class1": validate_metric_by_keys( marginal_abstention_rate_class1, pac_marg, "abstention_rate_class1_bounds", "abstention_class1", expected_key="expected_abstention_rate_class1", use_nan_check=False, scope="marginal", report=report, ), # Class-conditional singleton error variants # Note: We do NOT validate marginal singleton_error because it mixes two different # distributions (class 0 and class 1) which cannot be justified statistically. "singleton_error_class0": validate_metric_by_keys( marginal_error_class0_norm, pac_marg, "singleton_error_rate_class0_bounds", "singleton_error_class0", expected_key="expected_singleton_error_rate_class0", use_nan_check=False, scope="marginal", report=report, ), "singleton_error_class1": validate_metric_by_keys( marginal_error_class1_norm, pac_marg, "singleton_error_rate_class1_bounds", "singleton_error_class1", expected_key="expected_singleton_error_rate_class1", use_nan_check=False, scope="marginal", report=report, ), "singleton_correct_class0": validate_metric_by_keys( marginal_correct_class0_norm, pac_marg, "singleton_correct_rate_class0_bounds", "singleton_correct_class0", expected_key="expected_singleton_correct_rate_class0", use_nan_check=False, scope="marginal", report=report, ), "singleton_correct_class1": validate_metric_by_keys( marginal_correct_class1_norm, pac_marg, "singleton_correct_rate_class1_bounds", "singleton_correct_class1", expected_key="expected_singleton_correct_rate_class1", use_nan_check=False, scope="marginal", report=report, ), # Error/correct rates when singleton is assigned to a specific class (normalized by total) "singleton_error_pred_class0": validate_metric_by_keys( marginal_error_pred_class0, pac_marg, "singleton_error_rate_pred_class0_bounds", "singleton_error_pred_class0", expected_key="expected_singleton_error_rate_pred_class0", use_nan_check=False, scope="marginal", report=report, ), "singleton_error_pred_class1": validate_metric_by_keys( marginal_error_pred_class1, pac_marg, "singleton_error_rate_pred_class1_bounds", "singleton_error_pred_class1", expected_key="expected_singleton_error_rate_pred_class1", use_nan_check=False, scope="marginal", report=report, ), "singleton_correct_pred_class0": validate_metric_by_keys( marginal_correct_pred_class0, pac_marg, "singleton_correct_rate_pred_class0_bounds", "singleton_correct_pred_class0", expected_key="expected_singleton_correct_rate_pred_class0", use_nan_check=False, scope="marginal", report=report, ), "singleton_correct_pred_class1": validate_metric_by_keys( marginal_correct_pred_class1, pac_marg, "singleton_correct_rate_pred_class1_bounds", "singleton_correct_pred_class1", expected_key="expected_singleton_correct_rate_pred_class1", use_nan_check=False, scope="marginal", report=report, ), }, "class_0": { "singleton": validate_metric_all_methods( class_0_singleton_rates, pac_0, "singleton", use_nan_check=False, scope="class_0", report=report ), "doublet": validate_metric_all_methods( class_0_doublet_rates, pac_0, "doublet", use_nan_check=False, scope="class_0", report=report ), "abstention": validate_metric_all_methods( class_0_abstention_rates, pac_0, "abstention", use_nan_check=False, scope="class_0", report=report ), "singleton_error": validate_metric_all_methods( class_0_singleton_error_rates, pac_0, "singleton_error", use_nan_check=True, scope="class_0", report=report, ), }, "class_1": { "singleton": validate_metric_all_methods( class_1_singleton_rates, pac_1, "singleton", use_nan_check=False, scope="class_1", report=report ), "doublet": validate_metric_all_methods( class_1_doublet_rates, pac_1, "doublet", use_nan_check=False, scope="class_1", report=report ), "abstention": validate_metric_all_methods( class_1_abstention_rates, pac_1, "abstention", use_nan_check=False, scope="class_1", report=report ), "singleton_error": validate_metric_all_methods( class_1_singleton_error_rates, pac_1, "singleton_error", use_nan_check=True, scope="class_1", report=report, ), }, } # Add probability consistency checks for joint rates # For each class y: q_y^sing + q_y^dbl + q_y^abs = p_y_test ± ε # NOTE: p_y_test should be computed directly from test set class counts, NOT from the sum! # The sum should equal the test set class prevalence by the law of total probability. prob_consistency_checks = {} for class_label in [0, 1]: # Extract joint rates from marginal validation marginal_dict = result.get("marginal", {}) if isinstance(marginal_dict, dict): q_sing = marginal_dict.get(f"singleton_rate_class{class_label}", {}).get("mean", np.nan) q_dbl = marginal_dict.get(f"doublet_rate_class{class_label}", {}).get("mean", np.nan) q_abs = marginal_dict.get(f"abstention_rate_class{class_label}", {}).get("mean", np.nan) else: q_sing = q_dbl = q_abs = np.nan # Compute sum of joint rates (should equal p_y_test) sum_joint_rates = q_sing + q_dbl + q_abs # Compute test set class prevalence DIRECTLY from class counts (not from sum!) # Extract n_samples from trial results where we stored it class_rates_key = f"class_{class_label}" test_size_val = result.get("test_size") test_size: float = float(test_size_val) if test_size_val is not None else np.nan # Extract n_samples for this class from each trial n_samples_per_trial = [trial_result[class_rates_key]["n_samples"] for trial_result in trial_results] n_test_per_trial = [trial_result.get("n_test", test_size) for trial_result in trial_results] # Compute p_y_test as mean class prevalence across trials if len(n_samples_per_trial) > 0 and not np.isnan(test_size): # Use actual n_test from each trial if available, otherwise use test_size p_y_test_trials = [ n_samples / n_test for n_samples, n_test in zip(n_samples_per_trial, n_test_per_trial, strict=False) ] p_y_test = np.mean(p_y_test_trials) else: # Fallback: use sum (which should equal p_y_test by definition) p_y_test = sum_joint_rates # Also get calibration class prevalence for comparison p_y_calibration = np.nan if "calibration_result" in report: n_class = report["calibration_result"][class_label].get("n", 0) n_total = ( report["calibration_result"][0].get("n", 0) + report["calibration_result"][1].get("n", 0) if 0 in report["calibration_result"] and 1 in report["calibration_result"] else n_class ) p_y_calibration = n_class / n_total if n_total > 0 else np.nan # Verify the sum equals p_y_test (law of total probability) # The sum should equal p_y_test by definition: q_y^sing + q_y^dbl + q_y^abs = P(Y=y) = p_y_test # We verify this holds numerically (within floating point precision) tolerance_numerical = 1e-5 # Numerical precision check (allowing for averaging across trials) sum_valid = abs(sum_joint_rates - p_y_test) < tolerance_numerical # Compare to calibration as a sanity check (test set class distribution may differ due to random sampling) difference_from_calibration = np.nan if not np.isnan(p_y_calibration): difference_from_calibration = abs(p_y_test - p_y_calibration) # The main check is that sum = p_y_test (sum_valid), calibration comparison is just informational valid = sum_valid difference = difference_from_calibration else: # If no calibration data, just verify sum equals p_y_test valid = sum_valid and (0.0 < p_y_test < 1.0) difference = abs(sum_joint_rates - p_y_test) if not sum_valid else 0.0 # Build message if valid: msg = f"✅ Probability consistency valid (sum={sum_joint_rates:.6f} = test p_y={p_y_test:.6f}" if not np.isnan(p_y_calibration): msg += f", calibration p_y={p_y_calibration:.6f}, difference={difference_from_calibration:.6f}" msg += ")" else: diff_sum = abs(sum_joint_rates - p_y_test) msg = ( f"❌ Probability consistency violated: sum={sum_joint_rates:.6f} " f"≠ test p_y={p_y_test:.6f} (difference={diff_sum:.6f})" ) if not np.isnan(p_y_calibration): msg += f", calibration p_y={p_y_calibration:.6f}" prob_consistency_checks[f"class_{class_label}"] = { "valid": valid, "q_sing": q_sing, "q_dbl": q_dbl, "q_abs": q_abs, "sum": sum_joint_rates, # Sum of joint rates (should equal p_y_test) "p_y_test": p_y_test, # Test set class prevalence (computed from class counts) "p_y_calibration": p_y_calibration if not np.isnan(p_y_calibration) else None, # Calibration class prevalence (for comparison) "difference": difference, # Difference from calibration (if available) or numerical difference "message": msg, } result["probability_consistency"] = prob_consistency_checks return result
[docs] def plot_validation_bounds( validation: dict[str, Any], metric: str = "singleton", show_detail: bool = True, main_figsize: tuple[int, int] = (18, 5), detail_figsize: tuple[int, int] = (18, 12), bins: int = 50, method_colors: dict[str, tuple[str, str]] | None = None, return_figs: bool = False, ) -> tuple | None: """Plot empirical distributions with prediction interval bounds for all methods. Creates visualization comparing empirical rates against bounds from analytical, exact, and hoeffding methods when available. Parameters ---------- validation : dict Output from validate_pac_bounds() containing validation results metric : str, default="singleton" Which metric to plot. Options: "singleton", "doublet", "abstention", "singleton_error" show_detail : bool, default=True If True, also create detailed 3x3 grid showing each method separately main_figsize : tuple[int, int], default=(18, 5) Figure size for main comparison plot (width, height in inches) detail_figsize : tuple[int, int], default=(18, 12) Figure size for detailed method comparison grid (width, height in inches) bins : int, default=50 Number of bins for histograms method_colors : dict or None, default=None Custom colors and linestyles for methods. Dict mapping method names to (color, linestyle) tuples. If None, uses default colors: - "analytical": ("#2E86AB", "solid") # Blue - "exact": ("#A23B72", "dashed") # Purple - "hoeffding": ("#F18F01", "dashdot") # Orange return_figs : bool, default=False If True, returns matplotlib Figure objects for further customization. Returns (fig_main, fig_detail) or (fig_main, None) if show_detail=False. If False, calls plt.show() and returns None. Returns ------- tuple or None If return_figs=True: - (fig_main, fig_detail) if show_detail=True - (fig_main, None) if show_detail=False If return_figs=False: None (displays plots directly) Examples -------- >>> from ssbc import validate_pac_bounds, plot_validation_bounds >>> validation = validate_pac_bounds(report, sim, test_size=1000, n_trials=1000) >>> plot_validation_bounds(validation, metric="singleton") >>> # Or get figure objects for customization >>> fig_main, fig_detail = plot_validation_bounds( ... validation, metric="singleton", return_figs=True ... ) >>> fig_main.savefig("validation_main.png") Notes ----- The main plot shows all three methods overlaid on the same histogram for easy comparison. The detailed plot shows each method separately in a 3x3 grid. Both plots include: - Empirical distribution histogram - Method-specific bounds (when method comparison available) - Expected value from LOO-CV - Empirical mean from validation trials - Coverage percentages for each method """ try: import matplotlib.pyplot as plt except ImportError as err: raise ImportError("matplotlib is required for plotting. Install with: pip install matplotlib") from err # Validate metric exists valid_metrics = [ "singleton", "doublet", "abstention", "singleton_error", # New marginal-only metrics "singleton_error_class0", "singleton_error_class1", "singleton_correct_class0", "singleton_correct_class1", "singleton_error_pred_class0", "singleton_error_pred_class1", "singleton_correct_pred_class0", "singleton_correct_pred_class1", ] if metric not in valid_metrics: raise ValueError(f"metric must be one of {valid_metrics}, got '{metric}'") # Check that metric exists in validation if metric not in validation["marginal"]: raise ValueError( f"Metric '{metric}' not found in validation results. " f"Available metrics: {list(validation['marginal'].keys())}" ) # Default method colors if not provided if method_colors is None: method_colors = { "analytical": ("#2E86AB", "solid"), # Blue "exact": ("#A23B72", "dashed"), # Purple "hoeffding": ("#F18F01", "dashdot"), # Orange } scopes = [("marginal", "Marginal"), ("class_0", "Class 0"), ("class_1", "Class 1")] method_order = ["analytical", "exact", "hoeffding"] # ===== Main comparison plot ===== fig_main, axes = plt.subplots(1, 3, figsize=main_figsize) for idx, (scope, title) in enumerate(scopes): ax = axes[idx] rates = validation[scope][metric]["rates"] expected = validation[scope][metric]["expected"] # Plot histogram ax.hist(rates, bins=bins, alpha=0.6, edgecolor="black", color="gray", label="Empirical distribution") # Plot bounds for each method if available if "method_validations" in validation[scope][metric] and validation[scope][metric]["method_validations"]: for method_name in method_order: if method_name in validation[scope][metric]["method_validations"]: method_val = validation[scope][metric]["method_validations"][method_name] method_bounds = method_val["bounds"] method_coverage = method_val["empirical_coverage"] color, linestyle = method_colors.get(method_name, ("black", "solid")) # Plot method bounds ax.axvline( method_bounds[0], color=color, linestyle=linestyle, linewidth=2, alpha=0.8, label=f"{method_name.capitalize()} lower", ) ax.axvline( method_bounds[1], color=color, linestyle=linestyle, linewidth=2, alpha=0.8, label=f"{method_name.capitalize()} upper (cov: {method_coverage:.1%})", ) else: # Fallback to selected bounds if method comparison not available bounds = validation[scope][metric]["bounds"] coverage = validation[scope][metric]["empirical_coverage"] ax.axvline(bounds[0], color="r", linestyle="--", linewidth=2, label="Lower bound") ax.axvline(bounds[1], color="r", linestyle="--", linewidth=2, label=f"Upper bound (cov: {coverage:.1%})") # Plot expected and empirical mean ax.axvline(expected, color="g", linestyle="-", linewidth=2.5, label="Expected (LOO)", zorder=5) ax.axvline(np.mean(rates), color="darkorange", linestyle=":", linewidth=2, label="Empirical mean", zorder=5) # Plot requested 95% quantile lines (2.5% and 97.5%) for the empirical distribution q = validation[scope][metric]["quantiles"] ax.axvline(q["q025"], color="#555555", linestyle="--", linewidth=1.5, label="2.5% quantile", zorder=4) ax.axvline(q["q975"], color="#555555", linestyle="--", linewidth=1.5, label="97.5% quantile", zorder=4) ax.set_xlabel(f"{metric.replace('_', ' ').title()} Rate", fontsize=11) ax.set_ylabel("Frequency", fontsize=11) # Build title with coverage info if "method_validations" in validation[scope][metric] and validation[scope][metric]["method_validations"]: coverages = [] for method_name in method_order: if method_name in validation[scope][metric]["method_validations"]: cov = validation[scope][metric]["method_validations"][method_name]["empirical_coverage"] if not np.isnan(cov): coverages.append(f"{method_name[0].upper()}:{cov:.1%}") coverage_str = ", ".join(coverages) if coverages else "" ax.set_title(f"{title} {metric.replace('_', ' ').title()} Rates\nCoverages: {coverage_str}", fontsize=11) else: cov = validation[scope][metric]["empirical_coverage"] ax.set_title(f"{title} {metric.replace('_', ' ').title()} Rates\nCoverage: {cov:.1%}", fontsize=11) ax.legend(loc="upper right", fontsize=8, framealpha=0.9) ax.grid(alpha=0.3) plt.tight_layout() # ===== Detailed method comparison plot ===== fig_detail = None if show_detail: if ( "method_validations" in validation["marginal"][metric] and validation["marginal"][metric]["method_validations"] ): fig_detail, axes2 = plt.subplots(3, 3, figsize=detail_figsize) for method_idx, method_name in enumerate(method_order): if method_name in validation["marginal"][metric]["method_validations"]: color, linestyle = method_colors.get(method_name, ("black", "solid")) for scope_idx, (scope, title) in enumerate(scopes): ax = axes2[scope_idx, method_idx] rates = validation[scope][metric]["rates"] expected = validation[scope][metric]["expected"] if method_name in validation[scope][metric]["method_validations"]: method_val = validation[scope][metric]["method_validations"][method_name] method_bounds = method_val["bounds"] method_coverage = method_val["empirical_coverage"] # Plot histogram ax.hist(rates, bins=bins, alpha=0.6, edgecolor="black", color="lightgray") # Plot method bounds ax.axvline( method_bounds[0], color=color, linestyle=linestyle, linewidth=2.5, label=f"Lower [{method_bounds[0]:.3f}]", alpha=0.9, ) ax.axvline( method_bounds[1], color=color, linestyle=linestyle, linewidth=2.5, label=f"Upper [{method_bounds[1]:.3f}]", alpha=0.9, ) # Plot expected and mean ax.axvline(expected, color="g", linestyle="-", linewidth=2, label="Expected", zorder=5) ax.axvline( np.mean(rates), color="darkorange", linestyle=":", linewidth=2, label=f"Mean [{np.mean(rates):.3f}]", zorder=5, ) # Plot empirical 2.5% and 97.5% quantiles in detail panels as well q = validation[scope][metric]["quantiles"] ax.axvline( q["q025"], color="#555555", linestyle="--", linewidth=1.2, label="2.5%", zorder=4 ) ax.axvline( q["q975"], color="#555555", linestyle="--", linewidth=1.2, label="97.5%", zorder=4 ) ax.set_xlabel(f"{metric.replace('_', ' ').title()} Rate", fontsize=10) ax.set_ylabel("Frequency", fontsize=10) cov_str = f"{method_coverage:.1%}" if not np.isnan(method_coverage) else "N/A" ax.set_title(f"{title} - {method_name.capitalize()}\nCoverage: {cov_str}", fontsize=10) ax.legend(loc="upper right", fontsize=8, framealpha=0.9) ax.grid(alpha=0.3) else: ax.text( 0.5, 0.5, f"{method_name.capitalize()}\nnot available", ha="center", va="center", transform=ax.transAxes, fontsize=12, ) ax.axis("off") plt.tight_layout() # Display or return figures if return_figs: return (fig_main, fig_detail) if show_detail else (fig_main, None) else: plt.show() return None
[docs] def validate_prediction_interval_calibration( simulator: Any, n_calibration: int, BigN: int, alpha_target: float | dict[int, float] = 0.10, delta: float | dict[int, float] = 0.10, test_size: int = 1000, n_trials: int = 1000, ci_level: float = 0.95, use_loo_correction: bool = True, prediction_method: str = "all", loo_inflation_factor: float | None = None, seed: int | None = None, n_jobs: int = -1, verbose: bool = False, ) -> dict[str, Any]: """Validate that prediction interval confidence level holds across calibration datasets. This meta-validation checks if the nominal confidence level (e.g., 95%) actually holds when repeating the entire calibration+validation process many times with different calibration datasets. For each of BigN calibration datasets: 1. Generate random calibration data 2. Compute prediction interval bounds 3. Validate bounds with many test sets 4. Record empirical coverage Then aggregates results to see if ~95% of calibrations achieve ≥95% coverage. Parameters ---------- simulator : DataGenerator Simulator for generating calibration and test data (e.g., BinaryClassifierSimulator) n_calibration : int Size of each calibration dataset BigN : int Number of different calibration datasets to test alpha_target : float or dict[int, float], default=0.10 Target miscoverage rate per class delta : float or dict[int, float], default=0.10 PAC risk tolerance for threshold calibration test_size : int, default=1000 Size of each test set in validation n_trials : int, default=1000 Number of test sets per calibration dataset (for validation) ci_level : float, default=0.95 Nominal confidence level for prediction intervals (target to validate) use_loo_correction : bool, default=True Use LOO-corrected bounds prediction_method : str, default="all" Method for bounds computation ("all" to compare all methods) loo_inflation_factor : float, optional Manual override for LOO inflation factor seed : int, optional Random seed for reproducibility n_jobs : int, default=-1 Number of parallel jobs (-1 = all cores) verbose : bool, default=False If True, print progress for each calibration dataset Returns ------- dict Meta-validation results with keys: - 'n_calibrations': BigN - 'n_calibration': Calibration dataset size - 'n_trials_per_calibration': n_trials - 'ci_level': Target confidence level - 'marginal': Dict with coverage statistics per method - 'class_0': Dict with coverage statistics per method - 'class_1': Dict with coverage statistics per method Each scope contains: - 'singleton', 'doublet', 'abstention', 'singleton_error': Dicts with: - 'selected': Coverage stats for selected bounds - 'analytical': Coverage stats if available - 'exact': Coverage stats if available - 'hoeffding': Coverage stats if available Each method has: - 'coverages': Array of empirical coverages across BigN calibrations - 'mean': Mean coverage - 'median': Median coverage - 'quantiles': {q05, q25, q50, q75, q95} - 'fraction_above_target': Fraction achieving ≥ci_level - 'fraction_above_95pct': Fraction achieving ≥95% (for comparison) Examples -------- >>> from ssbc import BinaryClassifierSimulator, validate_prediction_interval_calibration >>> sim = BinaryClassifierSimulator(p_class1=0.2, seed=42) >>> results = validate_prediction_interval_calibration( ... simulator=sim, ... n_calibration=100, ... BigN=50, ... n_trials=500, ... verbose=False ... ) >>> print(f"Fraction achieving ≥95%: {results['marginal']['singleton']['selected']['fraction_above_target']:.1%}") """ from ssbc import generate_rigorous_pac_report if seed is not None: np.random.seed(seed) # Helper to run validation for one calibration dataset def _single_calibration_validation(cal_idx: int) -> dict[str, Any]: """Run full validation for one calibration dataset.""" # Generate unique seed for this calibration cal_seed = (seed + cal_idx * 10000) if seed is not None else None # Generate calibration dataset if cal_seed is not None: np.random.seed(cal_seed) cal_simulator = type(simulator)( p_class1=simulator.p_class1, beta_params_class0=(simulator.a0, simulator.b0), beta_params_class1=(simulator.a1, simulator.b1), seed=cal_seed, ) else: cal_simulator = simulator cal_labels, cal_probs = cal_simulator.generate(n_calibration) # Generate PAC report (suppress verbose output) # Important for performance: avoid nested parallelism. We parallelize at the # calibration level (outer loop) and keep inner routines single-threaded. report = generate_rigorous_pac_report( labels=cal_labels, probs=cal_probs, alpha_target=alpha_target, delta=delta, test_size=test_size, ci_level=ci_level, use_union_bound=False, n_jobs=1, verbose=False, # Always suppress report printing prediction_method=prediction_method, use_loo_correction=use_loo_correction, loo_inflation_factor=(float(loo_inflation_factor) if loo_inflation_factor is not None else 2.0), ) # Validate bounds # Ensure independent RNG across calibrations even when seed is None base_seed = cal_seed if cal_seed is not None else int(np.random.randint(0, 2**31 - 1)) validation = validate_pac_bounds( report=report, simulator=cal_simulator, test_size=test_size, n_trials=n_trials, seed=base_seed + 1, verbose=False, # Suppress validation progress n_jobs=1, ) # Extract coverages and quantiles for all methods result = {} for scope in ["marginal", "class_0", "class_1"]: result[scope] = {} if scope == "marginal": metrics_list = [ "singleton", "doublet", "abstention", # Class-specific rates normalized by total (for scope 'marginal') "singleton_rate_class0", "singleton_rate_class1", "doublet_rate_class0", "doublet_rate_class1", "abstention_rate_class0", "abstention_rate_class1", # Note: singleton_error is NOT included for marginal because it mixes # two different distributions (class 0 and class 1) which cannot be justified statistically. "singleton_error_class0", "singleton_error_class1", "singleton_correct_class0", "singleton_correct_class1", "singleton_error_pred_class0", "singleton_error_pred_class1", "singleton_correct_pred_class0", "singleton_correct_pred_class1", ] else: metrics_list = ["singleton", "doublet", "abstention", "singleton_error"] for metric in metrics_list: m = validation[scope][metric] # Get observed quantiles from test set rates rates = m["rates"] # Filter NaN values for quantile calculation (especially for singleton_error) valid_rates = rates[~np.isnan(rates)] if np.any(np.isnan(rates)) else rates if len(valid_rates) > 0: observed_q05 = float(np.percentile(valid_rates, 5)) observed_q95 = float(np.percentile(valid_rates, 95)) else: observed_q05 = np.nan observed_q95 = np.nan # Get selected bounds selected_bounds = m["bounds"] result[scope][metric] = { "selected": { "coverage": m["empirical_coverage"], "lower": float(selected_bounds[0]), "upper": float(selected_bounds[1]), }, "observed_q05": observed_q05, "observed_q95": observed_q95, } # Extract method-specific coverages and bounds if available if "method_validations" in m and m["method_validations"]: for method_name in ["analytical", "exact", "hoeffding"]: if method_name in m["method_validations"]: method_val = m["method_validations"][method_name] method_bounds = method_val["bounds"] result[scope][metric][method_name] = { "coverage": method_val["empirical_coverage"], "lower": float(method_bounds[0]), "upper": float(method_bounds[1]), } return result # Run BigN calibrations if verbose: print(f"Running meta-validation: {BigN} calibration datasets, {n_trials} trials each...") print("Progress: ", end="", flush=True) # Parallelize calibration validations def _safe_parallel_map_cal(n_jobs_local: int): try: # Use process-based parallelism at the outer level to avoid GIL contention # and prevent nested joblib pools from oversubscribing the machine. results = Parallel(n_jobs=n_jobs_local, backend="loky")( delayed(_single_calibration_validation)(cal_idx) for cal_idx in range(BigN) ) return results except Exception: # Fallback to serial return [_single_calibration_validation(cal_idx) for cal_idx in range(BigN)] all_results = _safe_parallel_map_cal(n_jobs) if verbose: print("Done!") # Aggregate coverages across all calibrations def compute_coverage_stats(coverages: np.ndarray, target_level: float) -> dict[str, Any]: """Compute statistics for coverage array.""" valid = coverages[~np.isnan(coverages)] if len(valid) == 0: return { "coverages": coverages, "mean": np.nan, "median": np.nan, "quantiles": { "q05": np.nan, "q25": np.nan, "q50": np.nan, "q75": np.nan, "q95": np.nan, }, "fraction_above_target": np.nan, "fraction_above_95pct": np.nan, } return { "coverages": coverages, "mean": float(np.mean(valid)), "median": float(np.median(valid)), "quantiles": { "q05": float(np.percentile(valid, 5)), "q25": float(np.percentile(valid, 25)), "q50": float(np.median(valid)), "q75": float(np.percentile(valid, 75)), "q95": float(np.percentile(valid, 95)), }, "fraction_above_target": float(np.mean(valid >= target_level)), "fraction_above_95pct": float(np.mean(valid >= 0.95)), } # Aggregate results aggregated: dict[str, Any] = { "n_calibrations": BigN, "n_calibration": n_calibration, "n_trials_per_calibration": n_trials, "ci_level": ci_level, } for scope in ["marginal", "class_0", "class_1"]: scope_dict: dict[str, Any] = {} aggregated[scope] = scope_dict # Extend metrics for marginal scope to include class-specific rates and error variants if scope == "marginal": metrics_list = [ "singleton", "doublet", "abstention", # Class-specific rates normalized by total (for scope 'marginal') "singleton_rate_class0", "singleton_rate_class1", "doublet_rate_class0", "doublet_rate_class1", "abstention_rate_class0", "abstention_rate_class1", # Note: singleton_error is NOT included for marginal because it mixes # two different distributions (class 0 and class 1) which cannot be justified statistically. "singleton_error_class0", "singleton_error_class1", "singleton_correct_class0", "singleton_correct_class1", "singleton_error_pred_class0", "singleton_error_pred_class1", "singleton_correct_pred_class0", "singleton_correct_pred_class1", ] else: metrics_list = ["singleton", "doublet", "abstention", "singleton_error"] for metric in metrics_list: metric_dict: dict[str, Any] = {} scope_dict[metric] = metric_dict # Collect selected method coverages selected_coverages = np.array( [all_results[i][scope][metric]["selected"]["coverage"] for i in range(BigN)], dtype=float ) metric_dict["selected"] = compute_coverage_stats(selected_coverages, ci_level) # Collect method-specific coverages method_names = ["analytical", "exact", "hoeffding"] for method_name in method_names: method_coverages_list: list[float] = [] for i in range(BigN): if method_name in all_results[i][scope][metric]: method_coverages_list.append(all_results[i][scope][metric][method_name]["coverage"]) # type: ignore[index] else: method_coverages_list.append(np.nan) method_coverages = np.array(method_coverages_list, dtype=float) if not np.all(np.isnan(method_coverages)): metric_dict[method_name] = compute_coverage_stats(method_coverages, ci_level) # Store raw calibration data for quantile analysis aggregated["_raw_calibration_data"] = all_results return aggregated
[docs] def get_calibration_bounds_dataframe( results: dict[str, Any], scope: str | None = None, metric: str | None = None, ) -> Any: """Extract calibration bounds and observed quantiles as DataFrame. Converts the raw calibration data from validate_prediction_interval_calibration() into a pandas DataFrame format for easy plotting and analysis. Parameters ---------- results : dict Output from validate_prediction_interval_calibration() scope : str, optional Filter to specific scope: "marginal", "class_0", or "class_1". If None, includes all scopes. metric : str, optional Filter to specific metric: "singleton", "doublet", "abstention", "singleton_error". If None, includes all metrics. Returns ------- DataFrame Pandas DataFrame with columns: - calibration_idx: Index of calibration dataset (0 to BigN-1) - scope: marginal, class_0, or class_1 - metric: singleton, doublet, abstention, singleton_error - observed_q05: 5th percentile of test set rates - observed_q95: 95th percentile of test set rates - selected_lower: Lower bound from selected method - selected_upper: Upper bound from selected method - analytical_lower: Lower bound from analytical method (NaN if not available) - analytical_upper: Upper bound from analytical method (NaN if not available) - exact_lower: Lower bound from exact method (NaN if not available) - exact_upper: Upper bound from exact method (NaN if not available) - hoeffding_lower: Lower bound from hoeffding method (NaN if not available) - hoeffding_upper: Upper bound from hoeffding method (NaN if not available) Examples -------- >>> import pandas as pd >>> from ssbc import get_calibration_bounds_dataframe >>> results = validate_prediction_interval_calibration(...) >>> df = get_calibration_bounds_dataframe(results) >>> # Filter to singleton marginal >>> df_single = df[(df['scope'] == 'marginal') & (df['metric'] == 'singleton')] >>> # Plot lower bounds >>> import matplotlib.pyplot as plt >>> plt.scatter(df_single['analytical_lower'], df_single['observed_q05']) """ try: import pandas as pd except ImportError as err: raise ImportError("pandas is required for DataFrame conversion. Install with: pip install pandas") from err if "_raw_calibration_data" not in results: raise ValueError( "Results dict does not contain raw calibration data. " "Make sure to use validate_prediction_interval_calibration() output." ) raw_data = results["_raw_calibration_data"] BigN = len(raw_data) # Build DataFrame rows rows = [] scopes = ["marginal", "class_0", "class_1"] if scope is None else [scope] # Include new marginal metrics by default; per-class scopes will be skipped when missing default_metrics = [ "singleton", "doublet", "abstention", "singleton_error", # Class-specific rates for scope 'marginal' "singleton_rate_class0", "singleton_rate_class1", "doublet_rate_class0", "doublet_rate_class1", "abstention_rate_class0", "abstention_rate_class1", # Class-specific error rates "singleton_error_class0", "singleton_error_class1", # Class-specific correct rates "singleton_correct_class0", "singleton_correct_class1", # Error/correct rates when singleton is assigned to a specific class "singleton_error_pred_class0", "singleton_error_pred_class1", "singleton_correct_pred_class0", "singleton_correct_pred_class1", ] metrics = default_metrics if metric is None else [metric] for cal_idx in range(BigN): for scope_name in scopes: for metric_name in metrics: if scope_name not in raw_data[cal_idx] or metric_name not in raw_data[cal_idx][scope_name]: continue m = raw_data[cal_idx][scope_name][metric_name] row = { "calibration_idx": cal_idx, "scope": scope_name, "metric": metric_name, "observed_q05": m.get("observed_q05", np.nan), "observed_q95": m.get("observed_q95", np.nan), } # Add selected bounds if "selected" in m: row["selected_lower"] = m["selected"].get("lower", np.nan) row["selected_upper"] = m["selected"].get("upper", np.nan) else: row["selected_lower"] = np.nan row["selected_upper"] = np.nan # Add method-specific bounds for method_name in ["analytical", "exact", "hoeffding"]: if method_name in m: row[f"{method_name}_lower"] = m[method_name].get("lower", np.nan) row[f"{method_name}_upper"] = m[method_name].get("upper", np.nan) else: row[f"{method_name}_lower"] = np.nan row[f"{method_name}_upper"] = np.nan rows.append(row) df = pd.DataFrame(rows) return df
[docs] def tabulate_calibration_excess( df: Any, scope: str | None = None, metric: str | None = None, methods: list[str] | None = None, ) -> Any: """Tabulate excess values for all methods as a DataFrame. Computes excess values (difference between observed and predicted quantiles) for each method and returns a structured DataFrame with statistics. Parameters ---------- df : DataFrame Output from get_calibration_bounds_dataframe() scope : str, optional Filter to specific scope: "marginal", "class_0", or "class_1". If None, includes all scopes. metric : str, optional Filter to specific metric: "singleton", "doublet", "abstention", "singleton_error". If None, includes all metrics. methods : list[str], optional Methods to include: ["analytical", "exact", "hoeffding", "selected"]. If None, includes all available methods. Returns ------- DataFrame DataFrame with columns: - method: Method name (analytical, exact, hoeffding, selected) - bound_type: "lower" or "upper" - excess_mean: Mean excess value - excess_std: Standard deviation of excess - excess_min: Minimum excess value - excess_max: Maximum excess value - excess_q05: 5th percentile of excess - excess_q25: 25th percentile of excess - excess_q50: 50th percentile (median) of excess - excess_q75: 75th percentile of excess - excess_q95: 95th percentile of excess - n_negative: Number of negative excess values (risky) - n_positive: Number of positive excess values (conservative) - pct_negative: Percentage of negative excess values - n_valid: Number of valid (non-NaN) excess values - scope: Scope name (if filtered) - metric: Metric name (if filtered) Examples -------- >>> from ssbc import get_calibration_bounds_dataframe, tabulate_calibration_excess >>> results = validate_prediction_interval_calibration(...) >>> df = get_calibration_bounds_dataframe(results) >>> # Tabulate excess for singleton marginal >>> df_single = df[(df['scope'] == 'marginal') & (df['metric'] == 'singleton')] >>> excess_table = tabulate_calibration_excess(df_single, scope='marginal', metric='singleton') >>> print(excess_table) """ try: import pandas as pd except ImportError as err: raise ImportError("pandas is required for tabulation. Install with: pip install pandas") from err # Filter DataFrame df_filtered = df.copy() # Apply scope filter if scope is not None: df_filtered = df_filtered[df_filtered["scope"] == scope] # Apply metric filter if metric is not None: unique_metrics = ( df_filtered["metric"].unique() if "metric" in df_filtered.columns and len(df_filtered) > 0 else [] ) if len(unique_metrics) == 1 and unique_metrics[0] != metric: actual_metric = unique_metrics[0] df_filtered = df_filtered[df_filtered["metric"] == actual_metric] else: df_filtered = df_filtered[df_filtered["metric"] == metric] if len(df_filtered) == 0: available_metrics = df["metric"].unique().tolist() if "metric" in df.columns else [] available_scopes = df["scope"].unique().tolist() if "scope" in df.columns else [] raise ValueError( f"No data matching the specified scope/metric filters.\n" f" Requested: scope='{scope}', metric='{metric}'\n" f" Available metrics: {available_metrics}\n" f" Available scopes: {available_scopes}" ) # Determine which methods are available available_methods = [] if methods is None: for method_name in ["selected", "analytical", "exact", "hoeffding"]: lower_col = f"{method_name}_lower" if method_name != "selected" else "selected_lower" if lower_col in df_filtered.columns: if df_filtered[lower_col].notna().any(): available_methods.append(method_name) else: for m in methods: if m == "selected" and "selected_lower" in df_filtered.columns: available_methods.append(m) elif f"{m}_lower" in df_filtered.columns: available_methods.append(m) if len(available_methods) == 0: raise ValueError("No method data available in DataFrame") # Build table rows rows = [] for method_name in available_methods: # Lower excess: observed_q05 - predicted_lower lower_col = f"{method_name}_lower" if method_name != "selected" else "selected_lower" lower_pred = df_filtered[lower_col] lower_obs = df_filtered["observed_q05"] valid_mask_lower = lower_pred.notna() & lower_obs.notna() if valid_mask_lower.sum() > 0: excess_lower = lower_obs[valid_mask_lower] - lower_pred[valid_mask_lower] excess_array = excess_lower.values rows.append( { "method": method_name, "bound_type": "lower", "excess_mean": float(np.mean(excess_array)), "excess_std": float(np.std(excess_array)), "excess_min": float(np.min(excess_array)), "excess_max": float(np.max(excess_array)), "excess_q05": float(np.percentile(excess_array, 5)), "excess_q25": float(np.percentile(excess_array, 25)), "excess_q50": float(np.percentile(excess_array, 50)), "excess_q75": float(np.percentile(excess_array, 75)), "excess_q95": float(np.percentile(excess_array, 95)), "n_negative": int((excess_array < 0).sum()), "n_positive": int((excess_array >= 0).sum()), "pct_negative": float((excess_array < 0).sum() / len(excess_array) * 100), "n_valid": int(len(excess_array)), "scope": scope if scope else df_filtered["scope"].iloc[0] if len(df_filtered) > 0 else None, "metric": metric if metric else df_filtered["metric"].iloc[0] if len(df_filtered) > 0 else None, } ) # Upper excess: predicted_upper - observed_q95 upper_col = f"{method_name}_upper" if method_name != "selected" else "selected_upper" upper_pred = df_filtered[upper_col] upper_obs = df_filtered["observed_q95"] valid_mask_upper = upper_pred.notna() & upper_obs.notna() if valid_mask_upper.sum() > 0: excess_upper = upper_pred[valid_mask_upper] - upper_obs[valid_mask_upper] excess_array = excess_upper.values rows.append( { "method": method_name, "bound_type": "upper", "excess_mean": float(np.mean(excess_array)), "excess_std": float(np.std(excess_array)), "excess_min": float(np.min(excess_array)), "excess_max": float(np.max(excess_array)), "excess_q05": float(np.percentile(excess_array, 5)), "excess_q25": float(np.percentile(excess_array, 25)), "excess_q50": float(np.percentile(excess_array, 50)), "excess_q75": float(np.percentile(excess_array, 75)), "excess_q95": float(np.percentile(excess_array, 95)), "n_negative": int((excess_array < 0).sum()), "n_positive": int((excess_array >= 0).sum()), "pct_negative": float((excess_array < 0).sum() / len(excess_array) * 100), "n_valid": int(len(excess_array)), "scope": scope if scope else df_filtered["scope"].iloc[0] if len(df_filtered) > 0 else None, "metric": metric if metric else df_filtered["metric"].iloc[0] if len(df_filtered) > 0 else None, } ) result_df = pd.DataFrame(rows) return result_df
[docs] def plot_calibration_excess( df: Any, scope: str | None = None, metric: str | None = None, methods: list[str] | None = None, figsize: tuple[int, int] = (14, 6), bins: int = 30, return_fig: bool = False, filename: str | None = None, ) -> Any: """Plot excess (difference between observed and predicted quantiles). Creates histograms showing: - Lower excess: observed_q05 - predicted_lower (positive = predicted too high) - Upper excess: predicted_upper - observed_q95 (positive = predicted too high) Parameters ---------- df : DataFrame Output from get_calibration_bounds_dataframe() scope : str, optional Filter to specific scope: "marginal", "class_0", or "class_1". If None, uses all scopes (creates separate subplots). metric : str, optional Filter to specific metric: "singleton", "doublet", "abstention", "singleton_error". If None, uses all metrics (creates separate subplots). methods : list[str], optional Methods to plot: ["analytical", "exact", "hoeffding"]. If None, plots all available methods. figsize : tuple[int, int], default=(14, 6) Figure size (width, height in inches) bins : int, default=30 Number of histogram bins return_fig : bool, default=False If True, returns matplotlib Figure object. If False, calls plt.show() filename : str, optional If provided, saves the plot to this filename (e.g., "excess_plot.png"). Supports common formats: .png, .pdf, .svg, .jpg, etc. Returns ------- Figure or None If return_fig=True, returns Figure object. Otherwise None. Examples -------- >>> from ssbc import get_calibration_bounds_dataframe, plot_calibration_excess >>> results = validate_prediction_interval_calibration(...) >>> df = get_calibration_bounds_dataframe(results) >>> # Plot for singleton marginal >>> df_single = df[(df['scope'] == 'marginal') & (df['metric'] == 'singleton')] >>> plot_calibration_excess(df_single, scope='marginal', metric='singleton') """ try: import matplotlib.pyplot as plt except ImportError as err: raise ImportError("matplotlib is required for plotting. Install with: pip install matplotlib") from err # We intentionally do not import pandas here to avoid unused import warnings. # The provided DataFrame `df` is treated duck-typed for plotting. # Filter DataFrame df_filtered = df.copy() # Apply scope filter first if scope is not None: df_filtered = df_filtered[df_filtered["scope"] == scope] # Check if dataframe is already filtered to a single metric (after scope filtering) unique_metrics = df_filtered["metric"].unique() if "metric" in df_filtered.columns and len(df_filtered) > 0 else [] if metric is not None: # If dataframe is already filtered to a single metric, use that metric # (ignore the passed parameter if it doesn't match - this handles pre-filtered dataframes) if len(unique_metrics) == 1 and unique_metrics[0] != metric: # Use the metric that's already in the dataframe actual_metric = unique_metrics[0] df_filtered = df_filtered[df_filtered["metric"] == actual_metric] else: df_filtered = df_filtered[df_filtered["metric"] == metric] elif len(unique_metrics) == 1: # If no metric specified but dataframe has only one metric, use it df_filtered = df_filtered[df_filtered["metric"] == unique_metrics[0]] if len(df_filtered) == 0: # Provide helpful error message with available options available_metrics = df["metric"].unique().tolist() if "metric" in df.columns else [] available_scopes = df["scope"].unique().tolist() if "scope" in df.columns else [] raise ValueError( f"No data matching the specified scope/metric filters.\n" f" Requested: scope='{scope}', metric='{metric}'\n" f" Available metrics: {available_metrics}\n" f" Available scopes: {available_scopes}" ) # Determine which methods are available available_methods = [] method_colors = { "selected": "#4C956C", # Green "analytical": "#2E86AB", # Blue "exact": "#A23B72", # Purple "hoeffding": "#F18F01", # Orange } if methods is None: # Check which methods have non-NaN data (include selected) for method_name in ["selected", "analytical", "exact", "hoeffding"]: lower_col = f"{method_name}_lower" if method_name != "selected" else "selected_lower" if lower_col in df_filtered.columns: if df_filtered[lower_col].notna().any(): available_methods.append(method_name) else: # Respect requested methods; support 'selected' tmp = [] for m in methods: if m == "selected" and "selected_lower" in df_filtered.columns: tmp.append(m) elif f"{m}_lower" in df_filtered.columns: tmp.append(m) available_methods = tmp if len(available_methods) == 0: raise ValueError("No method data available in DataFrame") # Create figure with two subplots (lower and upper excess) fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) # Plot lower excess: observed_q05 - predicted_lower # Positive excess = predicted lower bound is too high (conservative) # Negative excess = predicted lower bound is too low (underestimates) for method_name in available_methods: lower_col = f"{method_name}_lower" if method_name != "selected" else "selected_lower" lower_pred = df_filtered[lower_col] lower_obs = df_filtered["observed_q05"] # Compute excess only for non-NaN pairs valid_mask = lower_pred.notna() & lower_obs.notna() if valid_mask.sum() > 0: excess_lower = lower_obs[valid_mask] - lower_pred[valid_mask] color = method_colors.get(method_name, "gray") # Compute percentage of excess values below zero (risky/underestimated) # This tells us: what percentage of excess values are negative (risky) vs positive (conservative) excess_array = excess_lower.values if len(excess_array) > 0: # Compute percentage of negative values (mass below zero) n_negative = (excess_array < 0).sum() n_total = len(excess_array) excess_percent = (n_negative / n_total) * 100 # Format as percentage: "excess 3.2%" means 3.2% are negative (risky) if excess_percent < 0.1: excess_str = "excess 0.0%" elif excess_percent < 1: excess_str = f"excess {excess_percent:.2f}%" else: excess_str = f"excess {excess_percent:.1f}%" label_str = f"{method_name.capitalize()} ({excess_str})" else: label_str = f"{method_name.capitalize()} (n={len(excess_lower)})" ax1.hist( excess_lower, bins=bins, alpha=0.6, label=label_str, color=color, edgecolor="black", ) ax1.axvline(0, color="k", linestyle="--", linewidth=2, label="Perfect (0 excess)", zorder=10) ax1.set_xlabel("Lower Excess: Observed Q05 - Predicted Lower", fontsize=11) ax1.set_ylabel("Frequency", fontsize=11) ax1.set_title("Lower Bound Calibration", fontsize=12, fontweight="bold") ax1.legend(loc="upper right", fontsize=9) ax1.grid(alpha=0.3) # Add interpretation text ax1.text( 0.02, 0.98, "Positive = Predicted too high (conservative)\nNegative = Predicted too low (risky)", transform=ax1.transAxes, fontsize=8, verticalalignment="top", bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5), ) # Plot upper excess: predicted_upper - observed_q95 # Positive excess = predicted upper bound is higher than needed (conservative) # Negative excess = predicted upper bound is too low (underestimates) for method_name in available_methods: upper_col = f"{method_name}_upper" if method_name != "selected" else "selected_upper" upper_pred = df_filtered[upper_col] upper_obs = df_filtered["observed_q95"] # Compute excess only for non-NaN pairs valid_mask = upper_pred.notna() & upper_obs.notna() if valid_mask.sum() > 0: excess_upper = upper_pred[valid_mask] - upper_obs[valid_mask] color = method_colors.get(method_name, "gray") # Compute percentage of excess values below zero (risky/underestimated) excess_array = excess_upper.values if len(excess_array) > 0: # Compute percentage of negative values (mass below zero) n_negative = (excess_array < 0).sum() n_total = len(excess_array) excess_percent = (n_negative / n_total) * 100 # Format as percentage: "excess 3.2%" means 3.2% are negative (risky) if excess_percent < 0.1: excess_str = "excess 0.0%" elif excess_percent < 1: excess_str = f"excess {excess_percent:.2f}%" else: excess_str = f"excess {excess_percent:.1f}%" label_str = f"{method_name.capitalize()} ({excess_str})" else: label_str = f"{method_name.capitalize()} (n={len(excess_upper)})" ax2.hist( excess_upper, bins=bins, alpha=0.6, label=label_str, color=color, edgecolor="black", ) ax2.axvline(0, color="k", linestyle="--", linewidth=2, label="Perfect (0 excess)", zorder=10) ax2.set_xlabel("Upper Excess: Predicted Upper - Observed Q95", fontsize=11) ax2.set_ylabel("Frequency", fontsize=11) ax2.set_title("Upper Bound Calibration", fontsize=12, fontweight="bold") ax2.legend(loc="upper right", fontsize=9) ax2.grid(alpha=0.3) # Add interpretation text ax2.text( 0.02, 0.98, "Positive = Predicted too high (conservative)\nNegative = Predicted too low (risky)", transform=ax2.transAxes, fontsize=8, verticalalignment="top", bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5), ) # Add overall title with scope/metric info scope_str = scope if scope else "All Scopes" metric_str = metric if metric else "All Metrics" fig.suptitle(f"Calibration Excess Analysis: {scope_str} - {metric_str}", fontsize=14, fontweight="bold", y=1.02) plt.tight_layout() # Save to file if filename provided if filename is not None: fig.savefig(filename, dpi=300, bbox_inches="tight") print(f"Plot saved to: {filename}") if return_fig: return fig elif filename is not None: # If saved to file, don't show (close figure) plt.close(fig) return None else: plt.show() return None