Source code for ssbc.calibration.bootstrap

"""Bootstrap analysis of calibration uncertainty for operational rates.

This models: "If I recalibrate many times on similar datasets, how do rates vary?"
Different from LOO-CV which models: "Given ONE fixed calibration, how do test sets vary?"
"""

from typing import Protocol

import numpy as np
from joblib import Parallel, delayed

from ssbc.core_pkg import ssbc_correct

from .conformal import split_by_class

# Optional plotting dependencies
try:
    import matplotlib.pyplot as plt

    HAS_MATPLOTLIB = True
except ImportError:
    HAS_MATPLOTLIB = False


[docs] class DataGenerator(Protocol): """Protocol for data generators (e.g., BinaryClassifierSimulator)."""
[docs] def generate(self, n_samples: int) -> tuple[np.ndarray, np.ndarray]: """Generate samples. Returns ------- tuple (labels, probabilities) """ ...
def _bootstrap_single_trial( labels: np.ndarray, probs: np.ndarray, alpha_target: float, delta: float, test_size: int, bootstrap_seed: int, simulator: DataGenerator, ) -> dict[str, float]: """Single bootstrap trial: resample calibration → calibrate → evaluate on fresh test set. Parameters ---------- labels : np.ndarray Calibration labels probs : np.ndarray Calibration probabilities alpha_target : float Target miscoverage delta : float PAC risk test_size : int Test set size bootstrap_seed : int Random seed for this trial simulator : DataGenerator Simulator to generate fresh test sets Returns ------- dict Operational rates for this bootstrap sample """ np.random.seed(bootstrap_seed) n = len(labels) # Bootstrap resample calibration data (with replacement) bootstrap_idx = np.random.choice(n, size=n, replace=True) labels_boot = labels[bootstrap_idx] probs_boot = probs[bootstrap_idx] # Split by class class_data_boot = split_by_class(labels_boot, probs_boot) # Calibrate on bootstrap sample try: ssbc_0 = ssbc_correct(alpha_target=alpha_target, n=class_data_boot[0]["n"], delta=delta) ssbc_1 = ssbc_correct(alpha_target=alpha_target, n=class_data_boot[1]["n"], delta=delta) except Exception: # Handle edge cases (e.g., all samples from one class) return { "singleton": np.nan, "doublet": np.nan, "abstention": np.nan, "singleton_error": np.nan, "singleton_0": np.nan, "doublet_0": np.nan, "abstention_0": np.nan, "singleton_error_0": np.nan, "singleton_1": np.nan, "doublet_1": np.nan, "abstention_1": np.nan, "singleton_error_1": np.nan, } # Compute thresholds n_0 = class_data_boot[0]["n"] n_1 = class_data_boot[1]["n"] k_0 = int(np.ceil((n_0 + 1) * (1 - ssbc_0.alpha_corrected))) k_1 = int(np.ceil((n_1 + 1) * (1 - ssbc_1.alpha_corrected))) mask_0 = labels_boot == 0 mask_1 = labels_boot == 1 scores_0 = 1.0 - probs_boot[mask_0, 0] scores_1 = 1.0 - probs_boot[mask_1, 1] sorted_0 = np.sort(scores_0) sorted_1 = np.sort(scores_1) threshold_0 = sorted_0[min(k_0 - 1, len(sorted_0) - 1)] threshold_1 = sorted_1[min(k_1 - 1, len(sorted_1) - 1)] # Generate FRESH test set labels_test, probs_test = simulator.generate(test_size) # Evaluate on test set n_test = len(labels_test) n_singletons = 0 n_doublets = 0 n_abstentions = 0 n_singletons_correct = 0 # Per-class counters n_singletons_0 = 0 n_doublets_0 = 0 n_abstentions_0 = 0 n_singletons_correct_0 = 0 n_class_0 = 0 n_singletons_1 = 0 n_doublets_1 = 0 n_abstentions_1 = 0 n_singletons_correct_1 = 0 n_class_1 = 0 for i in range(n_test): true_label = labels_test[i] score_0 = 1.0 - probs_test[i, 0] score_1 = 1.0 - probs_test[i, 1] in_0 = score_0 <= threshold_0 in_1 = score_1 <= threshold_1 # Marginal if in_0 and in_1: n_doublets += 1 elif in_0 or in_1: n_singletons += 1 if (in_0 and true_label == 0) or (in_1 and true_label == 1): n_singletons_correct += 1 else: n_abstentions += 1 # Per-class if true_label == 0: n_class_0 += 1 if in_0 and in_1: n_doublets_0 += 1 elif in_0 or in_1: n_singletons_0 += 1 if in_0: n_singletons_correct_0 += 1 else: n_abstentions_0 += 1 else: n_class_1 += 1 if in_0 and in_1: n_doublets_1 += 1 elif in_0 or in_1: n_singletons_1 += 1 if in_1: n_singletons_correct_1 += 1 else: n_abstentions_1 += 1 # Compute rates singleton_rate = n_singletons / n_test doublet_rate = n_doublets / n_test abstention_rate = n_abstentions / n_test singleton_error_rate = (n_singletons - n_singletons_correct) / n_singletons if n_singletons > 0 else np.nan # Per-class rates singleton_0 = n_singletons_0 / n_class_0 if n_class_0 > 0 else np.nan doublet_0 = n_doublets_0 / n_class_0 if n_class_0 > 0 else np.nan abstention_0 = n_abstentions_0 / n_class_0 if n_class_0 > 0 else np.nan singleton_error_0 = (n_singletons_0 - n_singletons_correct_0) / n_singletons_0 if n_singletons_0 > 0 else np.nan singleton_1 = n_singletons_1 / n_class_1 if n_class_1 > 0 else np.nan doublet_1 = n_doublets_1 / n_class_1 if n_class_1 > 0 else np.nan abstention_1 = n_abstentions_1 / n_class_1 if n_class_1 > 0 else np.nan singleton_error_1 = (n_singletons_1 - n_singletons_correct_1) / n_singletons_1 if n_singletons_1 > 0 else np.nan return { "singleton": singleton_rate, "doublet": doublet_rate, "abstention": abstention_rate, "singleton_error": singleton_error_rate, "singleton_0": singleton_0, "doublet_0": doublet_0, "abstention_0": abstention_0, "singleton_error_0": singleton_error_0, "singleton_1": singleton_1, "doublet_1": doublet_1, "abstention_1": abstention_1, "singleton_error_1": singleton_error_1, }
[docs] def bootstrap_calibration_uncertainty( labels: np.ndarray, probs: np.ndarray, simulator: DataGenerator, alpha_target: float = 0.10, delta: float = 0.10, test_size: int = 1000, n_bootstrap: int = 1000, n_jobs: int = -1, seed: int | None = None, ) -> dict: """Bootstrap analysis of calibration uncertainty. For each bootstrap iteration: 1. Resample calibration data with replacement 2. Calibrate (compute SSBC thresholds) 3. Evaluate on fresh independent test set 4. Record operational rates This models: "If I recalibrate on similar datasets, how do rates vary?" Parameters ---------- labels : np.ndarray Calibration labels probs : np.ndarray Calibration probabilities simulator : DataGenerator Simulator to generate independent test sets alpha_target : float, default=0.10 Target miscoverage delta : float, default=0.10 PAC risk test_size : int, default=1000 Size of test sets for evaluation n_bootstrap : int, default=1000 Number of bootstrap iterations n_jobs : int, default=-1 Parallel jobs (-1 for all cores) seed : int, optional Random seed Returns ------- dict Bootstrap distributions with keys: - 'marginal': dict with 'singleton', 'doublet', 'abstention', 'singleton_error' - 'class_0': dict with same metrics - 'class_1': dict with same metrics Each metric contains: - 'samples': array of rates across bootstrap trials - 'mean': mean rate - 'std': standard deviation - 'quantiles': dict with q05, q25, q50, q75, q95 Examples -------- >>> from ssbc import BinaryClassifierSimulator, bootstrap_calibration_uncertainty >>> sim = BinaryClassifierSimulator(p_class1=0.2, beta_params_class0=(1,7), beta_params_class1=(5,2)) >>> labels, probs = sim.generate(100) >>> results = bootstrap_calibration_uncertainty(labels, probs, sim, n_bootstrap=100) >>> print(results['marginal']['singleton']['mean']) """ if seed is not None: np.random.seed(seed) # Generate bootstrap seeds bootstrap_seeds = np.random.randint(0, 2**31, size=n_bootstrap) # Parallel bootstrap with safe fallback to serial try: results = Parallel(n_jobs=n_jobs)( delayed(_bootstrap_single_trial)(labels, probs, alpha_target, delta, test_size, bs_seed, simulator) for bs_seed in bootstrap_seeds ) except Exception: results = [ _bootstrap_single_trial(labels, probs, alpha_target, delta, test_size, bs_seed, simulator) for bs_seed in bootstrap_seeds ] # Extract metrics metrics = ["singleton", "doublet", "abstention", "singleton_error"] def compute_stats(values): """Compute statistics for a metric.""" arr = np.array(values) valid = arr[~np.isnan(arr)] if len(valid) == 0: return { "samples": arr, "mean": np.nan, "std": np.nan, "quantiles": {"q05": np.nan, "q25": np.nan, "q50": np.nan, "q75": np.nan, "q95": np.nan}, } return { "samples": arr, "mean": np.mean(valid), "std": np.std(valid), "quantiles": { "q05": np.percentile(valid, 5), "q25": np.percentile(valid, 25), "q50": np.percentile(valid, 50), "q75": np.percentile(valid, 75), "q95": np.percentile(valid, 95), }, } # Organize results return { "n_bootstrap": n_bootstrap, "n_calibration": len(labels), "test_size": test_size, "marginal": {metric: compute_stats([r[metric] for r in results]) for metric in metrics}, "class_0": {metric: compute_stats([r[f"{metric}_0"] for r in results]) for metric in metrics}, "class_1": {metric: compute_stats([r[f"{metric}_1"] for r in results]) for metric in metrics}, }
[docs] def plot_bootstrap_distributions( bootstrap_results: dict, figsize: tuple[int, int] = (16, 12), save_path: str | None = None, ) -> None: """Plot bootstrap distributions. Parameters ---------- bootstrap_results : dict Results from bootstrap_calibration_uncertainty() figsize : tuple, default=(16, 12) Figure size save_path : str, optional Path to save figure. If None, displays interactively. Raises ------ ImportError If matplotlib is not installed Examples -------- >>> from ssbc import bootstrap_calibration_uncertainty, plot_bootstrap_distributions >>> results = bootstrap_calibration_uncertainty(...) >>> plot_bootstrap_distributions(results, save_path='bootstrap_results.png') """ if not HAS_MATPLOTLIB: raise ImportError("matplotlib is required for plotting. Install with: pip install matplotlib") fig, axes = plt.subplots(3, 4, figsize=figsize) fig.suptitle( f"Bootstrap Calibration Uncertainty ({bootstrap_results['n_bootstrap']} trials)\n" f"Calibration n={bootstrap_results['n_calibration']}, Test size={bootstrap_results['test_size']}", fontsize=14, fontweight="bold", ) metrics = ["singleton", "doublet", "abstention", "singleton_error"] metric_names = ["Singleton Rate", "Doublet Rate", "Abstention Rate", "Singleton Error Rate"] colors = ["steelblue", "coral", "mediumpurple"] row_names = ["MARGINAL", "CLASS 0", "CLASS 1"] data_keys = ["marginal", "class_0", "class_1"] for row, (row_name, data_key, color) in enumerate(zip(row_names, data_keys, colors, strict=False)): for col, (metric, name) in enumerate(zip(metrics, metric_names, strict=False)): ax = axes[row, col] m = bootstrap_results[data_key][metric] # Filter NaNs samples = m["samples"] samples = samples[~np.isnan(samples)] if len(samples) == 0: ax.text(0.5, 0.5, "No data", ha="center", va="center") continue # Histogram ax.hist(samples, bins=50, alpha=0.7, color=color, edgecolor="black") # Quantiles q = m["quantiles"] ax.axvline(q["q50"], color="green", linestyle="-", linewidth=2, label=f"Median: {q['q50']:.3f}") ax.axvline(q["q05"], color="red", linestyle="--", linewidth=2, label=f"5%: {q['q05']:.3f}") ax.axvline(q["q95"], color="red", linestyle="--", linewidth=2, label=f"95%: {q['q95']:.3f}") ax.axvline(m["mean"], color="orange", linestyle=":", linewidth=2, label=f"Mean: {m['mean']:.3f}") ax.set_title(f"{row_name}: {name}", fontweight="bold") ax.set_xlabel("Rate") ax.set_ylabel("Count") ax.legend(loc="best", fontsize=8) ax.grid(True, alpha=0.3) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches="tight") print(f"✅ Saved bootstrap visualization to: {save_path}")
# In non-interactive/test environments, avoid plt.show() to prevent warnings # Callers can explicitly show or save the returned figure if needed.