Source code for ssbc.utils

"""Utility functions for conformal prediction."""

from typing import Any, Literal

import numpy as np


[docs] def build_mondrian_prediction_sets( probs: np.ndarray, threshold_0: float, threshold_1: float, return_lists: bool = False, ) -> list[set[int] | list[int]]: """Build prediction sets using Mondrian conformal prediction thresholds. This function implements the standard Mondrian conformal prediction approach: - For each sample, include class 0 if score_0 <= threshold_0 - For each sample, include class 1 if score_1 <= threshold_1 - Where score_k = 1 - P(class=k) Parameters ---------- probs : np.ndarray, shape (n, 2) Probability predictions for each sample. probs[i, 0] = P(class=0), probs[i, 1] = P(class=1) threshold_0 : float Conformal prediction threshold for class 0 threshold_1 : float Conformal prediction threshold for class 1 return_lists : bool, default=False If True, returns lists instead of sets Returns ------- list[set[int]] or list[list[int]] List of prediction sets, where each set/list contains the classes included in the prediction set for that sample. Examples -------- >>> import numpy as np >>> from ssbc.utils import build_mondrian_prediction_sets >>> >>> probs = np.array([ ... [0.8, 0.2], # High confidence class 0 ... [0.5, 0.5], # Uncertain ... [0.2, 0.8], # High confidence class 1 ... ]) >>> threshold_0, threshold_1 = 0.3, 0.3 >>> pred_sets = build_mondrian_prediction_sets(probs, threshold_0, threshold_1) >>> print(pred_sets) # [{0}, {0, 1}, {1}] Notes ----- This function is used throughout the codebase for building Mondrian conformal prediction sets. It centralizes the logic to ensure consistency across all modules that perform conformal prediction evaluation. """ n = len(probs) if probs.shape != (n, 2): raise ValueError(f"probs must have shape (n, 2), got {probs.shape}") # Vectorize score computation scores_0 = 1.0 - probs[:, 0] scores_1 = 1.0 - probs[:, 1] prediction_sets = [] for score_0, score_1 in zip(scores_0, scores_1, strict=False): if return_lists: pred_set = [] if score_0 <= threshold_0: pred_set.append(0) if score_1 <= threshold_1: pred_set.append(1) else: pred_set = set() if score_0 <= threshold_0: pred_set.add(0) if score_1 <= threshold_1: pred_set.add(1) prediction_sets.append(pred_set) return prediction_sets
[docs] def build_conditional_prediction_sets( probs: np.ndarray, threshold: float, return_lists: bool = False, ) -> list[set[int] | list[int]]: """Build prediction sets using a SINGLE threshold for conditional analysis. Unlike Mondrian CP which uses separate thresholds per class, this uses ONE threshold for BOTH classes - as in standard (non-Mondrian) conformal prediction. This is used for conditional analysis where we want to evaluate predictions conditioned on the true class label, using the threshold calibrated for that class. Parameters ---------- probs : np.ndarray, shape (n, 2) Probability predictions [P(class=0), P(class=1)] Note: The data should be filtered by true class label before calling this function. For conditional analysis, only samples with the same true label should be included. threshold : float Single conformal prediction threshold for both classes This should be the threshold calibrated for the class of the samples in probs. return_lists : bool, default=False If True, returns lists instead of sets Returns ------- list[set[int]] or list[list[int]] Prediction sets where: - {0, 1} if both P(0) >= 1-threshold AND P(1) >= 1-threshold (doublet) - {0} if P(0) >= 1-threshold AND P(1) < 1-threshold (singleton) - {1} if P(1) >= 1-threshold AND P(0) < 1-threshold (singleton) - {} if both P(0) < 1-threshold AND P(1) < 1-threshold (abstention) Examples -------- >>> import numpy as np >>> from ssbc.utils import build_conditional_prediction_sets >>> >>> probs = np.array([ ... [0.8, 0.2], # High confidence class 0: score_0=0.2, score_1=0.8 ... [0.75, 0.75], # Uncertain, both above threshold: score_0=0.25, score_1=0.25 ... [0.2, 0.8], # High confidence class 1: score_0=0.8, score_1=0.2 ... ]) >>> threshold = 0.3 >>> pred_sets = build_conditional_prediction_sets(probs, threshold) >>> print(pred_sets) # [{0}, {0, 1}, {1}] Notes ----- This function is used for conditional analysis in Mondrian conformal prediction, where we evaluate prediction sets conditioned on the true class label. For each class, we use the threshold calibrated for that class and apply it to BOTH classes in the prediction set, providing conditional coverage guarantees. The data is filtered by true class label BEFORE calling this function (e.g., via split_by_class). This ensures that when evaluating conditional coverage P(Y ∈ C(X) | Y = y), we only analyze samples where the true label Y equals the class y for which the threshold was calibrated. """ n = len(probs) if probs.shape != (n, 2): raise ValueError(f"probs must have shape (n, 2), got {probs.shape}") # Compute nonconformity scores scores_0 = 1.0 - probs[:, 0] scores_1 = 1.0 - probs[:, 1] prediction_sets = [] for score_0, score_1 in zip(scores_0, scores_1, strict=False): if return_lists: pred_set = [] if score_0 <= threshold: # Same threshold for both classes! pred_set.append(0) if score_1 <= threshold: # Same threshold for both classes! pred_set.append(1) else: pred_set = set() if score_0 <= threshold: pred_set.add(0) if score_1 <= threshold: pred_set.add(1) prediction_sets.append(pred_set) return prediction_sets
[docs] def compute_operational_rate( prediction_sets: list[set[int] | list[int]], true_labels: np.ndarray, rate_type: Literal["singleton", "doublet", "abstention", "error_in_singleton", "correct_in_singleton"], ) -> np.ndarray: """Compute operational rate indicators for prediction sets. For each prediction set, compute a binary indicator showing whether a specific operational event occurred (singleton, doublet, abstention, error in singleton, or correct in singleton). Parameters ---------- prediction_sets : list[set | list] Prediction sets for each sample. Each set contains predicted labels. true_labels : np.ndarray True labels for each sample rate_type : {"singleton", "doublet", "abstention", "error_in_singleton", "correct_in_singleton"} Type of operational rate to compute: - "singleton": prediction set contains exactly one label - "doublet": prediction set contains exactly two labels - "abstention": prediction set is empty - "error_in_singleton": singleton prediction that doesn't contain true label - "correct_in_singleton": singleton prediction that contains true label Returns ------- np.ndarray Binary indicators (0 or 1) for whether the event holds for each sample Examples -------- >>> pred_sets = [{0}, {0, 1}, set(), {1}] >>> true_labels = np.array([0, 0, 1, 0]) >>> indicators = compute_operational_rate(pred_sets, true_labels, "singleton") >>> print(indicators) # [1, 0, 0, 1] >>> indicators = compute_operational_rate(pred_sets, true_labels, "correct_in_singleton") >>> print(indicators) # [1, 0, 0, 0] - first and last are singletons, first is correct Notes ----- This function is useful for computing operational statistics on conformal prediction sets, such as singleton rates, escalation rates, and error rates. """ n = len(prediction_sets) indicators = np.zeros(n, dtype=int) for i in range(n): pred_set = prediction_sets[i] y_true = true_labels[i] if rate_type == "singleton": indicators[i] = int(len(pred_set) == 1) elif rate_type == "doublet": indicators[i] = int(len(pred_set) == 2) elif rate_type == "abstention": indicators[i] = int(len(pred_set) == 0) elif rate_type == "error_in_singleton": indicators[i] = int(len(pred_set) == 1 and y_true not in pred_set) elif rate_type == "correct_in_singleton": indicators[i] = int(len(pred_set) == 1 and y_true in pred_set) else: raise ValueError(f"Unknown rate_type: {rate_type}") return indicators
[docs] def evaluate_test_dataset( test_labels: np.ndarray, test_probs: np.ndarray, threshold_0: float, threshold_1: float, ) -> dict[str, Any]: """Evaluate a test dataset and compute empirical operational rates. This function takes a test dataset with true labels and probability predictions, applies Mondrian conformal prediction thresholds, and returns comprehensive empirical rates for both marginal and per-class statistics. Parameters ---------- test_labels : np.ndarray True labels for test samples (0 or 1) test_probs : np.ndarray Probability predictions for test samples, shape (n_samples, 2) test_probs[i, 0] = P(class=0), test_probs[i, 1] = P(class=1) threshold_0 : float Conformal prediction threshold for class 0 threshold_1 : float Conformal prediction threshold for class 1 Returns ------- dict Dictionary containing empirical rates with structure: - 'marginal': Marginal rates across all samples - 'class_0': Rates for class 0 samples only - 'class_1': Rates for class 1 samples only Each containing: - 'singleton_rate': Fraction of samples with singleton predictions - 'doublet_rate': Fraction of samples with doublet predictions - 'abstention_rate': Fraction of samples with abstention (empty set) - 'singleton_error_rate': Fraction of singleton predictions that are incorrect - 'n_samples': Number of samples in this group - 'n_singletons': Number of singleton predictions - 'n_doublets': Number of doublet predictions - 'n_abstentions': Number of abstentions Examples -------- >>> import numpy as np >>> from ssbc import evaluate_test_dataset >>> >>> # Generate test data >>> test_labels = np.array([0, 0, 1, 1, 0]) >>> test_probs = np.array([ ... [0.8, 0.2], # High confidence class 0 ... [0.6, 0.4], # Medium confidence class 0 ... [0.3, 0.7], # High confidence class 1 ... [0.4, 0.6], # Medium confidence class 1 ... [0.5, 0.5], # Uncertain ... ]) >>> >>> # Evaluate with thresholds >>> results = evaluate_test_dataset(test_labels, test_probs, 0.3, 0.3) >>> print(f"Marginal singleton rate: {results['marginal']['singleton_rate']:.3f}") >>> print(f"Class 0 singleton rate: {results['class_0']['singleton_rate']:.3f}") Notes ----- This function is useful for: - Evaluating conformal prediction performance on test data - Comparing empirical rates to theoretical bounds - Computing operational statistics for reporting - Validating that thresholds work as expected The function builds prediction sets using the Mondrian approach: - For each sample, include class 0 if score_0 <= threshold_0 - For each sample, include class 1 if score_1 <= threshold_1 - Where score_k = 1 - P(class=k) """ n_test = len(test_labels) if n_test == 0: raise ValueError("Test dataset cannot be empty") if test_probs.shape != (n_test, 2): raise ValueError(f"test_probs must have shape ({n_test}, 2), got {test_probs.shape}") # Build prediction sets using Mondrian thresholds prediction_sets = build_mondrian_prediction_sets(test_probs, threshold_0, threshold_1) # Compute indicators for all rate types singleton_indicators = compute_operational_rate(prediction_sets, test_labels, "singleton") doublet_indicators = compute_operational_rate(prediction_sets, test_labels, "doublet") abstention_indicators = compute_operational_rate(prediction_sets, test_labels, "abstention") error_in_singleton_indicators = compute_operational_rate(prediction_sets, test_labels, "error_in_singleton") # Split by class class_0_mask = test_labels == 0 class_1_mask = test_labels == 1 def compute_rates(indicators: np.ndarray, mask: np.ndarray | None = None) -> dict[str, Any]: """Compute rates for a subset of samples.""" if mask is not None: subset_indicators = indicators[mask] n_subset = np.sum(mask) else: subset_indicators = indicators n_subset = len(indicators) if n_subset == 0: return { "singleton_rate": np.nan, "doublet_rate": np.nan, "abstention_rate": np.nan, "singleton_error_rate": np.nan, "n_samples": 0, "n_singletons": 0, "n_doublets": 0, "n_abstentions": 0, } # Compute rates singleton_rate = np.mean(subset_indicators) n_singletons = int(np.sum(subset_indicators)) # For other rates, use the appropriate indicators if mask is not None: doublet_indicators_subset = doublet_indicators[mask] abstention_indicators_subset = abstention_indicators[mask] error_indicators_subset = error_in_singleton_indicators[mask] else: doublet_indicators_subset = doublet_indicators abstention_indicators_subset = abstention_indicators error_indicators_subset = error_in_singleton_indicators doublet_rate = np.mean(doublet_indicators_subset) abstention_rate = np.mean(abstention_indicators_subset) n_doublets = int(np.sum(doublet_indicators_subset)) n_abstentions = int(np.sum(abstention_indicators_subset)) # Singleton error rate: errors among singletons if n_singletons > 0: singleton_error_rate = np.mean(error_indicators_subset[subset_indicators == 1]) else: singleton_error_rate = np.nan return { "singleton_rate": float(singleton_rate), "doublet_rate": float(doublet_rate), "abstention_rate": float(abstention_rate), "singleton_error_rate": float(singleton_error_rate) if not np.isnan(singleton_error_rate) else np.nan, "n_samples": int(n_subset), "n_singletons": n_singletons, "n_doublets": n_doublets, "n_abstentions": n_abstentions, } # Compute rates for all groups marginal_rates = compute_rates(singleton_indicators) class_0_rates = compute_rates(singleton_indicators, class_0_mask) class_1_rates = compute_rates(singleton_indicators, class_1_mask) return { "marginal": marginal_rates, "class_0": class_0_rates, "class_1": class_1_rates, "thresholds": {"threshold_0": threshold_0, "threshold_1": threshold_1}, "n_test": n_test, }