"""Hyperparameter sweep and optimization for Mondrian conformal prediction."""
import itertools
from collections.abc import Callable
from typing import Any, Literal
import numpy as np
import pandas as pd
from ssbc.calibration import mondrian_conformal_calibrate
from ssbc.reporting import plot_parallel_coordinates_plotly, report_prediction_stats
[docs]
def sweep_hyperparams_and_collect(
class_data: dict[int, dict[str, Any]],
alpha_0: np.ndarray,
delta_0: np.ndarray,
alpha_1: np.ndarray,
delta_1: np.ndarray,
mode: Literal["beta", "beta-binomial"] = "beta",
extra_metrics: dict[str, Callable] | None = None,
quiet: bool = True,
) -> pd.DataFrame:
"""Sweep (a0,d0,a1,d1), run mondrian_conformal_calibrate + report_prediction_stats,
and return a tidy DataFrame with hyperparams + selected metrics.
This function performs a grid search over hyperparameter combinations and
evaluates the resulting conformal prediction performance.
Parameters
----------
class_data : dict
Output from split_by_class()
alpha_0 : array-like
Grid of alpha values for class 0
delta_0 : array-like
Grid of delta values for class 0
alpha_1 : array-like
Grid of alpha values for class 1
delta_1 : array-like
Grid of delta values for class 1
mode : str, default="beta"
"beta" or "beta-binomial" mode for SSBC
extra_metrics : dict of {name: function}, optional
Additional metrics to compute. Each function takes the summary dict
and returns a scalar value.
quiet : bool, default=True
If True, suppress progress output
Returns
-------
pd.DataFrame
Tidy dataframe with one row per hyperparameter combination.
Columns include:
- a0, d0, a1, d1: hyperparameters
- cov: overall coverage rate
- sing_rate: singleton prediction rate
- err_all: overall singleton error rate
- err_pred0, err_pred1: errors by predicted class
- err_y0, err_y1: errors by true class
- esc_rate: escalation rate (doublets + abstentions)
- n_total, sing_count, m_abst, m_doublets: counts
- Any additional metrics from extra_metrics
Examples
--------
>>> import numpy as np
>>> from ssbc import BinaryClassifierSimulator, split_by_class
>>>
>>> # Generate data
>>> sim = BinaryClassifierSimulator(0.1, (2, 8), (8, 2), seed=42)
>>> labels, probs = sim.generate(1000)
>>> class_data = split_by_class(labels, probs)
>>>
>>> # Define grid
>>> alpha_grid = np.arange(0.05, 0.20, 0.05)
>>> delta_grid = np.arange(0.05, 0.20, 0.05)
>>>
>>> # Run sweep
>>> df = sweep_hyperparams_and_collect(
... class_data,
... alpha_0=alpha_grid, delta_0=delta_grid,
... alpha_1=alpha_grid, delta_1=delta_grid,
... )
>>>
>>> # Analyze results
>>> print(df[['a0', 'a1', 'cov', 'sing_rate', 'err_all']].head())
Notes
-----
The function performs a complete grid search, so the total number of
evaluations is len(alpha_0) × len(delta_0) × len(alpha_1) × len(delta_1).
For large grids, this can be computationally expensive.
"""
rows = []
combos = list(itertools.product(alpha_0, delta_0, alpha_1, delta_1))
for a0, d0, a1, d1 in combos:
if not quiet:
print(f"a0={a0:.3f}, d0={d0:.3f}, a1={a1:.3f}, d1={d1:.3f}")
cal_result, pred_stats = mondrian_conformal_calibrate(
class_data=class_data,
alpha_target={0: float(a0), 1: float(a1)},
delta={0: float(d0), 1: float(d1)},
mode=mode,
)
summary = report_prediction_stats(pred_stats, cal_result, verbose=False)
# Robust getter
def g(d, *keys, default=None):
"""Navigate nested dict safely."""
cur = d
for k in keys:
if not isinstance(cur, dict) or k not in cur:
return default
cur = cur[k]
return cur
n_total = int(g(summary, "marginal", "n_total", default=0) or 0)
cov = float(g(summary, "marginal", "coverage", "rate", default=0.0) or 0.0)
sing_rate = float(g(summary, "marginal", "singletons", "rate", default=0.0) or 0.0)
sing_cnt = int(g(summary, "marginal", "singletons", "count", default=0) or 0)
abst_cnt = int(g(summary, "marginal", "abstentions", "count", default=0) or 0)
doub_cnt = int(g(summary, "marginal", "doublets", "count", default=0) or 0)
esc_rate = (abst_cnt + doub_cnt) / float(n_total if n_total else 1)
err_all = float(g(summary, "marginal", "singletons", "errors", "rate", default=0.0) or 0.0)
err_p0 = float(g(summary, "marginal", "singletons", "errors_by_pred", "pred_0", "rate", default=0.0) or 0.0)
err_p1 = float(g(summary, "marginal", "singletons", "errors_by_pred", "pred_1", "rate", default=0.0) or 0.0)
err_y0 = float(g(summary, 0, "singletons", "error_given_singleton", "rate", default=0.0) or 0.0)
err_y1 = float(g(summary, 1, "singletons", "error_given_singleton", "rate", default=0.0) or 0.0)
row = {
"a0": float(a0),
"d0": float(d0),
"a1": float(a1),
"d1": float(d1),
"cov": cov,
"sing_rate": sing_rate,
"err_all": err_all,
"err_pred0": err_p0,
"err_pred1": err_p1,
"err_y0": err_y0,
"err_y1": err_y1,
"esc_rate": esc_rate,
"n_total": int(n_total),
"sing_count": int(sing_cnt),
"m_abst": abst_cnt,
"m_doublets": doub_cnt,
}
if extra_metrics:
for name, fn in extra_metrics.items():
try:
row[name] = fn(summary)
except Exception:
row[name] = np.nan
rows.append(row)
df = pd.DataFrame(rows)
return df.sort_values(["a0", "d0", "a1", "d1"], kind="mergesort").reset_index(drop=True)
[docs]
def sweep_and_plot_parallel_plotly(
class_data: dict[int, dict[str, Any]],
delta_0: np.ndarray,
delta_1: np.ndarray,
alpha_0: np.ndarray,
alpha_1: np.ndarray,
mode: Literal["beta", "beta-binomial"] = "beta",
extra_metrics: dict[str, Callable] | None = None,
color: str = "err_all",
color_continuous_scale=None,
title: str | None = None,
height: int = 600,
):
"""Convenience wrapper: run sweep + show plotly parallel coordinates figure.
This function combines hyperparameter sweep and visualization in one call.
Parameters
----------
class_data : dict
Output from split_by_class()
delta_0, delta_1 : array-like
Grid of delta values for classes 0 and 1
alpha_0, alpha_1 : array-like
Grid of alpha values for classes 0 and 1
mode : str, default="beta"
"beta" or "beta-binomial" mode for SSBC
extra_metrics : dict of {name: function}, optional
Additional metrics to compute
color : str, default='err_all'
Column to use for coloring the parallel coordinates
color_continuous_scale : plotly colorscale, optional
Color scale for the plot
title : str, optional
Plot title (defaults to auto-generated title)
height : int, default=600
Plot height in pixels
Returns
-------
df : pd.DataFrame
Results dataframe
fig : plotly.graph_objects.Figure
Interactive parallel coordinates plot
Examples
--------
>>> import numpy as np
>>> from ssbc import BinaryClassifierSimulator, split_by_class
>>>
>>> # Generate data
>>> sim = BinaryClassifierSimulator(0.1, (2, 8), (8, 2), seed=42)
>>> labels, probs = sim.generate(1000)
>>> class_data = split_by_class(labels, probs)
>>>
>>> # Run sweep and plot
>>> df, fig = sweep_and_plot_parallel_plotly(
... class_data,
... delta_0=np.arange(0.05, 0.20, 0.05),
... delta_1=np.arange(0.05, 0.20, 0.05),
... alpha_0=np.arange(0.05, 0.20, 0.05),
... alpha_1=np.arange(0.05, 0.20, 0.05),
... color='err_all'
... )
>>> fig.show() # Display in notebook
>>> # Or save: fig.write_html("sweep_results.html")
Notes
-----
The parallel coordinates plot allows interactive exploration of the
hyperparameter space. You can brush (select) ranges on any axis to
filter configurations and see their impact on other metrics.
"""
df = sweep_hyperparams_and_collect(
class_data=class_data,
alpha_0=alpha_0,
delta_0=delta_0,
alpha_1=alpha_1,
delta_1=delta_1,
mode=mode,
extra_metrics=extra_metrics,
quiet=True,
)
if title is None:
title = f"Mondrian Hyperparameter Sweep (n={len(df)} configs)"
fig = plot_parallel_coordinates_plotly(
df, color=color, color_continuous_scale=color_continuous_scale, title=title, height=height
)
return df, fig