Source code for ssbc.core_pkg.core

"""Core SSBC (Small-Sample Beta Correction) algorithm."""

import math
from dataclasses import dataclass
from typing import Any, Literal

import numpy as np
from scipy.stats import beta as beta_dist
from scipy.stats import betabinom, norm

from ssbc._logging import get_logger

__all__ = ["SSBCResult", "ssbc_correct"]

logger = get_logger(__name__)


[docs] @dataclass class SSBCResult: """Result of SSBC correction. Attributes: alpha_target: Target miscoverage rate alpha_corrected: Corrected miscoverage rate (u_star / (n+1)) u_star: Optimal u value found by the algorithm n: Calibration set size satisfied_mass: Probability that coverage >= target mode: "beta" for infinite test window, "beta-binomial" for finite details: Additional diagnostic information """ alpha_target: float alpha_corrected: float u_star: int n: int satisfied_mass: float mode: Literal["beta", "beta-binomial"] details: dict[str, Any]
[docs] def ssbc_correct( alpha_target: float, n: int, delta: float, *, mode: Literal["beta", "beta-binomial"] = "beta", m: int | None = None, bracket_width: int | None = None, ) -> SSBCResult: """Small-Sample Beta Correction (SSBC), corrected acceptance rule. Find the largest α' = u/(n+1) ≤ α_target such that: P(Coverage(α') ≥ 1 - α_target) ≥ 1 - δ where Coverage(α') ~ Beta(n+1-u, u) for infinite test window. Trivial regime: if α_target < 1/(n+1), return α_corrected=0. Parameters ---------- alpha_target : float Target miscoverage rate (must be in (0,1)) n : int Calibration set size (must be >= 1) delta : float PAC risk tolerance (must be in (0,1)). This is the probability that the coverage guarantee fails. For example, delta=0.10 means we want a 90% PAC confidence (1-delta) that coverage ≥ target. mode : {"beta", "beta-binomial"}, default="beta" "beta" for infinite test window "beta-binomial" for finite test window (defaults to m=n) m : int, optional Test window size for beta-binomial mode (defaults to n) bracket_width : int, optional Search radius around initial guess (default: adaptive based on n) Returns ------- SSBCResult Dataclass containing correction results and diagnostic details Raises ------ ValueError If parameters are out of valid ranges Examples -------- >>> result = ssbc_correct(alpha_target=0.10, n=50, delta=0.10) >>> print(f"Corrected alpha: {result.alpha_corrected:.4f}") Notes ----- The algorithm uses a bracketed search with an initial guess based on normal approximation to the Beta distribution. If the initial bracket fails to find a solution, it performs adaptive outward expansion (downward then upward) with O(n) worst-case complexity. """ # Input validation with detailed error messages if not isinstance(alpha_target, int | float): raise TypeError(f"alpha_target must be numeric, got {type(alpha_target).__name__}") if not (0.0 < alpha_target < 1.0): raise ValueError( f"alpha_target must be in (0,1), got {alpha_target}. " "This represents the target miscoverage rate (e.g., 0.10 for 90% coverage)." ) # Accept both Python int and numpy integer types if not isinstance(n, int | np.integer) or n < 1: raise ValueError( f"n must be a positive integer >= 1, got {n} (type: {type(n).__name__}). This is the calibration set size." ) # Convert to Python int for consistency n = int(n) # Require minimum calibration size for reliable results MIN_REQUIRED_N = 10 if n < MIN_REQUIRED_N: raise ValueError( f"Calibration set size n={n} is too small (required: n >= {MIN_REQUIRED_N}). " "SSBC requires at least 10 calibration samples for reliable PAC guarantees. " "Please collect more calibration data." ) if not isinstance(delta, int | float): raise TypeError(f"delta must be numeric, got {type(delta).__name__}") if not (0.0 < delta < 1.0): raise ValueError( f"delta must be in (0,1), got {delta}. This is the PAC risk tolerance (e.g., 0.10 for 90% PAC confidence)." ) if mode not in ("beta", "beta-binomial"): raise ValueError( f"mode must be 'beta' or 'beta-binomial', got {mode!r}. " "'beta' is for infinite test window, 'beta-binomial' for finite test window." ) # Maximum u to search (α' must be ≤ α_target) u_max = min(n, math.floor(alpha_target * (n + 1))) # Handle beta-binomial mode setup if mode == "beta-binomial": m_eval = m if m is not None else n if m_eval < 1: raise ValueError("m must be >= 1 for beta-binomial mode.") # Trivial regime: if α_target < 1/(n+1), no positive u is allowed. # Return u=0, α_corrected=0, with satisfied mass = 1.0 by construction. if u_max == 0: return SSBCResult( alpha_target=alpha_target, alpha_corrected=0.0, u_star=0, n=n, satisfied_mass=1.0, mode=mode, details=dict( u_max=u_max, u_star_guess=0, search_range=(0, 0), bracket_width=0, delta=delta, m=(m_eval if (mode == "beta-binomial") else None), acceptance_rule="P(Coverage >= target) >= 1-delta", search_log=[], note="α_target < 1/(n+1) ⇒ α_corrected=0", ), ) target_coverage = 1 - alpha_target # Initial guess for u using normal approximation to Beta distribution # We want P(Beta(n+1-u, u) >= target_coverage) ≈ 1-δ # Using normal approximation: u ≈ u_target - z_δ * sqrt(u_target) # where u_target = (n+1)*α_target and z_δ = Φ^(-1)(1-δ) u_target = (n + 1) * alpha_target z_delta = norm.ppf(1 - delta) # quantile function (inverse CDF) u_star_guess = max(1, math.floor(u_target - z_delta * math.sqrt(max(u_target, 1e-12)))) # Clamp to valid range u_star_guess = min(u_max, u_star_guess) # Bracket width (Δ in Algorithm 1) if bracket_width is None: # Adaptive bracket: wider for small n, scales with √n for large n # For large n, the uncertainty scales as √u_target ~ (n*α)^(1/2) bracket_width = max(5, min(int(2 * z_delta * math.sqrt(u_target)), n // 10)) bracket_width = min(bracket_width, 100) # cap at 100 for efficiency # Search bounds - ensure we don't go outside [1, u_max] u_min = max(1, u_star_guess - bracket_width) u_search_max = min(u_max, u_star_guess + bracket_width) # If the guess is way off (e.g., guess > u_max), fall back to full search if u_min > u_search_max: u_min = 1 u_search_max = u_max if mode == "beta-binomial": k_thresh = math.ceil(target_coverage * m_eval) u_star: int | None = None mass_star: float | None = None # Search from u_min up to u_search_max to find the largest u that satisfies the condition # Keep updating u_star as we find larger values that work search_log = [] for u in range(u_min, u_search_max + 1): # When we calibrate at α' = u/(n+1), coverage follows: a = n + 1 - u # first parameter b = u # second parameter alpha_prime = u / (n + 1) if mode == "beta": # Use survival function for numerical stability near x≈1 ptail = float(beta_dist.sf(target_coverage, a, b)) else: # P(X ≥ k_thresh) where X ~ BetaBinomial(m, a, b) ptail = float(betabinom.sf(k_thresh - 1, m_eval, a, b)) passes = ptail >= 1 - delta search_log.append( { "u": u, "alpha_prime": alpha_prime, "a": a, "b": b, "ptail": ptail, "threshold": 1 - delta, "passes": passes, } ) # Accept if probability is high enough - keep updating to find the largest if passes: u_star = u mass_star = ptail # If nothing passes in the initial bracket, expand outward adaptively. if u_star is None: # Downward expansion for u in range(u_min - 1, 0, -1): a = n + 1 - u b = u if mode == "beta": ptail = float(beta_dist.sf(target_coverage, a, b)) else: ptail = float(betabinom.sf(k_thresh - 1, m_eval, a, b)) if ptail >= 1 - delta: u_star, mass_star = u, ptail break # Upward expansion if (u_star is None) and (u_search_max < u_max): for u in range(u_search_max + 1, u_max + 1): a = n + 1 - u b = u if mode == "beta": ptail = float(beta_dist.sf(target_coverage, a, b)) else: ptail = float(betabinom.sf(k_thresh - 1, m_eval, a, b)) if ptail >= 1 - delta: u_star, mass_star = u, ptail else: # stop at first failure above; tail typically decreases break # If still nothing passes, choose the most conservative admissible u (0). if u_star is None: u_star = 0 mass_star = 1.0 alpha_corrected = u_star / (n + 1) # At this point, mass_star is always set (either from loop or fallback) assert mass_star is not None, "mass_star should be set by this point" return SSBCResult( alpha_target=alpha_target, alpha_corrected=alpha_corrected, u_star=u_star, n=n, satisfied_mass=mass_star, mode=mode, details=dict( u_max=u_max, u_star_guess=u_star_guess, search_range=(u_min, u_search_max), bracket_width=bracket_width, delta=delta, m=(m_eval if (mode == "beta-binomial") else None), acceptance_rule="P(Coverage >= target) >= 1-delta", search_log=search_log, ), )