"""Unified rigorous reporting with full PAC guarantees.
This module provides a single comprehensive report that properly accounts for
coverage volatility across all operational metrics.
"""
from datetime import datetime
from typing import Any, cast
import numpy as np
import scipy
from ssbc import __version__
from ssbc._logging import get_logger
from ssbc.calibration import mondrian_conformal_calibrate, split_by_class
from ssbc.core_pkg import ssbc_correct
from ssbc.metrics import (
compute_pac_operational_bounds_marginal_loo_corrected,
compute_pac_operational_bounds_perclass_loo_corrected,
)
logger = get_logger(__name__)
[docs]
def generate_rigorous_pac_report(
labels: np.ndarray,
probs: np.ndarray,
alpha_target: float | dict[int, float] = 0.10,
delta: float | dict[int, float] = 0.10,
test_size: int | None = None,
ci_level: float = 0.95,
use_union_bound: bool = False,
n_jobs: int = -1,
verbose: bool = True,
prediction_method: str = "exact",
use_loo_correction: bool = True,
loo_inflation_factor: float | None = None,
) -> dict[str, Any]:
"""Generate complete rigorous PAC report with coverage volatility.
This is the UNIFIED function that gives you everything properly:
- SSBC-corrected thresholds
- Coverage guarantees
- PAC-controlled operational bounds (marginal + per-class)
- Singleton error rates with PAC guarantees
- All bounds account for coverage volatility via BetaBinomial
Parameters
----------
labels : np.ndarray, shape (n,)
True labels (0 or 1)
probs : np.ndarray, shape (n, 2)
Predicted probabilities [P(class=0), P(class=1)]
alpha_target : float or dict[int, float], default=0.10
Target miscoverage per class
delta : float or dict[int, float], default=0.10
PAC risk tolerance. Used for both:
- Coverage guarantee (via SSBC)
- Operational bounds (pac_level = 1 - delta)
test_size : int, optional
Expected test set size. If None, uses calibration size
ci_level : float, default=0.95
Confidence level for prediction bounds
prediction_method : str, default="hoeffding"
Method for LOO uncertainty quantification (when use_loo_correction=True):
- "auto": Automatically select best method
- "analytical": Method 1 (recommended for n>=40)
- "exact": Method 2 (recommended for n=20-40)
- "hoeffding": Method 3 (ultra-conservative, default)
- "all": Compare all methods
When use_loo_correction=False, this parameter is ignored.
use_loo_correction : bool, default=False
If True, uses LOO-CV uncertainty correction for small samples (n=20-40).
This accounts for all four sources of uncertainty:
1. LOO-CV correlation structure (variance inflation ≈2×)
2. Threshold calibration uncertainty
3. Parameter estimation uncertainty
4. Test sampling uncertainty
Recommended for small calibration sets where standard bounds may be too narrow.
**LOO-CV Correlation Issue**: The critical challenge with LOO-CV is that the N
LOO predictions are not independent. The training sets for different folds overlap
substantially—folds i and j using training sets D_{-i} and D_{-j} differ by only
two examples out of N−1. Because each fold's threshold is computed from nearly
identical data, the resulting predictions exhibit strong positive correlation.
This correlation structure is handled through specialized LOO-corrected methods
that account for the dependency between folds when computing diagnostic bounds.
loo_inflation_factor : float, optional
Manual override for LOO variance inflation factor. If None (default),
automatically estimated from the data using empirical variance.
**Empirical Correction Factor Estimation**: The inflation factor is estimated by
comparing the empirical variance of LOO predictions to the theoretical IID variance.
Specifically, inflation = (Var_empirical / Var_IID) × (n / (n-1)), where
Var_empirical is the sample variance of the binary LOO predictions (with Bessel's
correction), Var_IID = p̂(1-p̂) is the expected variance under independence, and
the n/(n-1) factor accounts for the finite-sample bias correction. For large n,
this approaches the theoretical value of 2.0, but for small samples (n=20-40),
the actual inflation can vary. The estimated factor is clipped to [1.0, 6.0] to
prevent extreme values from outliers or numerical instability.
Typical values:
- 1.0: No inflation (assumes independent samples - usually wrong for LOO)
- 2.0: Standard LOO inflation (theoretical value for n→∞)
- 1.5-2.5: Empirical range for small samples
- >2.5: High correlation scenarios
- Up to 6.0: Extended range for very high correlation scenarios
Note: This parameter can be used as a phenomenological control knob to
correct for issues not modeled properly in the statistical framework.
For example, if validation suggests the default estimation is too optimistic
or too conservative, manually adjusting this factor can help achieve desired
coverage behavior. Use with caution and validate empirically.
use_union_bound : bool, default=False
Apply Bonferroni for simultaneous guarantees
n_jobs : int, default=-1
Number of parallel jobs for LOO-CV computation.
-1 = use all cores (default), 1 = single-threaded, N = use N cores.
verbose : bool, default=True
Print comprehensive report
Returns
-------
dict
Complete report with keys:
- 'ssbc_class_0': SSBCResult for class 0
- 'ssbc_class_1': SSBCResult for class 1
- 'pac_bounds_marginal': PAC operational bounds (marginal)
- 'pac_bounds_class_0': PAC operational bounds (class 0)
- 'pac_bounds_class_1': PAC operational bounds (class 1)
- 'calibration_result': From mondrian_conformal_calibrate
- 'prediction_stats': From mondrian_conformal_calibrate
Examples
--------
>>> from ssbc import BinaryClassifierSimulator
>>> from ssbc.rigorous_report import generate_rigorous_pac_report
>>>
>>> sim = BinaryClassifierSimulator(p_class1=0.5, seed=42)
>>> labels, probs = sim.generate(n_samples=1000)
>>>
>>> report = generate_rigorous_pac_report(
... labels, probs,
... alpha_target=0.10,
... delta=0.10,
... verbose=True
... )
Notes
-----
**This replaces the old workflow (removed in v1.1.0):**
OLD (removed - these functions no longer exist):
```python
# These functions were removed in v1.1.0:
# op_bounds = compute_mondrian_operational_bounds(...) # Removed
# marginal_bounds = compute_marginal_operational_bounds(...) # Removed
# report_prediction_stats(...) # Removed
```
NEW (rigorous):
```python
report = generate_rigorous_pac_report(labels, probs, alpha_target, delta)
# Done! All bounds account for coverage volatility.
```
"""
# Comprehensive input validation
logger.info("Starting rigorous PAC report generation")
# Validate labels
if not isinstance(labels, np.ndarray):
raise TypeError(f"labels must be a numpy array, got {type(labels).__name__}")
if len(labels) < 2:
raise ValueError(f"Need at least 2 calibration samples, got {len(labels)}")
if labels.dtype.kind not in ("i", "u"):
raise ValueError(f"labels must be integer array, got dtype {labels.dtype}")
unique_labels = np.unique(labels)
if not np.all(np.isin(unique_labels, [0, 1])):
raise ValueError(
f"labels must contain only 0 and 1, found {unique_labels.tolist()}. "
"This function is for binary classification only."
)
# Validate probs
if not isinstance(probs, np.ndarray):
raise TypeError(f"probs must be a numpy array, got {type(probs).__name__}")
if probs.shape != (len(labels), 2):
raise ValueError(
f"probs must have shape ({len(labels)}, 2), got {probs.shape}. "
"Each row should contain [P(class=0), P(class=1)]."
)
if np.any((probs < 0) | (probs > 1)):
invalid_mask = (probs < 0) | (probs > 1)
invalid_count = np.sum(invalid_mask)
raise ValueError(
f"All probabilities must be in [0,1], found {invalid_count} invalid values. "
f"Invalid range: [{np.min(probs[invalid_mask]):.4f}, {np.max(probs[invalid_mask]):.4f}]"
)
if np.any(np.isnan(probs)):
nan_count = np.sum(np.isnan(probs))
raise ValueError(f"probs must not contain NaN values, found {nan_count} NaNs")
if np.any(np.isinf(probs)):
inf_count = np.sum(np.isinf(probs))
raise ValueError(f"probs must not contain Inf values, found {inf_count} Infs")
# Validate probability sum (should be approximately 1, allow small numerical errors)
prob_sums = np.sum(probs, axis=1)
if np.any(np.abs(prob_sums - 1.0) > 0.01):
bad_indices = np.where(np.abs(prob_sums - 1.0) > 0.01)[0]
raise ValueError(
f"Probabilities must sum to 1.0 for each sample, "
f"found {len(bad_indices)} samples with sums outside [0.99, 1.01]. "
f"Example sums: {prob_sums[bad_indices[:5]].tolist()}"
)
# Handle scalar inputs - convert to dict format
if isinstance(alpha_target, int | float):
if not (0.0 < float(alpha_target) < 1.0):
raise ValueError(f"alpha_target must be in (0,1), got {alpha_target}")
alpha_dict: dict[int, float] = {0: float(alpha_target), 1: float(alpha_target)}
else:
alpha_dict = cast(dict[int, float], alpha_target)
if not all(0.0 < v < 1.0 for v in alpha_dict.values()):
raise ValueError(f"All alpha_target values must be in (0,1), got {alpha_dict}")
if isinstance(delta, int | float):
if not (0.0 < float(delta) < 1.0):
raise ValueError(f"delta must be in (0,1), got {delta}")
delta_dict: dict[int, float] = {0: float(delta), 1: float(delta)}
else:
delta_dict = cast(dict[int, float], delta)
if not all(0.0 < v < 1.0 for v in delta_dict.values()):
raise ValueError(f"All delta values must be in (0,1), got {delta_dict}")
# Validate other parameters
if test_size is not None and test_size < 1:
raise ValueError(f"test_size must be >= 1, got {test_size}")
if not (0.0 < ci_level < 1.0):
raise ValueError(f"ci_level must be in (0,1), got {ci_level}")
if not isinstance(use_union_bound, bool):
raise TypeError(f"use_union_bound must be bool, got {type(use_union_bound).__name__}")
if not isinstance(verbose, bool):
raise TypeError(f"verbose must be bool, got {type(verbose).__name__}")
logger.debug(f"Input validation passed: n={len(labels)}, alpha_target={alpha_dict}, delta={delta_dict}")
# Split by class
class_data = split_by_class(labels, probs)
n_0 = class_data[0]["n"]
n_1 = class_data[1]["n"]
n_total = len(labels)
# Set test_size if not provided
if test_size is None:
test_size = n_total
# Derive PAC levels from delta values
# For marginal: use independence since split (n₀, n₁) is observed
# Pr(both coverage guarantees hold) = (1-δ₀)(1-δ₁)
pac_level_marginal = (1 - delta_dict[0]) * (1 - delta_dict[1])
pac_level_0 = 1 - delta_dict[0]
pac_level_1 = 1 - delta_dict[1]
# Step 1: Run SSBC for each class
ssbc_result_0 = ssbc_correct(alpha_target=alpha_dict[0], n=n_0, delta=delta_dict[0], mode="beta")
ssbc_result_1 = ssbc_correct(alpha_target=alpha_dict[1], n=n_1, delta=delta_dict[1], mode="beta")
# Step 2: Get calibration results (for thresholds and basic stats)
cal_result, pred_stats = mondrian_conformal_calibrate(
class_data=class_data, alpha_target=alpha_dict, delta=delta_dict, mode="beta"
)
# Step 3: Compute PAC operational bounds - MARGINAL (always LOO-corrected)
pac_bounds_marginal = compute_pac_operational_bounds_marginal_loo_corrected(
ssbc_result_0=ssbc_result_0,
ssbc_result_1=ssbc_result_1,
labels=labels,
probs=probs,
test_size=test_size,
ci_level=ci_level,
pac_level=pac_level_marginal,
use_union_bound=use_union_bound,
n_jobs=n_jobs,
prediction_method=prediction_method,
loo_inflation_factor=loo_inflation_factor,
verbose=verbose,
)
# Step 4: Compute PAC operational bounds - PER-CLASS (always LOO-corrected)
pac_bounds_class_0 = compute_pac_operational_bounds_perclass_loo_corrected(
ssbc_result_0=ssbc_result_0,
ssbc_result_1=ssbc_result_1,
labels=labels,
probs=probs,
class_label=0,
test_size=test_size,
ci_level=ci_level,
pac_level=pac_level_0,
use_union_bound=use_union_bound,
n_jobs=n_jobs,
prediction_method=prediction_method,
loo_inflation_factor=loo_inflation_factor,
verbose=verbose,
)
pac_bounds_class_1 = compute_pac_operational_bounds_perclass_loo_corrected(
ssbc_result_0=ssbc_result_0,
ssbc_result_1=ssbc_result_1,
labels=labels,
probs=probs,
class_label=1,
test_size=test_size,
ci_level=ci_level,
pac_level=pac_level_1,
use_union_bound=use_union_bound,
n_jobs=n_jobs,
prediction_method=prediction_method,
loo_inflation_factor=loo_inflation_factor,
verbose=verbose,
)
# Build comprehensive report dict (common to all paths)
# Build cleaned report with only essential information
report = {
# Essential SSBC results (return dataclasses as-is for tests)
"ssbc_class_0": ssbc_result_0,
"ssbc_class_1": ssbc_result_1,
"pac_bounds_marginal": pac_bounds_marginal,
"pac_bounds_class_0": pac_bounds_class_0,
"pac_bounds_class_1": pac_bounds_class_1,
# Calibration result as returned by mondrian_conformal_calibrate (keys 0 and 1)
"calibration_result": cal_result,
"prediction_stats": pred_stats,
"parameters": {
"alpha_target": alpha_dict,
"delta": delta_dict,
"test_size": test_size,
"ci_level": ci_level,
"pac_level_marginal": pac_level_marginal,
"pac_level_0": pac_level_0,
"pac_level_1": pac_level_1,
"use_union_bound": use_union_bound,
},
# Metadata for reproducibility
"metadata": {
"ssbc_version": __version__,
"numpy_version": np.__version__,
"scipy_version": scipy.__version__,
"timestamp": datetime.now().isoformat(),
"n_calibration": n_total,
"n_class_0": n_0,
"n_class_1": n_1,
"prediction_method": prediction_method,
"use_loo_correction": use_loo_correction,
"loo_inflation_factor": loo_inflation_factor,
},
}
logger.info(f"Report generated successfully: n={n_total}, n_0={n_0}, n_1={n_1}")
# Print comprehensive report if verbose
if verbose:
_print_rigorous_report(report)
return report
def _print_rigorous_report(report: dict) -> None:
"""Print comprehensive rigorous PAC report."""
cal_result = report["calibration_result"]
pred_stats = report["prediction_stats"]
params = report["parameters"]
print("=" * 80)
print("OPERATIONAL PAC-CONTROLLED CONFORMAL PREDICTION REPORT")
print("=" * 80)
print("\nParameters:")
print(f" Test size: {params['test_size']}")
print(f" CI level: {params['ci_level']:.0%} (Clopper-Pearson)")
pac_0 = params["pac_level_0"]
pac_1 = params["pac_level_1"]
delta_0 = 1.0 - pac_0
delta_1 = 1.0 - pac_1
print(" PAC guarantee levels:")
print(f" Class 0: δ = {delta_0:.2f} ({pac_0:.0%} confidence)")
print(f" Class 1: δ = {delta_1:.2f} ({pac_1:.0%} confidence)")
union_bound = params["use_union_bound"]
if union_bound:
print(" Union bound: applied across metrics (all metrics hold simultaneously)")
print(" Class guarantees: validated separately (no union bound across classes)")
else:
print(" Union bound: not applied (metrics validated independently)")
print(" Class guarantees: validated separately")
# Per-class reports
for class_label in [0, 1]:
ssbc = report[f"ssbc_class_{class_label}"]
pac = report[f"pac_bounds_class_{class_label}"]
cal = cal_result[class_label]
print("\n" + "=" * 80)
print(f"CLASS {class_label} (Conditioned on True Label = {class_label})")
print("=" * 80)
print(f" Calibration size: n = {ssbc.n}")
print(f" Target miscoverage: α = {params['alpha_target'][class_label]:.3f}")
print(f" SSBC-corrected α: α' = {ssbc.alpha_corrected:.4f}")
print(f" PAC risk: δ = {params['delta'][class_label]:.3f}")
print(f" Conformal threshold: {cal['threshold']:.4f}")
# Calibration data statistics
stats = pred_stats[class_label]
if "error" not in stats:
print(f"\n Calibration summary (n = {ssbc.n})")
print(" Empirical rates on calibration data. Intervals are 95% Clopper-Pearson.")
print(" These do not include PAC guarantees.")
# Abstentions
abst = stats["abstentions"]
print(
f" Abstentions: {abst['count']:4d} / {ssbc.n:4d} = "
f"{abst['proportion']:6.2%} 95% CI: [{abst['lower']:.3f}, {abst['upper']:.3f}]"
)
# Singletons
sing = stats["singletons"]
print(
f" Singletons: {sing['count']:4d} / {ssbc.n:4d} = "
f"{sing['proportion']:6.2%} 95% CI: [{sing['lower']:.3f}, {sing['upper']:.3f}]"
)
# Correct/incorrect singletons
sing_corr = stats["singletons_correct"]
print(
f" Correct: {sing_corr['count']:4d} / {ssbc.n:4d} = "
f"{sing_corr['proportion']:6.2%} 95% CI: [{sing_corr['lower']:.3f}, {sing_corr['upper']:.3f}]"
)
sing_incorr = stats["singletons_incorrect"]
print(
f" Incorrect: {sing_incorr['count']:4d} / {ssbc.n:4d} = "
f"{sing_incorr['proportion']:6.2%} 95% CI: [{sing_incorr['lower']:.3f}, {sing_incorr['upper']:.3f}]"
)
# Error | singleton
if sing["count"] > 0:
from ssbc.bounds import cp_interval
error_cond = cp_interval(sing_incorr["count"], sing["count"])
print(
f" Error | singleton: {sing_incorr['count']:4d} / {sing['count']:4d} = "
f"{error_cond['proportion']:6.2%} 95% CI: [{error_cond['lower']:.3f}, {error_cond['upper']:.3f}]"
)
# Doublets
doub = stats["doublets"]
print(
f" Doublets: {doub['count']:4d} / {ssbc.n:4d} = "
f"{doub['proportion']:6.2%} 95% CI: [{doub['lower']:.3f}, {doub['upper']:.3f}]"
)
print("\n Operational bounds for deployment")
pac_level_class = params[f"pac_level_{class_label}"]
if "loo_diagnostics" in pac:
print(
" Method: leave-one-out calibration at confidence 1-δ, plus binomial "
"predictive bounds for sampling variability."
)
else:
print(
" Method: leave-one-out calibration at confidence 1-δ, plus prediction "
"bounds for sampling uncertainty."
)
print(f" Threshold calibration level: {pac_level_class:.0%} (1-δ)")
print(f" Reported confidence level for bounds: {params['ci_level']:.0%}")
print(f" Grid points evaluated: {pac['n_grid_points']}")
# Helper to print bounds with method comparison
# Capture test_size in closure-safe way
test_size_for_methods = pac.get("test_size", params["test_size"])
def _print_rate_with_methods(rate_name: str, bounds: tuple, expected: float, diagnostics: dict | None = None):
"""Print rate bounds, showing method comparison if available."""
lower, upper = bounds
test_size = test_size_for_methods # noqa: B023 (captured in closure)
print(f"\n {rate_name}")
print(f" Point estimate: {expected:.3f}")
if diagnostics and "comparison" in diagnostics:
# Method comparison available
comp = diagnostics["comparison"]
selected = diagnostics.get("selected_method", "unknown")
print(f" Candidate bounds (95% predictive, n_test = {test_size}):")
for method_name, method_lower, method_upper, method_width in zip(
comp["method"], comp["lower"], comp["upper"], comp["width"], strict=False
):
# Replace method names for display
display_name = method_name.replace("Analytical", "Normal approximation")
display_name = display_name.replace("Exact Binomial", "Exact binomial predictive")
# Match selected method - handle both "exact" and "exact (auto-corrected)" cases
method_lower_name = method_name.lower().replace(" ", "_")
if "analytical" in method_lower_name and (
"analytical" in selected.lower() or selected.lower() == "analytical"
):
marker = "(retained)"
elif "exact" in method_lower_name and "exact" in selected.lower():
marker = "(retained)"
elif "hoeffding" in method_lower_name and "hoeffding" in selected.lower():
marker = "(retained)"
else:
marker = ""
print(
f" {display_name:25s} [{method_lower:.3f}, {method_upper:.3f}] "
f"width {method_width:.3f} {marker}"
)
print(f" Operational bounds: [{lower:.3f}, {upper:.3f}]")
else:
# Single method - show which method if available
method_info = diagnostics.get("selected_method", "") if diagnostics else ""
# Fallback to "method" key if "selected_method" not available
if not method_info and diagnostics and "method" in diagnostics:
method_name = diagnostics["method"]
# Convert internal method names to user-friendly names
method_map = {
"clopper_pearson_plus_sampling": "simple",
"beta_binomial_loo_corrected": "beta_binomial",
"hoeffding_distribution_free": "hoeffding",
}
method_info = method_map.get(method_name, method_name)
if method_info:
print(f" Method: {method_info}")
print(f" Operational bounds: [{lower:.3f}, {upper:.3f}]")
# Get diagnostics if available
loo_diag = pac.get("loo_diagnostics", {})
singleton_diag = loo_diag.get("singleton") if loo_diag else None
doublet_diag = loo_diag.get("doublet") if loo_diag else None
abstention_diag = loo_diag.get("abstention") if loo_diag else None
error_diag = loo_diag.get("singleton_error") if loo_diag else None
s_lower, s_upper = pac["singleton_rate_bounds"]
_print_rate_with_methods("Singleton rate", (s_lower, s_upper), pac["expected_singleton_rate"], singleton_diag)
d_lower, d_upper = pac["doublet_rate_bounds"]
_print_rate_with_methods("Doublet rate", (d_lower, d_upper), pac["expected_doublet_rate"], doublet_diag)
a_lower, a_upper = pac["abstention_rate_bounds"]
_print_rate_with_methods(
"Abstention rate", (a_lower, a_upper), pac["expected_abstention_rate"], abstention_diag
)
se_lower, se_upper = pac["singleton_error_rate_bounds"]
_print_rate_with_methods(
f"Conditional error rate given singleton (P(error | singleton, class = {class_label}))",
(se_lower, se_upper),
pac["expected_singleton_error_rate"],
error_diag,
)
# Singleton correct rate: P(correct | singleton, class) = 1 - P(error | singleton, class)
if "singleton_correct_rate_bounds" in pac:
sc_lower, sc_upper = pac["singleton_correct_rate_bounds"]
sc_expected = pac.get("expected_singleton_correct_rate", 1.0 - pac["expected_singleton_error_rate"])
# Use same diagnostics as error rate (they're complementary)
_print_rate_with_methods(
f"Conditional correct rate given singleton (P(correct | singleton, class = {class_label}))",
(sc_lower, sc_upper),
sc_expected,
error_diag,
)
# Note about per-class rates (all have random denominators)
print("\n Stability note:")
print(
f" All rates above (singleton, doublet, abstention, conditional error) "
f"are conditional on class {class_label}."
)
print(
f" Their denominators (number of class {class_label} samples in the test set) "
f"are random at deployment time."
)
print(" This induces extra variance and can bias the reported intervals.")
print(" For audit and Service Level Objective reporting, use the marginal rates")
print(" in the next section (normalized by total volume), which have a fixed denominator.")
# Marginal report
pac_marg = report["pac_bounds_marginal"]
marginal_stats = pred_stats["marginal"]
print("\n" + "=" * 80)
print("MARGINAL STATISTICS (deployment view; class labels not assumed known)")
print("=" * 80)
n_total = marginal_stats["n_total"]
print(f" Total samples: n = {n_total}")
# Calibration data statistics (marginal)
print(f"\n Calibration summary (n = {n_total})")
print(" Empirical rates on calibration data. Intervals are 95% Clopper-Pearson.")
print(" No PAC guarantees.")
# Coverage
cov = marginal_stats["coverage"]
print(
f" Coverage (prediction set contains true label): {cov['count']:4d} / {n_total:4d} = "
f"{cov['rate']:6.2%} 95% CI: [{cov['ci_95']['lower']:.3f}, {cov['ci_95']['upper']:.3f}]"
)
# Abstentions
abst = marginal_stats["abstentions"]
print(
f" Abstentions: {abst['count']:4d} / {n_total:4d} = "
f"{abst['proportion']:6.2%} 95% CI: [{abst['lower']:.3f}, {abst['upper']:.3f}]"
)
# Singletons
sing = marginal_stats["singletons"]
print(
f" Singletons: {sing['count']:4d} / {n_total:4d} = "
f"{sing['proportion']:6.2%} 95% CI: [{sing['lower']:.3f}, {sing['upper']:.3f}]"
)
# Singleton errors
if sing["count"] > 0:
from ssbc.bounds import cp_interval
error_cond_marg = cp_interval(sing["errors"], sing["count"])
err_prop = error_cond_marg["proportion"]
err_lower = error_cond_marg["lower"]
err_upper = error_cond_marg["upper"]
print(
f" Errors: {sing['errors']:4d} / {sing['count']:4d} = "
f"{err_prop:6.2%} 95% CI: [{err_lower:.3f}, {err_upper:.3f}]"
)
# Doublets
doub = marginal_stats["doublets"]
print(
f" Doublets: {doub['count']:4d} / {n_total:4d} = "
f"{doub['proportion']:6.2%} 95% CI: [{doub['lower']:.3f}, {doub['upper']:.3f}]"
)
print("\n Operational bounds for deployment")
print(" Class-specific rates (normalized by total test set size):")
print(" These are JOINT probabilities measuring operational events across the full test set.")
print(" ")
print(" Singleton rates:")
print(" - Definition: P(true_label=class, prediction_set=singleton)")
print(" - Count samples where TRUE label = class AND prediction set = singleton")
print(" - Example: 'Singleton rate (Class 0)' = P(Y=0, S=singleton)")
print(" Meaning: Among all test samples, what fraction have:")
print(" • True label is class 0")
print(" • Prediction set is singleton (can be {{0}} or {{1}})")
print(" This includes BOTH correct singletons (predicted {{0}}) and")
print(" incorrect singletons (predicted {{1}} when true label is 0).")
print(" ")
print(" Doublet rates:")
print(" - Definition: P(true_label=class, prediction_set=doublet)")
print(" - Count samples where TRUE label = class AND prediction set = {{0,1}}")
print(" - Example: 'Doublet rate (Class 0)' = P(Y=0, S=doublet)")
print(" Meaning: Among all test samples, what fraction have:")
print(" • True label is class 0")
print(" • Prediction set is doublet (contains both {{0, 1}})")
print(" Doublets always contain the true label (by coverage guarantee).")
print(" ")
print(" Abstention rates:")
print(" - Definition: P(true_label=class, prediction_set=empty)")
print(" - Count samples where TRUE label = class AND prediction set = {{}}")
print(" - Example: 'Abstention rate (Class 0)' = P(Y=0, S=abstention)")
print(" Meaning: Among all test samples, what fraction have:")
print(" • True label is class 0")
print(" • Prediction set is empty (abstention/rejection)")
print(" Abstentions indicate the model is uncertain and rejects the sample.")
print(" ")
print(" Error rates (normalized by total, conditioned on true label):")
print(" - Definition: P(true_label=class, prediction_set=singleton, error=1)")
print(" - Count samples where TRUE label = class AND singleton AND prediction is wrong")
print(" - Example: 'Error rate (Class 0 singletons)' = P(Y=0, S=singleton, E=1)")
print(" Meaning: Among all test samples, what fraction have:")
print(" • True label is class 0")
print(" • Prediction set is singleton (single class predicted)")
print(" • Prediction is INCORRECT (predicted class ≠ true label)")
print(" ")
print(" Correct rates (normalized by total, conditioned on true label):")
print(" - Definition: P(true_label=class, prediction_set=singleton, error=0)")
print(" - Count samples where TRUE label = class AND singleton AND prediction is correct")
print(" - Example: 'Correct rate (Class 0 singletons)' = P(Y=0, S=singleton, E=0)")
print(" Meaning: Among all test samples, what fraction have:")
print(" • True label is class 0")
print(" • Prediction set is singleton (single class predicted)")
print(" • Prediction is CORRECT (predicted class = true label)")
print(" ")
print(" Error rates (normalized by total, conditioned on predicted class):")
print(" - Definition: P(predicted_class=X, prediction_set=singleton, error=1)")
print(" - Count samples where PREDICTED class = X AND singleton AND prediction is wrong")
print(" - Example: 'Error rate (when singleton predicted as Class 0)' = P(predicted=0, S=singleton, E=1)")
print(" Meaning: Among all test samples, what fraction have:")
print(" • Prediction set is singleton with predicted class = 0")
print(" • Prediction is INCORRECT (predicted class ≠ true label)")
print(" This answers: 'If I predict class 0, how often am I wrong?'")
print(" ")
print(" Correct rates (normalized by total, conditioned on predicted class):")
print(" - Definition: P(predicted_class=X, prediction_set=singleton, error=0)")
print(" - Count samples where PREDICTED class = X AND singleton AND prediction is correct")
print(" - Example: 'Correct rate (when singleton predicted as Class 0)' = P(predicted=0, S=singleton, E=0)")
print(" Meaning: Among all test samples, what fraction have:")
print(" • Prediction set is singleton with predicted class = 0")
print(" • Prediction is CORRECT (predicted class = true label)")
print(" This answers: 'If I predict class 0, how often am I correct?'")
print(" ")
print(" Relationship:")
print(" - singleton_rate = error_rate + correct_rate (for same class)")
print(" - All rates normalized by total test set size (fixed denominator)")
print(" ")
print(" All rates use fixed denominator (total test set size) for deployment planning.")
ci_lvl = params["ci_level"]
print(f" Reported confidence level for bounds: {ci_lvl:.0%}")
# Helper to print bounds with method comparison (reused for marginal)
def _print_rate_with_methods_marginal(
rate_name: str, bounds: tuple, expected: float, diagnostics: dict | None = None
):
"""Print rate bounds, showing method comparison if available."""
lower, upper = bounds
test_size = pac_marg.get("test_size", params["test_size"])
print(f"\n {rate_name}")
print(f" Point estimate: {expected:.3f}")
if diagnostics and "comparison" in diagnostics:
# Method comparison available
comp = diagnostics["comparison"]
selected = diagnostics.get("selected_method", "unknown")
print(f" Candidate bounds (95% predictive, n_test = {test_size}):")
for method_name, method_lower, method_upper, method_width in zip(
comp["method"], comp["lower"], comp["upper"], comp["width"], strict=False
):
# Replace method names for display
display_name = method_name.replace("Analytical", "Normal approximation")
display_name = display_name.replace("Exact Binomial", "Exact binomial predictive")
# Match selected method - handle both "exact" and "exact (auto-corrected)" cases
method_lower_name = method_name.lower().replace(" ", "_")
if "analytical" in method_lower_name and (
"analytical" in selected.lower() or selected.lower() == "analytical"
):
marker = "(retained)"
elif "exact" in method_lower_name and "exact" in selected.lower():
marker = "(retained)"
elif "hoeffding" in method_lower_name and "hoeffding" in selected.lower():
marker = "(retained)"
else:
marker = ""
print(
f" {display_name:25s} [{method_lower:.3f}, {method_upper:.3f}] "
f"width {method_width:.3f} {marker}"
)
print(f" Operational bounds: [{lower:.3f}, {upper:.3f}]")
else:
# Single method - show which method if available
method_info = diagnostics.get("selected_method", "") if diagnostics else ""
# Fallback to "method" key if "selected_method" not available
if not method_info and diagnostics and "method" in diagnostics:
method_name = diagnostics["method"]
# Convert internal method names to user-friendly names
method_map = {
"clopper_pearson_plus_sampling": "simple",
"beta_binomial_loo_corrected": "beta_binomial",
"hoeffding_distribution_free": "hoeffding",
}
method_info = method_map.get(method_name, method_name)
if method_info:
print(f" Method: {method_info}")
print(f" Operational bounds: [{lower:.3f}, {upper:.3f}]")
# Get diagnostics if available (for class-specific rates and error rates)
loo_diag_marg = pac_marg.get("loo_diagnostics", {})
singleton_class0_diag_marg = loo_diag_marg.get("singleton_class0") if loo_diag_marg else None
singleton_class1_diag_marg = loo_diag_marg.get("singleton_class1") if loo_diag_marg else None
doublet_class0_diag_marg = loo_diag_marg.get("doublet_class0") if loo_diag_marg else None
doublet_class1_diag_marg = loo_diag_marg.get("doublet_class1") if loo_diag_marg else None
abstention_class0_diag_marg = loo_diag_marg.get("abstention_class0") if loo_diag_marg else None
abstention_class1_diag_marg = loo_diag_marg.get("abstention_class1") if loo_diag_marg else None
error_class0_diag_marg = loo_diag_marg.get("singleton_error_class0") if loo_diag_marg else None
error_class1_diag_marg = loo_diag_marg.get("singleton_error_class1") if loo_diag_marg else None
correct_class0_diag_marg = loo_diag_marg.get("singleton_correct_class0") if loo_diag_marg else None
correct_class1_diag_marg = loo_diag_marg.get("singleton_correct_class1") if loo_diag_marg else None
error_pred_class0_diag_marg = loo_diag_marg.get("singleton_error_pred_class0") if loo_diag_marg else None
error_pred_class1_diag_marg = loo_diag_marg.get("singleton_error_pred_class1") if loo_diag_marg else None
correct_pred_class0_diag_marg = loo_diag_marg.get("singleton_correct_pred_class0") if loo_diag_marg else None
correct_pred_class1_diag_marg = loo_diag_marg.get("singleton_correct_pred_class1") if loo_diag_marg else None
# Class-specific singleton rates (normalized against full dataset)
# These are operationally meaningful for deployment planning
if "singleton_rate_class0_bounds" in pac_marg:
s_class0_lower, s_class0_upper = pac_marg["singleton_rate_class0_bounds"]
s_class0_expected = pac_marg.get("expected_singleton_rate_class0", 0.0)
_print_rate_with_methods_marginal(
"Singleton rate (Class 0, normalized by total)",
(s_class0_lower, s_class0_upper),
s_class0_expected,
singleton_class0_diag_marg,
)
if "singleton_rate_class1_bounds" in pac_marg:
s_class1_lower, s_class1_upper = pac_marg["singleton_rate_class1_bounds"]
s_class1_expected = pac_marg.get("expected_singleton_rate_class1", 0.0)
_print_rate_with_methods_marginal(
"Singleton rate (Class 1, normalized by total)",
(s_class1_lower, s_class1_upper),
s_class1_expected,
singleton_class1_diag_marg,
)
# Class-specific doublet rates (normalized against full dataset)
if "doublet_rate_class0_bounds" in pac_marg:
d_class0_lower, d_class0_upper = pac_marg["doublet_rate_class0_bounds"]
d_class0_expected = pac_marg.get("expected_doublet_rate_class0", 0.0)
_print_rate_with_methods_marginal(
"Doublet rate (Class 0, normalized by total)",
(d_class0_lower, d_class0_upper),
d_class0_expected,
doublet_class0_diag_marg,
)
if "doublet_rate_class1_bounds" in pac_marg:
d_class1_lower, d_class1_upper = pac_marg["doublet_rate_class1_bounds"]
d_class1_expected = pac_marg.get("expected_doublet_rate_class1", 0.0)
_print_rate_with_methods_marginal(
"Doublet rate (Class 1, normalized by total)",
(d_class1_lower, d_class1_upper),
d_class1_expected,
doublet_class1_diag_marg,
)
# Class-specific abstention rates (normalized against full dataset)
if "abstention_rate_class0_bounds" in pac_marg:
a_class0_lower, a_class0_upper = pac_marg["abstention_rate_class0_bounds"]
a_class0_expected = pac_marg.get("expected_abstention_rate_class0", 0.0)
_print_rate_with_methods_marginal(
"Abstention rate (Class 0, normalized by total)",
(a_class0_lower, a_class0_upper),
a_class0_expected,
abstention_class0_diag_marg,
)
if "abstention_rate_class1_bounds" in pac_marg:
a_class1_lower, a_class1_upper = pac_marg["abstention_rate_class1_bounds"]
a_class1_expected = pac_marg.get("expected_abstention_rate_class1", 0.0)
_print_rate_with_methods_marginal(
"Abstention rate (Class 1, normalized by total)",
(a_class1_lower, a_class1_upper),
a_class1_expected,
abstention_class1_diag_marg,
)
# Class-specific error rates (normalized against full dataset)
# Note: We do NOT report marginal singleton_error because it mixes two different
# distributions (class 0 and class 1) which cannot be justified statistically.
if "singleton_error_rate_class0_bounds" in pac_marg:
se_class0_lower, se_class0_upper = pac_marg["singleton_error_rate_class0_bounds"]
se_class0_expected = pac_marg.get("expected_singleton_error_rate_class0", 0.0)
_print_rate_with_methods_marginal(
"Error rate (Class 0 singletons, normalized by total)",
(se_class0_lower, se_class0_upper),
se_class0_expected,
error_class0_diag_marg,
)
if "singleton_error_rate_class1_bounds" in pac_marg:
se_class1_lower, se_class1_upper = pac_marg["singleton_error_rate_class1_bounds"]
se_class1_expected = pac_marg.get("expected_singleton_error_rate_class1", 0.0)
_print_rate_with_methods_marginal(
"Error rate (Class 1 singletons, normalized by total)",
(se_class1_lower, se_class1_upper),
se_class1_expected,
error_class1_diag_marg,
)
# Class-specific singleton correct rates (normalized against full dataset)
if "singleton_correct_rate_class0_bounds" in pac_marg:
sc_class0_lower, sc_class0_upper = pac_marg["singleton_correct_rate_class0_bounds"]
sc_class0_expected = pac_marg.get("expected_singleton_correct_rate_class0", 0.0)
_print_rate_with_methods_marginal(
"Correct rate (Class 0 singletons, normalized by total)",
(sc_class0_lower, sc_class0_upper),
sc_class0_expected,
correct_class0_diag_marg,
)
if "singleton_correct_rate_class1_bounds" in pac_marg:
sc_class1_lower, sc_class1_upper = pac_marg["singleton_correct_rate_class1_bounds"]
sc_class1_expected = pac_marg.get("expected_singleton_correct_rate_class1", 0.0)
_print_rate_with_methods_marginal(
"Correct rate (Class 1 singletons, normalized by total)",
(sc_class1_lower, sc_class1_upper),
sc_class1_expected,
correct_class1_diag_marg,
)
# Error rates when singleton is assigned to a specific class (normalized against full dataset)
if "singleton_error_rate_pred_class0_bounds" in pac_marg:
se_pred_class0_lower, se_pred_class0_upper = pac_marg["singleton_error_rate_pred_class0_bounds"]
se_pred_class0_expected = pac_marg.get("expected_singleton_error_rate_pred_class0", 0.0)
_print_rate_with_methods_marginal(
"Error rate (when singleton predicted as Class 0, normalized by total)",
(se_pred_class0_lower, se_pred_class0_upper),
se_pred_class0_expected,
error_pred_class0_diag_marg,
)
if "singleton_error_rate_pred_class1_bounds" in pac_marg:
se_pred_class1_lower, se_pred_class1_upper = pac_marg["singleton_error_rate_pred_class1_bounds"]
se_pred_class1_expected = pac_marg.get("expected_singleton_error_rate_pred_class1", 0.0)
_print_rate_with_methods_marginal(
"Error rate (when singleton predicted as Class 1, normalized by total)",
(se_pred_class1_lower, se_pred_class1_upper),
se_pred_class1_expected,
error_pred_class1_diag_marg,
)
# Correct rates when singleton is assigned to a specific class (normalized against full dataset)
if "singleton_correct_rate_pred_class0_bounds" in pac_marg:
sc_pred_class0_lower, sc_pred_class0_upper = pac_marg["singleton_correct_rate_pred_class0_bounds"]
sc_pred_class0_expected = pac_marg.get("expected_singleton_correct_rate_pred_class0", 0.0)
_print_rate_with_methods_marginal(
"Correct rate (when singleton predicted as Class 0, normalized by total)",
(sc_pred_class0_lower, sc_pred_class0_upper),
sc_pred_class0_expected,
correct_pred_class0_diag_marg,
)
if "singleton_correct_rate_pred_class1_bounds" in pac_marg:
sc_pred_class1_lower, sc_pred_class1_upper = pac_marg["singleton_correct_rate_pred_class1_bounds"]
sc_pred_class1_expected = pac_marg.get("expected_singleton_correct_rate_pred_class1", 0.0)
_print_rate_with_methods_marginal(
"Correct rate (when singleton predicted as Class 1, normalized by total)",
(sc_pred_class1_lower, sc_pred_class1_upper),
sc_pred_class1_expected,
correct_pred_class1_diag_marg,
)
# Deployment expectations are not reported at marginal level since singleton/doublet
# rates are derived from class-specific rates (already reported in CLASS 0/1 sections).
print("\n" + "=" * 80)