Source code for ssbc.reporting.visualization

"""Visualization and reporting utilities for conformal prediction results."""

from typing import Any

from ssbc.bounds import cp_interval


[docs] def report_prediction_stats( prediction_stats: dict[Any, Any], calibration_result: dict[Any, Any], operational_bounds_per_class: dict[int, Any] | None = None, marginal_operational_bounds: Any | None = None, verbose: bool = True, ) -> dict[str | int, Any]: """Report rigorous statistics for Mondrian conformal prediction with valid CIs. Only displays statistics with valid confidence intervals: - Per-class statistics from calibration data (valid within class) - Per-class operational bounds from cross-validation (rigorous PAC bounds) - Marginal operational bounds from cross-validated Mondrian (rigorous PAC bounds) Does NOT display marginal statistics from calibration data (invalid CIs for Mondrian). Parameters ---------- prediction_stats : dict Output from mondrian_conformal_calibrate (second return value) calibration_result : dict Output from mondrian_conformal_calibrate (first return value) operational_bounds_per_class : dict[int, OperationalRateBoundsResult], optional Per-class operational bounds (from generate_rigorous_pac_report) marginal_operational_bounds : OperationalRateBoundsResult, optional Marginal operational bounds (from generate_rigorous_pac_report) verbose : bool, default=True If True, print detailed statistics to stdout Returns ------- dict Structured summary with valid CIs: - Keys 0, 1 for per-class statistics - Key 'marginal_bounds' if marginal_operational_bounds provided Examples -------- >>> # Get operational bounds from rigorous PAC report >>> from ssbc import generate_rigorous_pac_report >>> report = generate_rigorous_pac_report(labels, probs, alpha_target=0.10, delta=0.10) >>> cal_result = report['calibration_result'] >>> pred_stats = report['prediction_stats'] >>> op_bounds = report['pac_bounds_class_0'] # Per-class bounds >>> marginal = report['pac_bounds_marginal'] # Marginal bounds >>> summary = report_prediction_stats(pred_stats, cal_result, op_bounds, marginal) """ summary: dict[str | int, Any] = {} if verbose: print("=" * 80) print("MONDRIAN CONFORMAL PREDICTION REPORT") print("=" * 80) # ==================== PER-CLASS STATISTICS ==================== for class_label in sorted([k for k in prediction_stats.keys() if isinstance(k, int)]): cls = prediction_stats[class_label] if isinstance(cls, dict) and "error" in cls: if verbose: print(f"\nCLASS {class_label}: {cls['error']}") summary[class_label] = {"error": cls["error"]} continue n = int(cls.get("n", cls.get("n_class", 0))) if n == 0: continue # Get calibration info cal = calibration_result.get(class_label, {}) alpha_target = cal.get("alpha_target") alpha_corrected = cal.get("alpha_corrected") delta = cal.get("delta") threshold = cal.get("threshold") if verbose: print(f"\n{'=' * 80}") print(f"CLASS {class_label} (Conditioned on True Label = {class_label})") print(f"{'=' * 80}") print(f" Calibration size: n = {n}") if alpha_target is not None: print(f" Target miscoverage: α = {alpha_target:.3f}") if alpha_corrected is not None: print(f" SSBC-corrected α: α' = {alpha_corrected:.4f}") if delta is not None: print(f" PAC risk: δ = {delta:.3f}") if threshold is not None: print(f" Conformal threshold: {threshold:.4f}") # Per-class stats from calibration data (VALID - exchangeable within class) if verbose: print(f"\n 📊 Statistics from Calibration Data (n={n}):") print(" [Basic CP CIs without PAC guarantee - evaluated on calibration data]") # Abstentions abstentions = cls.get("abstentions", {}) if isinstance(abstentions, dict): abst_count = abstentions.get("count", 0) abst_ci = cp_interval(abst_count, n) if verbose: print( f" Abstentions: {abst_count:4d} / {n:4d} = {abst_ci['proportion']:6.2%} " f"95% CI: [{abst_ci['lower']:.3f}, {abst_ci['upper']:.3f}]" ) # Singletons (note: singletons_correct/incorrect are at top level, not nested) singletons = cls.get("singletons", {}) singletons_correct = cls.get("singletons_correct", {}) singletons_incorrect = cls.get("singletons_incorrect", {}) if isinstance(singletons, dict): sing_count = singletons.get("count", 0) sing_correct = singletons_correct.get("count", 0) if isinstance(singletons_correct, dict) else 0 sing_incorrect = singletons_incorrect.get("count", 0) if isinstance(singletons_incorrect, dict) else 0 # Compute valid CIs (exchangeable within class) sing_ci = cp_interval(sing_count, n) sing_corr_ci = cp_interval(sing_correct, n) sing_inc_ci = cp_interval(sing_incorrect, n) if verbose: print( f" Singletons: {sing_count:4d} / {n:4d} = {sing_ci['proportion']:6.2%} " f"95% CI: [{sing_ci['lower']:.3f}, {sing_ci['upper']:.3f}]" ) print( f" Correct: {sing_correct:4d} / {n:4d} = {sing_corr_ci['proportion']:6.2%} " f"95% CI: [{sing_corr_ci['lower']:.3f}, {sing_corr_ci['upper']:.3f}]" ) print( f" Incorrect: {sing_incorrect:4d} / {n:4d} = {sing_inc_ci['proportion']:6.2%} " f"95% CI: [{sing_inc_ci['lower']:.3f}, {sing_inc_ci['upper']:.3f}]" ) # Error rate given singleton if sing_count > 0: err_given_sing = cp_interval(sing_incorrect, sing_count) print( f" Error | singleton: {sing_incorrect:4d} / {sing_count:4d} = " f"{err_given_sing['proportion']:6.2%} " f"95% CI: [{err_given_sing['lower']:.3f}, {err_given_sing['upper']:.3f}]" ) # Doublets doublets = cls.get("doublets", {}) if isinstance(doublets, dict): doub_count = doublets.get("count", 0) doub_ci = cp_interval(doub_count, n) if verbose: print( f" Doublets: {doub_count:4d} / {n:4d} = {doub_ci['proportion']:6.2%} " f"95% CI: [{doub_ci['lower']:.3f}, {doub_ci['upper']:.3f}]" ) # PAC bounds (ρ, κ, α'_bound) - important theoretical guarantees pac_bounds = cls.get("pac_bounds", {}) if isinstance(pac_bounds, dict) and pac_bounds.get("rho") is not None: if verbose: print(f"\n 📐 PAC Singleton Error Bound (δ={delta:.3f}):") print(f" ρ = {pac_bounds.get('rho', 0):.3f}, κ = {pac_bounds.get('kappa', 0):.3f}") if "alpha_singlet_bound" in pac_bounds and "alpha_singlet_observed" in pac_bounds: bound = float(pac_bounds["alpha_singlet_bound"]) observed = float(pac_bounds["alpha_singlet_observed"]) ok = "✓" if observed <= bound else "✗" print(f" α'_bound: {bound:.4f}") print(f" α'_observed: {observed:.4f} {ok}") # Operational bounds (RIGOROUS - cross-validated with PAC guarantees) if operational_bounds_per_class and class_label in operational_bounds_per_class: op_bounds = operational_bounds_per_class[class_label] if verbose: print("\n ✅ RIGOROUS Operational Bounds (LOO-CV)") print(f" CI width: {op_bounds.ci_width:.1%}") print(f" Calibration size: n = {op_bounds.n_calibration}") # Show main rates (singleton, doublet, abstention) for rate_name in ["abstention", "singleton", "doublet"]: if rate_name in op_bounds.rate_bounds: bounds = op_bounds.rate_bounds[rate_name] if verbose: print(f"\n {rate_name.upper()}:") print(f" Bounds: [{bounds.lower_bound:.3f}, {bounds.upper_bound:.3f}]") print(f" Count: {bounds.n_successes}/{bounds.n_evaluations}") # Show conditional singleton rates (conditional on having a singleton) has_correct = "correct_in_singleton" in op_bounds.rate_bounds has_error = "error_in_singleton" in op_bounds.rate_bounds has_singleton = "singleton" in op_bounds.rate_bounds if verbose and (has_correct or has_error) and has_singleton: print("\n CONDITIONAL RATES (conditioned on singleton, with CP+PAC bounds):") singleton_bounds = op_bounds.rate_bounds["singleton"] n_singletons = singleton_bounds.n_successes # P(correct | singleton) with rigorous CP bounds if has_correct and n_singletons > 0: correct_bounds = op_bounds.rate_bounds["correct_in_singleton"] n_correct = correct_bounds.n_successes # Conditional rate and CP interval rate = n_correct / n_singletons if n_singletons > 0 else 0.0 ci = cp_interval(n_correct, n_singletons) print(f" P(correct | singleton) = {rate:.3f} 95% CI: [{ci['lower']:.3f}, {ci['upper']:.3f}]") # P(error | singleton) with rigorous CP bounds if has_error and n_singletons > 0: error_bounds = op_bounds.rate_bounds["error_in_singleton"] n_error = error_bounds.n_successes # Conditional rate and CP interval rate = n_error / n_singletons if n_singletons > 0 else 0.0 ci = cp_interval(n_error, n_singletons) print(f" P(error | singleton) = {rate:.3f} 95% CI: [{ci['lower']:.3f}, {ci['upper']:.3f}]") # Store in summary summary[class_label] = { "n": n, "alpha_target": alpha_target, "alpha_corrected": alpha_corrected, "threshold": threshold, "calibration_stats": { "abstentions": abstentions, "singletons": singletons, "doublets": doublets, }, "pac_bounds": pac_bounds, } if operational_bounds_per_class and class_label in operational_bounds_per_class: summary[class_label]["operational_bounds"] = operational_bounds_per_class[class_label] # ==================== MARGINAL STATISTICS ==================== if marginal_operational_bounds is not None: if verbose: print(f"\n{'=' * 80}") print("MARGINAL STATISTICS (Deployment View - Ignores True Labels)") print(f"{'=' * 80}") print(f" Total samples: n = {marginal_operational_bounds.n_calibration}") print("\n ✅ RIGOROUS Marginal Bounds (LOO-CV)") print(f" CI width: {marginal_operational_bounds.ci_width:.1%}") print(f" Total evaluations: n = {marginal_operational_bounds.n_calibration}") # Show main rates for rate_name in ["abstention", "singleton", "doublet"]: if rate_name in marginal_operational_bounds.rate_bounds: bounds = marginal_operational_bounds.rate_bounds[rate_name] if verbose: print(f"\n {rate_name.upper()}:") print(f" Bounds: [{bounds.lower_bound:.3f}, {bounds.upper_bound:.3f}]") print(f" Count: {bounds.n_successes}/{bounds.n_evaluations}") # Show conditional singleton rates (marginal) has_correct = "correct_in_singleton" in marginal_operational_bounds.rate_bounds has_error = "error_in_singleton" in marginal_operational_bounds.rate_bounds has_singleton = "singleton" in marginal_operational_bounds.rate_bounds if verbose and (has_correct or has_error) and has_singleton: print("\n CONDITIONAL RATES (conditioned on singleton, with CP+PAC bounds):") singleton_bounds = marginal_operational_bounds.rate_bounds["singleton"] n_singletons = singleton_bounds.n_successes if has_correct and n_singletons > 0: correct_bounds = marginal_operational_bounds.rate_bounds["correct_in_singleton"] n_correct = correct_bounds.n_successes # Conditional rate and CP interval rate = n_correct / n_singletons if n_singletons > 0 else 0.0 ci = cp_interval(n_correct, n_singletons) print(f" P(correct | singleton) = {rate:.3f} 95% CI: [{ci['lower']:.3f}, {ci['upper']:.3f}]") if has_error and n_singletons > 0: error_bounds = marginal_operational_bounds.rate_bounds["error_in_singleton"] n_error = error_bounds.n_successes # Conditional rate and CP interval rate = n_error / n_singletons if n_singletons > 0 else 0.0 ci = cp_interval(n_error, n_singletons) print(f" P(error | singleton) = {rate:.3f} 95% CI: [{ci['lower']:.3f}, {ci['upper']:.3f}]") summary["marginal_bounds"] = marginal_operational_bounds if verbose: # Deployment interpretation sing_bounds = marginal_operational_bounds.rate_bounds.get("singleton") doub_bounds = marginal_operational_bounds.rate_bounds.get("doublet") abst_bounds = marginal_operational_bounds.rate_bounds.get("abstention") if sing_bounds: print("\n 📈 Deployment Expectations:") print( f" Automation (singletons): " f"{sing_bounds.lower_bound:.1%} - {sing_bounds.upper_bound:.1%} of cases" ) # Escalation = doublets + abstentions if doub_bounds and abst_bounds: esc_lower = doub_bounds.lower_bound + abst_bounds.lower_bound esc_upper = doub_bounds.upper_bound + abst_bounds.upper_bound print(f" Escalation (doublets+abstentions): {esc_lower:.1%} - {esc_upper:.1%} of cases") elif doub_bounds: print( f" Escalation (doublets): " f"{doub_bounds.lower_bound:.1%} - {doub_bounds.upper_bound:.1%} of cases" ) # ==================== WARNINGS ==================== if verbose: print(f"\n{'=' * 80}") print("NOTES") print(f"{'=' * 80}") print("\n✓ Per-class CIs are valid (Clopper-Pearson, exchangeable within class)") if operational_bounds_per_class or marginal_operational_bounds: print("✓ Operational bounds have PAC guarantees via cross-validation") else: print("\n⚠️ For rigorous deployment bounds, use:") print(" - generate_rigorous_pac_report() which provides all bounds") print( " - Access via report['pac_bounds_class_0']," " report['pac_bounds_class_1'], report['pac_bounds_marginal']" ) if prediction_stats.get("marginal") and marginal_operational_bounds is None: print("\n⚠️ Marginal stats from calibration data NOT shown (invalid CIs for Mondrian)") print( " Use generate_rigorous_pac_report() and access" " report['pac_bounds_marginal'] for valid marginal bounds" ) return summary
[docs] def plot_parallel_coordinates_plotly( df, columns: list[str] | None = None, color: str = "err_all", color_continuous_scale=None, title: str = "Mondrian sweep – interactive parallel coordinates", height: int = 600, base_opacity: float = 0.9, unselected_opacity: float = 0.06, ): """Create interactive parallel coordinates plot for hyperparameter sweep results. Parameters ---------- df : pd.DataFrame DataFrame with hyperparameter sweep results columns : list of str, optional Columns to display in parallel coordinates Default: ['a0','d0','a1','d1','cov','sing_rate','err_all','err_pred0','err_pred1','err_y0','err_y1','esc_rate'] color : str, default='err_all' Column to use for coloring lines color_continuous_scale : plotly colorscale, optional Color scale for the lines title : str, default="Mondrian sweep – interactive parallel coordinates" Plot title height : int, default=600 Plot height in pixels base_opacity : float, default=0.9 Opacity of selected lines unselected_opacity : float, default=0.06 Opacity of unselected lines (creates contrast) Returns ------- plotly.graph_objects.Figure Interactive plotly figure Examples -------- >>> import pandas as pd >>> df = sweep_hyperparams_and_collect(...) >>> fig = plot_parallel_coordinates_plotly(df, color='err_all') >>> fig.show() # In notebook >>> # Or save: fig.write_html("sweep_results.html") """ import plotly.express as px if columns is None: default_cols = [ "a0", "d0", "a1", "d1", "cov", "sing_rate", "err_all", "err_pred0", "err_pred1", "err_y0", "err_y1", "esc_rate", ] columns = [c for c in default_cols if c in df.columns] fig = px.parallel_coordinates( df, dimensions=columns, color=color if color in df.columns else None, color_continuous_scale=color_continuous_scale or px.colors.sequential.Blugrn, labels={c: c for c in columns}, ) # Maximize contrast between selected and unselected lines if fig.data: # Fade unselected lines fig.data[0].unselected.update(line=dict(color=f"rgba(1,1,1,{float(unselected_opacity)})")) fig.update_layout( title=title, height=height, margin=dict(l=40, r=40, t=60, b=40), plot_bgcolor="white", paper_bgcolor="white", font=dict(size=14), uirevision=True, # keep user brushing across updates ) # Make axis labels and ranges more readable fig.update_traces(labelfont=dict(size=14), rangefont=dict(size=12), tickfont=dict(size=12)) # Optional: title for colorbar if we're coloring by a column if color in df.columns and fig.data and getattr(fig.data[0], "line", None): if getattr(fig.data[0].line, "colorbar", None) is not None: fig.data[0].line.colorbar.title = color return fig