from __future__ import annotations
import os
import warnings
from typing import Union, Optional, List, Sequence, Tuple
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.ticker as mticker
from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec
from matplotlib.lines import Line2D
from matplotlib.patches import FancyArrowPatch, Patch, Rectangle
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import seaborn as sns
import statsmodels.api as sm
from scipy.stats import gaussian_kde
from shap import Explanation
from shap.plots._labels import labels
from robustipy.utils import get_selection_key, get_colormap_colors
plt.rcParams['axes.unicode_minus'] = False
def _legend_side_from_hist(ax, *, tau: float = 0.6) -> str:
"""
Decide whether 'upper left' or 'upper right' is safer, given the
rectangular patches in *ax* produced by seaborn.histplot.
Parameters
----------
ax : matplotlib.axes.Axes
Axes containing seaborn.histplot patches.
tau : float, default=0.6
Safety threshold in (0,1). A bar exceeding `tau * ylim_max` on one side
forces the legend to the opposite.
Returns
-------
str
Either 'upper left' or 'upper right'.
"""
bars = [p for p in ax.patches if isinstance(p, Rectangle) and p.get_height() > 0]
if not bars: # fall-back if no histogram rendered
return 'upper left'
# Split bars at the sample median
median_x = np.median([p.get_x() + p.get_width() / 2 for p in bars])
left_max = max((p.get_height() for p in bars if (p.get_x() + p.get_width() / 2) < median_x), default=0.0)
right_max = max((p.get_height() for p in bars if (p.get_x() + p.get_width() / 2) >= median_x), default=0.0)
f_max = max(left_max, right_max)
ylim_top = ax.get_ylim()[1]
# ‘Legend altitude’ test
left_hits = left_max > tau * ylim_top
right_hits = right_max > tau * ylim_top
if left_hits and (left_max >= right_max):
return 'upper right'
if right_hits and (right_max > left_max):
return 'upper left'
return 'upper left' # default / tie
def _right_align_axes_to(reference_ax: plt.Axes, *axes: plt.Axes) -> None:
"""
Right-align one or more axes to the plotting box of ``reference_ax``.
"""
ref_pos = reference_ax.get_position()
for ax in axes:
pos = ax.get_position()
ax.set_position([ref_pos.x1 - pos.width, pos.y0, pos.width, pos.height])
def _set_axes_horizontal_span(x0: float, x1: float, *axes: plt.Axes) -> None:
"""
Set a shared horizontal span for one or more axes while preserving their
current vertical placement.
"""
if x1 <= x0:
raise ValueError("x1 must be greater than x0.")
width = x1 - x0
for ax in axes:
pos = ax.get_position()
ax.set_position([x0, pos.y0, width, pos.height])
def _blue_palette(num_colors: int = 1) -> List[str]:
"""
Return a consistent medium-blue tone for visual accents.
"""
if num_colors < 1:
raise ValueError("num_colors must be >= 1")
cmap = matplotlib.colormaps['Blues']
mid = 0.72
return [matplotlib.colors.to_hex(cmap(mid), keep_alpha=False)] * num_colors
def _spec_order_idx(results_object, oddsratio: bool) -> np.ndarray:
"""
Return the spec ordering index used by the specification curve (sorted by median).
"""
if oddsratio and hasattr(results_object, "estimates_exp"):
medians = results_object.estimates_exp.quantile(q=0.5, axis=1)
else:
medians = results_object.estimates.quantile(q=0.5, axis=1)
return medians.sort_values().index.to_numpy()
[docs]
def plot_hexbin_r2(
results_object,
ax: plt.Axes,
fig: plt.Figure,
oddsratio: bool,
colormap: Union[str, cm.Colormap],
title: str = "",
side: str = "left",
) -> None:
"""
Hex-bin density plot of boot-strapped coefficient estimates versus in-sample
:math:`R^2`, together with a marginal colour-bar of observation counts.
Parameters
----------
results_object : Any
Must expose ``results_object.estimates`` and ``results_object.r2_values``,
each supporting ``.stack()`` to obtain 1-d views.
ax : matplotlib.axes.Axes
Target axes.
fig : matplotlib.figure.Figure
Parent figure, needed for colour-bar geometry.
oddsratio : bool
If True, use exponentiated estimates for plotting.
colormap : str | matplotlib.colors.Colormap
Matplotlib-compatible colormap.
title : str, optional
Axes title.
side : {'left', 'right'}, optional
* ``'left'`` – conventional layout: y-axis on the left, colour-bar on the
right.
* ``'right'`` – mirror layout: y-axis (ticks, label, spine) on the right,
colour-bar on the left; the left spine is removed.
Returns
-------
None
Draws in place on `ax`.
Raises
------
ValueError
If `side` is not 'left' or 'right'.
Notes
-----
Only the presentation layer is mirrored; the data are not transformed.
"""
# ------------------------------------------------------------------ #
# 1. Hex-bin and colour-bar #
# ------------------------------------------------------------------ #
if oddsratio is True:
image = ax.hexbin(
results_object.estimates_exp.stack(),
results_object.r2_values.stack(),
cmap=colormap,
gridsize=20,
mincnt=1,
edgecolor="k",
)
else:
image = ax.hexbin(
results_object.estimates.stack(),
results_object.r2_values.stack(),
cmap=colormap,
gridsize=20,
mincnt=1,
edgecolor="k",
)
# Place the colour-bar opposite the y-axis
cb_location = "right" if side == "left" else "left"
cb = fig.colorbar(
image,
ax=ax,
spacing="uniform",
pad=0.05,
extend="max",
location=cb_location, # Matplotlib ≥ 3.3
)
# ------------------------------------------------------------------ #
# 2. Colour-bar tick formatting #
# ------------------------------------------------------------------ #
data = image.get_array()
ticks = np.linspace(data.min(), data.max(), num=6)
cb.set_ticks(ticks)
if 1_000 <= data.max() < 10_000:
cb.set_ticklabels([f"{t / 1_000:.1f}k" for t in ticks])
elif data.max() >= 10_000:
cb.set_ticklabels([f"{t / 1_000:.0f}k" for t in ticks])
else:
cb.set_ticklabels([f"{t:.0f}" for t in ticks])
cb.ax.set_title('Count')
# ------------------------------------------------------------------ #
# 2a. Optional trimming & baseline alignment (only when on the left)#
# ------------------------------------------------------------------ #
if cb_location == "left": # i.e. side == "right"
frac = 0.1 # fraction to trim (from the top)
ax_box = ax.get_position() # main axes box (for baseline)
cb_box = cb.ax.get_position()
new_height = cb_box.height * (1 - frac)
cb.ax.set_position([
cb_box.x0, # keep x-position
ax_box.y0, # align bottom with axes baseline
cb_box.width,
new_height # shorten from the top only
])
# ------------------------------------------------------------------ #
# 3. Axes labels, title, tick locators #
# ------------------------------------------------------------------ #
#
if results_object.model_name=='Logistic Regression Robust':
axis_formatter(
ax,
r"R$_{\mathrm{McF}}^2$",
r"Bootstrapped Estimand",
title,
side,
)
else:
axis_formatter(
ax,
r"In-Sample $\bar{\mathrm{R}}^2$",
r"Bootstrapped Estimand",
title,
side,
)
ax.xaxis.set_major_locator(mticker.MaxNLocator(4))
ax.yaxis.set_major_locator(mticker.MaxNLocator(4))
# ------------------------------------------------------------------ #
# 4. Side-dependent spines, ticks, labels #
# ------------------------------------------------------------------ #
if side == "right":
# shift ticks and label to the right
ax.yaxis.set_ticks_position("right")
ax.yaxis.tick_right()
ax.yaxis.set_label_position("right")
# ensure both marks *and* labels appear on the right
ax.tick_params(axis="y",
which="both",
right=True, labelright=True,
left=False, labelleft=False)
# keep right spine, hide left spine
ax.spines["right"].set_visible(True)
ax.spines["left"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.spines["bottom"].set_visible(True)
elif side == "left":
ax.yaxis.set_ticks_position("left")
ax.yaxis.tick_left()
ax.yaxis.set_label_position("left")
ax.tick_params(axis="y",
which="both",
right=False, labelright=False,
left=True, labelleft=True)
ax.spines["left"].set_visible(True)
ax.spines["right"].set_visible(False)
sns.despine(ax=ax, left=False, right=True)
else:
raise ValueError("`side` must be 'left' or 'right'.")
# ------------------------------------------------------------------ #
# 5. Void function – all graphics modified in-place #
# ------------------------------------------------------------------ #
[docs]
def plot_hexbin_log(
results_object,
ax: plt.Axes,
fig: plt.Figure,
oddsratio: bool,
colormap: Union[str, cm.Colormap],
title: str = ''
) -> None:
"""
Plot a hex-bin density of full-sample coefficient estimates vs. log-likelihood.
Parameters
----------
results_object : object
Must expose:
- `all_b`/all_b_exp: list/array of full-sample coefficient arrays
- `summary_df['ll']` or `summary_df['ll_gain_per_obs']`: corresponding
likelihood metric values
ax : matplotlib.axes.Axes
The axes on which to draw the hex-bin.
fig : matplotlib.figure.Figure
Parent figure (needed to place the colorbar).
oddsratio : bool
If True, use exponentiated estimates for plotting.
colormap : str or Colormap
Name or object of a Matplotlib colormap.
title : str, optional
Title displayed above the plot (default: '').
Returns
-------
None
"""
if 'll_gain_per_obs' in results_object.summary_df.columns:
ll_values = results_object.summary_df['ll_gain_per_obs']
y_label = r'Null-Relative Log-Likelihood Gain'
else:
ll_values = results_object.summary_df['ll']
y_label = r'Full Model Log Likelihood'
if oddsratio is True:
image = ax.hexbin(results_object.all_b_exp,
ll_values,
cmap=colormap,
gridsize=20,
mincnt=1,
edgecolor='k'
)
else:
image = ax.hexbin([arr[0][0] for arr in results_object.all_b.copy()],
ll_values,
cmap=colormap,
gridsize=20,
mincnt=1,
edgecolor='k'
)
cb = fig.colorbar(image, ax=ax, spacing='uniform', extend='max', pad=0.05)
data = image.get_array()
ticks = np.linspace(data.min(), data.max(), num=6)
cb.set_ticks(ticks)
if (data.max() >= 1000) and (data.max() < 10000):
cb.set_ticklabels([f'{tick / 1000:.1f}k' for tick in ticks])
elif data.max() >= 10000:
cb.set_ticklabels([f'{tick / 1000:.0f}k' for tick in ticks])
else:
cb.set_ticklabels([f'{tick:.0f}' for tick in ticks])
cb.ax.set_title('Count')
axis_formatter(ax, y_label, r'Full-Sample Estimand', title)
ax.yaxis.set_major_locator(mticker.MaxNLocator(4))
ax.xaxis.set_major_locator(mticker.MaxNLocator(4))
ax.yaxis.set_major_formatter(mticker.FormatStrFormatter('%.3g'))
sns.despine(ax=ax)
[docs]
def shap_violin(
ax: plt.Axes,
shap_values: Union[np.ndarray, List[np.ndarray], Explanation],
features: Optional[Union[np.ndarray, pd.DataFrame, List[str]]] = None,
feature_names: Optional[List[str]] = None,
max_display: int = 10,
color: Optional[Union[str, Sequence]] = None,
alpha: float = 1.0,
cmap: str = 'viridis',
use_log_scale: bool = False,
title: str = '',
clear_yticklabels: bool = False,
cbar_ax: Optional[plt.Axes] = None,
cbar_width: float = 0.04,
cbar_width_fig: Optional[float] = None
) -> List[str]:
"""
Create a SHAP beeswarm plot, colored by feature values when they are provided.
Parameters
----------
ax : matplotlib.axes.Axes
Axes on which to draw the plot.
shap_values : array-like or Explanation
SHAP value matrix (#samples×#features), or a list thereof for multiclass,
or a SHAP Explanation object.
features : array-like, DataFrame, or list of str, optional
Feature value matrix (#samples×#features), or just a feature_names list.
Default: None (no coloring).
feature_names : list of str, optional
Names of each feature. Default: None (will infer or auto‐label).
max_display : int, default=10
Maximum number of top features (by mean absolute SHAP value) to show.
color : str or sequence, optional
Single color for all points when no feature values given.
alpha : float, default=1.0
Opacity for scatter points.
cmap : str, default='viridis'
Colormap for coloring points.
use_log_scale : bool, default=False
If True, use symlog x-axis scaling.
title : str, optional
Title text for the axes.
clear_yticklabels : bool, default=False
If True, hide the y-tick labels.
cbar_ax : matplotlib.axes.Axes, optional
If provided, draw the colorbar inside this axes instead of
creating an inset colorbar on the right.
cbar_width : float, default=0.04
Colorbar width as a fraction of the axes width (or the cbar axis width
when `cbar_ax` is provided).
cbar_width_fig : float, optional
Absolute colorbar width as a fraction of the figure width. If provided,
this overrides `cbar_width` when `cbar_ax` is given, ensuring consistent
absolute thickness across panels.
Returns
-------
List[str]
Ordered list of feature names actually plotted.
"""
# If SHAP Explanation object is passed, extract values and related info
if str(type(shap_values)).endswith("Explanation'>"):
shap_exp = shap_values
shap_values = shap_exp.values
if features is None:
features = shap_exp.data
if feature_names is None:
feature_names = shap_exp.feature_names
if len(shap_exp.base_values.shape) == 2 and shap_exp.base_values.shape[1] > 2:
shap_values = [shap_values[:, :, i] for i in range(shap_exp.base_values.shape[1])]
if isinstance(features, pd.DataFrame):
if feature_names is None:
feature_names = features.columns
features = features.values
elif isinstance(features, list):
if feature_names is None:
feature_names = features
features = None
elif (features is not None) and len(features.shape) == 1 and feature_names is None:
feature_names = features
features = None
num_features = shap_values.shape[1]
# Drop SHAP bias column if present
if features is not None:
if shap_values.shape[1] == features.shape[1] + 1:
shap_values = shap_values[:, :-1] # drop bias
num_features -= 1
elif shap_values.shape[1] != features.shape[1]:
raise ValueError(
f"'shap_values' has {shap_values.shape[1]} columns but "
f"'features' has {features.shape[1]} – shapes don’t match."
)
# Set default feature names if still missing
if feature_names is None:
feature_names = np.array([labels["FEATURE"] % str(i) for i in range(num_features)])
# Set x-axis to symmetric log scale if requested
if use_log_scale:
ax.xscale("symlog")
if max_display is None:
max_display = 20
# Order features by mean absolute SHAP importance
feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0))
feature_order = feature_order[-min(max_display, len(feature_order)):]
# Add horizontal lines for each feature row in the plot
for pos in range(len(feature_order)):
ax.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
# Main path: coloring by feature values is possible
if features is not None:
# Compute global SHAP value range for noise scaling
global_low = np.nanpercentile(shap_values[:, : len(feature_names)].flatten(), 1)
global_high = np.nanpercentile(shap_values[:, : len(feature_names)].flatten(), 99)
for pos, i in enumerate(feature_order):
shaps = shap_values[:, i]
shap_min, shap_max = np.min(shaps), np.max(shaps)
rng = shap_max - shap_min
xs = np.linspace(np.min(shaps) - rng * 0.2, np.max(shaps) + rng * 0.2, 100)
# Estimate density; add noise if nearly constant
if np.std(shaps) < (global_high - global_low) / 100:
ds = gaussian_kde(shaps + np.random.randn(len(shaps)) * (global_high - global_low) / 100)(xs)
else:
ds = gaussian_kde(shaps)(xs)
ds /= np.max(ds) * 3 # Normalize for plot size
values = features[:, i]
# Smooth feature values for color gradients
smooth_values = np.zeros(len(xs) - 1)
sort_inds = np.argsort(shaps)
trailing_pos, leading_pos, running_sum, back_fill = 0, 0, 0, 0
for j in range(len(xs) - 1):
while leading_pos < len(shaps) and xs[j] >= shaps[sort_inds[leading_pos]]:
running_sum += values[sort_inds[leading_pos]]
leading_pos += 1
if leading_pos - trailing_pos > 20:
running_sum -= values[sort_inds[trailing_pos]]
trailing_pos += 1
if leading_pos - trailing_pos > 0:
smooth_values[j] = running_sum / (leading_pos - trailing_pos)
for k in range(back_fill):
smooth_values[j - k - 1] = smooth_values[j]
else:
back_fill += 1
# Clip and normalize feature values for coloring
vmin = np.nanpercentile(values, 5)
vmax = np.nanpercentile(values, 95)
if vmin == vmax:
vmin = np.nanpercentile(values, 1)
vmax = np.nanpercentile(values, 99)
if vmin == vmax:
vmin = np.min(values)
vmax = np.max(values)
nan_mask = np.isnan(values)
# Plot SHAP values where feature values are NaN in gray
ax.scatter(
shaps[nan_mask],
np.ones(shap_values[nan_mask].shape[0]) * pos,
color="#777777",
s=9,
alpha=alpha,
linewidth=0,
zorder=1,
rasterized=True
)
# Clip and prepare color values
cvals = values[np.invert(nan_mask)].astype(np.float64)
cvals_imp = cvals.copy()
cvals_imp[np.isnan(cvals)] = (vmin + vmax) / 2.0
cvals[cvals_imp > vmax] = vmax
cvals[cvals_imp < vmin] = vmin
# Plot SHAP points colored by feature values
ax.scatter(
shaps[np.invert(nan_mask)],
np.ones(shap_values[np.invert(nan_mask)].shape[0]) * pos,
cmap=cmap,
vmin=vmin,
vmax=vmax,
s=9,
c=cvals,
alpha=alpha,
linewidth=0,
zorder=1,
rasterized=True,
)
# Normalize and color density envelope
smooth_values -= vmin
if vmax - vmin > 0:
smooth_values /= vmax - vmin
for i in range(len(xs) - 1):
if ds[i] > 0.05 or ds[i + 1] > 0.05:
ax.fill_between(
[xs[i], xs[i + 1]],
[pos + ds[i], pos + ds[i + 1]],
[pos - ds[i], pos - ds[i + 1]],
color=plt.get_cmap(cmap)(smooth_values[i]),
zorder=2,
)
else:
# If no feature values, just plot plain violin plots
parts = ax.violinplot(
shap_values[:, feature_order],
range(len(feature_order)),
points=200,
vert=False,
widths=0.7,
showmeans=False,
showextrema=False,
showmedians=False,
rasterized=True
)
for pc in parts["bodies"]:
pc.set_facecolor(color)
pc.set_edgecolor("none")
pc.set_alpha(alpha)
# Add colorbar if using feature values
m = cm.ScalarMappable(cmap=cmap)
m.set_array([0, 1])
if cbar_ax is None:
cax = inset_axes(
ax,
width=f"{cbar_width * 100:.1f}%",
height="100%",
loc="lower left",
bbox_to_anchor=(1.02, 0.0, 1, 1),
bbox_transform=ax.transAxes,
borderpad=0,
)
else:
cbar_ax.set_axis_off()
cbar_pos = cbar_ax.get_position()
if cbar_width_fig is not None and cbar_pos.width > 0:
width_frac = min(1.0, cbar_width_fig / cbar_pos.width)
else:
ax_pos = ax.get_position()
if cbar_pos.width > 0:
width_frac = min(1.0, (cbar_width * ax_pos.width) / cbar_pos.width)
else:
width_frac = cbar_width
cax = inset_axes(
cbar_ax,
width=f"{width_frac * 100:.1f}%",
height="100%",
loc="center",
bbox_to_anchor=(0.0, 0.0, 1.0, 1.0),
bbox_transform=cbar_ax.transAxes,
borderpad=0,
)
cb = plt.colorbar(m, cax=cax, ticks=[0, 1])
cb.set_ticklabels(['Low', 'High'])
if cbar_ax is None:
cb.set_label('Feature Value', size=12, labelpad=-20)
else:
cb.ax.yaxis.set_ticks_position('right')
cb.ax.yaxis.set_label_position('right')
cb.set_label('Feature Value', size=12, labelpad=-8)
cb.outline.set_edgecolor('k')
cb.ax.tick_params(labelsize=11, length=0)
cb.set_alpha(1)
cb.outline.set_visible(True)
# Configure axis appearance
ax.xaxis.set_ticks_position("bottom")
ax.spines["left"].set_visible(True)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
# Set y-axis ticks and labels
feature_name_order = [feature_names[i] for i in feature_order]
ax.set_yticks(range(len(feature_order)), feature_name_order, fontsize=13)
ax.set_ylim(-1, len(feature_order))
# Set axis labels and title
axis_formatter(ax, r'', r'SHAP Values', title)
title_setter(ax, title)
# Optionally clear y-tick labels
if clear_yticklabels:
ax.set_yticklabels([])
return feature_name_order
[docs]
def plot_curve(
results_object,
loess: bool = True,
ci: float = 1,
oddsratio: bool = False,
specs: Optional[List[List[str]]] = None,
ax: Optional[plt.Axes] = None,
highlights: bool = True,
inset: bool = True,
title: str = '',
colormap: Union[str, matplotlib.colors.Colormap] = 'viridis'
) -> plt.Axes:
"""
Plot the specification-curve of median and CI for coefficient estimates.
Parameters
----------
results_object : object
Must expose `.summary_df` (with columns 'median'), `.specs_names`,
`.estimates` (DataFrame of bootstrap draws), `.draws`, `.kfold`,
and `.inference` dict.
loess : bool, default=True
Whether to smooth the lower/upper CI bounds with LOESS.
ci : float, default=1
The confidence-level (e.g. 0.95 for a 95% interval).
oddsratio : bool, default=False
If True, exponentiate the estimates before plotting.
specs : list of control-lists, optional
Up to three specs to highlight. Default: None (no highlights).
ax : matplotlib.axes.Axes, optional
Axes to draw on. Default: current axes.
colormap : str or Colormap, default='viridis'
Colormap for highlights and related elements.
title : str, optional
Title text for the axes.
highlights : bool, default=False
If True, highlights the full model and the null model in the plot.
inset : bool, default=True
If True, adds an inset with the full model and null model highlights.
Returns
-------
matplotlib.axes.Axes
The axes containing the plot.
"""
# Prepare axes
if ax is None:
ax = plt.gca()
# Compute quantiles and annotate DataFrame
df = pd.DataFrame()
alpha = 1 - ci
if oddsratio is True:
qs = results_object.estimates_exp.quantile(q=[alpha / 2, 1 - alpha / 2],
axis=1,
interpolation='nearest'
)
df['median'] = results_object.estimates_exp.quantile(q=0.5, axis=1)
else:
qs = results_object.estimates.quantile(q=[alpha / 2, 1 - alpha / 2],
axis=1,
interpolation='nearest'
)
df['median'] = results_object.estimates.quantile(q=0.5, axis=1)
# Store lower and upper quantiles
df['q_low'] = qs.iloc[0].values
df['q_high'] = qs.iloc[1].values
# Identify and flag full and null specs for highlighting
if highlights:
full_spec, null_spec = (list(results_object.specs_names.iloc[-1]),
list(results_object.specs_names.iloc[0]))
df['full_spec_idx'] = results_object.specs_names.isin(get_selection_key([full_spec]))
df['null_spec_idx'] = results_object.specs_names.isin(get_selection_key([null_spec]))
if specs:
df['idx'] = results_object.specs_names.isin(get_selection_key(specs))
df['specs_names'] = results_object.specs_names
# Sort by median (shared ordering for curve + spec matrix)
order_idx = _spec_order_idx(results_object, oddsratio)
df = df.loc[order_idx].reset_index(drop=True)
n = len(df)
# Sample colours: first for null, next for highlights, last for full
n_hl = len(specs) if specs else 0
# colourset = get_colormap_colors(colormap, n_hl + 2)
colourset = get_colormap_colors(n_hl + 2, colormap)
null_color = colourset[0]
spec_colors = colourset[1:-1]
full_color = colourset[-1]
# Plot median
median_color = 'k'
df['median'].plot(ax=ax, color=median_color, linestyle='-')
# Plot CI bounds
hi_color = 'gray'
if loess:
frac = max(2 / n, 0.3)
lo_low = sm.nonparametric.lowess(df['q_low'], df.index, frac=frac)
lo_high = sm.nonparametric.lowess(df['q_high'], df.index, frac=frac)
ax.plot(lo_low[:, 0], lo_low[:, 1], color=hi_color, linestyle='--')
ax.plot(lo_high[:, 0], lo_high[:, 1], color=hi_color, linestyle='--')
ax.fill_between(df.index, lo_low[:, 1], lo_high[:, 1], facecolor='#fee08b', alpha=0.15)
else:
ax.plot(df.index, df['q_low'], color=hi_color, linestyle='--', label='Lower CI')
ax.plot(df.index, df['q_high'], color=hi_color, linestyle='--', label='Upper CI')
ax.fill_between(df.index, df['q_low'], df['q_high'], facecolor='#fee08b', alpha=0.15)
# Zero line
y0, y1 = ax.get_ylim()
if y0 < 0 < y1:
ax.axhline(0, color='k', ls='--')
# Prepare handles for legend
handles = []
# Plot vertical intervals for user-specified highlighted specs
if specs:
idxs = df.index[df['idx']].tolist()
for j, idx in enumerate(idxs):
col = spec_colors[j]
lbl = ', '.join(df['specs_names'].iloc[idx])
low = lo_low[idx, 1] if loess else df.at[idx, 'q_low']
high = lo_high[idx, 1] if loess else df.at[idx, 'q_high']
ax.vlines(idx, ymin=low, ymax=high, color=col)
arrow = FancyArrowPatch((idx, low), (idx, high), arrowstyle='<|-|>', color=col,
mutation_scale=20, shrinkA=0, shrinkB=0)
ax.add_artist(arrow)
ax.plot(idx, df.at[idx, 'median'], 'o', markeredgecolor='k', markerfacecolor='w', markersize=12)
handles.append(Line2D([0], [0], marker='o', color=col, markerfacecolor='w', markersize=10, label=lbl))
# Plot full model highlight
if highlights:
pos_full = df.index[df['full_spec_idx']].item()
low_f = lo_low[pos_full, 1] if loess else df.at[pos_full, 'q_low']
high_f = lo_high[pos_full, 1] if loess else df.at[pos_full, 'q_high']
ax.vlines(pos_full, ymin=low_f, ymax=high_f, color=full_color)
arrow_f = FancyArrowPatch((pos_full, low_f), (pos_full, high_f), arrowstyle='<|-|>', color=full_color,
mutation_scale=20, shrinkA=0, shrinkB=0)
ax.add_artist(arrow_f)
ax.plot(pos_full, df.at[pos_full, 'median'], 'o', markeredgecolor='k', markerfacecolor='w', markersize=12)
# Add to legend depending on number of outcomes
if max(len(t) for t in results_object.y_name) == 1:
handles.append(
Line2D([0], [0], marker='o', color=full_color, markerfacecolor='w', markersize=10, label='Full Model'))
else:
handles.append(
Line2D([0], [0], marker='o', color=full_color, markerfacecolor='w', markersize=10, label='All Data Used'))
# Plot null model highlight
if highlights:
pos_null = df.index[df['null_spec_idx']].item()
low_n = lo_low[pos_null, 1] if loess else df.at[pos_null, 'q_low']
high_n = lo_high[pos_null, 1] if loess else df.at[pos_null, 'q_high']
ax.vlines(pos_null, ymin=low_n, ymax=high_n, color=null_color)
arrow_n = FancyArrowPatch((pos_null, low_n), (pos_null, high_n), arrowstyle='<|-|>', color=null_color,
mutation_scale=20, shrinkA=0, shrinkB=0)
ax.add_artist(arrow_n)
ax.plot(pos_null, df.at[pos_null, 'median'], 'o', markeredgecolor='k', markerfacecolor='w', markersize=12)
# Add to legend depending on number of outcomes
if max(len(t) for t in results_object.y_name) == 1:
handles.append(
Line2D([0], [0], marker='o', color=null_color, markerfacecolor='w', markersize=10, label='No Controls'))
else:
handles.append(
Line2D([0], [0], marker='o', color=null_color, markerfacecolor='w', markersize=10, label=r'First y Only'))
# Display the legend if both highlights and specs are present
if highlights and (specs is not None):
ax.legend(handles=handles, frameon=True, edgecolor='black', fontsize=10,
loc='lower right', ncols=2, framealpha=1, facecolor='w')
# Format axes and apply limits
axis_formatter(ax, r'Estimand of Interest', 'Ordered Specifications', title)
ax.set_xlim(-0.5, n - 0.5)
pad = (y1 - y0) * 0.1
ax.set_ylim(y0 - pad, y1 + pad)
# Summary inset
if inset:
median_inf = results_object.inference['median']
z_score = results_object.inference['Stouffers'][0]
if np.isnan(z_score):
median_line = f'Median: {median_inf:.3f}'
else:
median_line = f'Median: {median_inf:.3f} (Z: {z_score:.3f})'
info_text = (
f'Specifications: {n}\n'
f'Bootstraps: {results_object.draws}\n'
f'Folds: {results_object.kfold}\n'
f'{median_line}'
)
ax.text(0.05, 0.95, info_text, transform=ax.transAxes, va='top', ha='left',
fontsize=9, color='black', bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=1'))
sns.despine(ax=ax)
return ax
[docs]
def plot_spec_matrix(
results_object,
ax: Optional[plt.Axes] = None,
order_idx: Optional[Sequence[int]] = None,
controls: Optional[Sequence[str]] = None,
oddsratio: bool = False,
title: str = '',
bins: Optional[int] = None,
heatmap_threshold: int = 128,
colormap: Union[str, matplotlib.colors.Colormap] = 'viridis',
cbar_ax: Optional[plt.Axes] = None,
cbar_width: float = 0.04,
cbar_width_fig: Optional[float] = None
) -> plt.Axes:
"""
Plot a dot matrix indicating which controls are included in each specification.
Parameters
----------
results_object : object
Must expose `.specs_names` and (optionally) `.controls`.
ax : matplotlib.axes.Axes, optional
Axes to draw on. Default: current axes.
order_idx : sequence of int, optional
Ordering of specifications along the x-axis. If None, uses the
specification-curve ordering (sorted by median).
controls : sequence of str, optional
Controls to show on the y-axis. Defaults to results_object.controls
if available; otherwise uses the union of spec controls.
oddsratio : bool, default=False
If True, use exponentiated estimates for ordering (to match plot_curve).
title : str, optional
Title text for the axes.
bins : int, optional
If provided, aggregate specifications into this many bins and plot
inclusion rates as a heatmap (0–1) instead of individual dots.
heatmap_threshold : int, default=128
Minimum number of specifications required to switch from dots to heatmap.
colormap : str or Colormap, default='viridis'
Colormap used for heatmap shading (dot matrix uses a fixed blue tone).
cbar_ax : matplotlib.axes.Axes, optional
If provided, draw the heatmap colorbar inside this axes instead of
creating an inset colorbar on the right.
cbar_width : float, default=0.04
Colorbar width as a fraction of the heatmap axes width (or the cbar axis
width when `cbar_ax` is provided).
cbar_width_fig : float, optional
Absolute colorbar width as a fraction of the figure width. If provided,
this overrides `cbar_width` when `cbar_ax` is given, ensuring consistent
absolute thickness across panels.
Returns
-------
matplotlib.axes.Axes
The axes containing the plot.
"""
if ax is None:
ax = plt.gca()
if order_idx is None:
order_idx = _spec_order_idx(results_object, oddsratio)
specs_ordered = results_object.specs_names.iloc[order_idx].tolist()
if controls is None:
controls = getattr(results_object, "controls", None)
if controls is None:
controls = sorted({c for spec in specs_ordered for c in spec})
else:
controls = list(controls)
if len(controls) == 0 or len(specs_ordered) == 0:
axis_formatter(ax, '', 'Ordered Specifications', title)
ax.text(0.5, 0.5, 'No controls', transform=ax.transAxes,
ha='center', va='center', fontsize=11)
ax.set_yticks([])
ax.set_xlim(0, max(len(specs_ordered) - 1, 0))
sns.despine(ax=ax)
return ax
mask = np.zeros((len(controls), len(specs_ordered)), dtype=bool)
for i, ctrl in enumerate(controls):
mask[i, :] = [ctrl in spec for spec in specs_ordered]
if isinstance(colormap, str):
cmap = matplotlib.colormaps[colormap]
elif isinstance(colormap, matplotlib.colors.Colormap):
cmap = colormap
else:
raise TypeError("colormap must be a string name or a Matplotlib Colormap.")
# Dot-matrix uses the consistent blue tone used in panels c/g
dot_color = _blue_palette(1)[0]
n_specs = len(specs_ordered)
use_heatmap = bins is not None and bins > 0 and n_specs > heatmap_threshold
if use_heatmap:
bins = min(bins, n_specs)
idx_bins = np.array_split(np.arange(n_specs), bins)
heat = np.zeros((len(controls), bins), dtype=float)
for b, idxs in enumerate(idx_bins):
if idxs.size == 0:
continue
heat[:, b] = mask[:, idxs].mean(axis=1)
x_edges = np.empty(bins + 1, dtype=float)
x_edges[0] = idx_bins[0][0] - 0.5
for b, idxs in enumerate(idx_bins):
x_edges[b + 1] = idxs[-1] + 0.5
y_edges = np.arange(len(controls) + 1) - 0.5
im = ax.pcolormesh(
x_edges,
y_edges,
heat,
cmap=cmap,
vmin=0,
vmax=1,
shading='flat'
)
if cbar_ax is None:
cax = inset_axes(
ax,
width=f"{cbar_width * 100:.1f}%",
height="100%",
loc="lower left",
bbox_to_anchor=(1.02, 0.0, 1, 1),
bbox_transform=ax.transAxes,
borderpad=0,
)
else:
cbar_ax.set_axis_off()
cbar_pos = cbar_ax.get_position()
if cbar_width_fig is not None and cbar_pos.width > 0:
width_frac = min(1.0, cbar_width_fig / cbar_pos.width)
else:
ax_pos = ax.get_position()
if cbar_pos.width > 0:
width_frac = min(1.0, (cbar_width * ax_pos.width) / cbar_pos.width)
else:
width_frac = cbar_width
cax = inset_axes(
cbar_ax,
width=f"{width_frac * 100:.1f}%",
height="100%",
loc="center",
bbox_to_anchor=(0.0, 0.0, 1.0, 1.0),
bbox_transform=cbar_ax.transAxes,
borderpad=0,
)
cbar = ax.figure.colorbar(im, cax=cax)
if cbar_ax is None:
cbar.set_label('Inclusion Rate', fontsize=11, labelpad=8)
else:
cbar.ax.set_title('')
cbar.ax.yaxis.set_ticks_position('right')
cbar.ax.yaxis.set_label_position('right')
cbar.set_label('Inclusion Rate', fontsize=11, labelpad=4)
cbar.ax.tick_params(labelsize=10, pad=2)
else:
if cbar_ax is not None:
cbar_ax.axis('off')
y_idx, x_idx = np.where(mask)
ax.scatter(
x_idx,
y_idx,
s=28,
marker='o',
facecolors=dot_color,
edgecolors='w',
linewidths=0.8,
alpha=0.9
)
axis_formatter(ax, '', 'Ordered Specifications', title)
ax.set_yticks(range(len(controls)))
ax.set_yticklabels(controls, fontsize=11)
ax.set_ylim(-0.5, len(controls) - 0.5)
ax.invert_yaxis()
ax.set_xlim(-0.5, len(specs_ordered) - 0.5)
ax.xaxis.set_major_locator(mticker.MaxNLocator(5))
ax.grid(False)
ax.xaxis.grid(True, linestyle='--', color='k', alpha=0.1)
ax.yaxis.grid(True, linestyle='--', color='k', alpha=0.06)
return ax
[docs]
def plot_ic(
results_object,
ic: str,
specs: Optional[List[List[str]]] = None,
ax: Optional[plt.Axes] = None,
colormap: str = 'viridis',
title: str = '',
despine_left: bool = True
) -> plt.Axes:
"""
Plots the information criterion (IC) curve, colouring:
• “No Controls” in the first colormap colour
• Each user‐highlighted spec in the next colours
• “Full Model” in the last colormap colour
"""
# Validate IC column
if ic not in results_object.summary_df.columns:
available_ics = [c for c in results_object.summary_df.columns
if c.lower() in {'aic', 'bic', 'hqic'}]
raise ValueError(f"[plot_ic] '{ic}' not found. Available: {available_ics}")
if ax is None:
ax = plt.gca()
# grab exactly len(specs)+2 colours
n_specs = len(specs) if specs else 0
colorset = get_colormap_colors(n_specs + 2, colormap)
df = results_object.summary_df.copy() # get a copy to avoid modifying the original summary_df from results_object
df = df.sort_values(by=ic).reset_index(drop=True)
axis_formatter(ax, f'{ic.upper()} curve', 'Ordered Specifications', title)
if specs:
key = get_selection_key(specs)
full_spec_key = get_selection_key([list(results_object.specs_names.iloc[-1])])
null_spec_key = get_selection_key([list(results_object.specs_names.iloc[0])])
df['idx'] = df.spec_name.isin(key)
df['full_spec_idx'] = df.spec_name.isin(full_spec_key)
df['null_spec_idx'] = df.spec_name.isin(null_spec_key)
df = df.sort_values(by=ic).reset_index(drop=True)
# main IC curve stays the same colour you had
ax.plot(df[ic], color='k')
ymin, ymax = ax.get_ylim()
ax.set_ylim(ymin, ymax)
lines, markers = [], []
# highlighted specs
idxs = df.index[df['idx']].tolist()
for idx, i in zip(idxs, range(n_specs)):
col = colorset[i + 1] # shift by one
label = ', '.join(df.spec_name.iloc[idx])
lines.append(ax.vlines(idx, ymin, df.at[idx, ic], color=col, label=label))
markers.append(Line2D([0], [0], marker='o', color=col,
markerfacecolor='w', markeredgecolor='k',
markersize=10, label=label))
ax.plot(idx, df.at[idx, ic], 'o',
markeredgecolor='k', markerfacecolor='w', markersize=15)
# full‐model in last colour
full_pos = df.index[df['full_spec_idx']].item()
col_f = colorset[-1]
lines.append(ax.vlines(full_pos, ymin, df.at[full_pos, ic], color=col_f, label='Full Model'))
markers.append(Line2D([0], [0], marker='o', color=col_f,
markerfacecolor='w', markeredgecolor='k',
markersize=10, label='Full Model'))
ax.plot(full_pos, df.at[full_pos, ic], 'o',
markeredgecolor='k', markerfacecolor='w', markersize=15)
# no‐controls in first colour
null_pos = df.index[df['null_spec_idx']].item()
col_n = colorset[0]
lines.append(ax.vlines(null_pos, ymin, df.at[null_pos, ic], color=col_n, label='No Controls'))
markers.append(Line2D([0], [0], marker='o', color=col_n,
markerfacecolor='w', markeredgecolor='k',
markersize=10, label='No Controls'))
ax.plot(null_pos, df.at[null_pos, ic], 'o',
markeredgecolor='k', markerfacecolor='w', markersize=15)
ax.legend(handles=markers,
frameon=True, edgecolor='black',
fontsize=9, loc="upper left",
ncols=1, framealpha=1, facecolor='w')
if despine_left:
sns.despine(ax=ax, right=False, left=True)
ax.yaxis.set_label_position("right")
else:
sns.despine(ax=ax)
else:
df['full_spec_idx'] = df.spec_name.isin(get_selection_key([list(results_object.specs_names.iloc[-1])]))
df['null_spec_idx'] = df.spec_name.isin(get_selection_key([list(results_object.specs_names.iloc[0])]))
df = df.sort_values(by=ic).reset_index(drop=True)
ax.plot(df[ic], color='k')
ymin, ymax = ax.get_ylim()
ax.set_ylim(ymin, ymax)
lines, markers = [], []
# full model
full_pos = df.index[df['full_spec_idx']].item()
col_f = colorset[-1]
lines.append(ax.vlines(full_pos, ymin, df.at[full_pos, ic], color=col_f, label='Full Model'))
markers.append(Line2D([0], [0], marker='o', color=col_f,
markerfacecolor='w', markeredgecolor='k',
markersize=10, label='Full Model'))
ax.plot(full_pos, df.at[full_pos, ic], 'o',
markeredgecolor='k', markerfacecolor='w', markersize=15)
# no controls
null_pos = df.index[df['null_spec_idx']].item()
col_n = colorset[0]
lines.append(ax.vlines(null_pos, ymin, df.at[null_pos, ic], color=col_n, label='No Controls'))
markers.append(Line2D([0], [0], marker='o', color=col_n,
markerfacecolor='w', markeredgecolor='k',
markersize=10, label='No Controls'))
ax.plot(null_pos, df.at[null_pos, ic], 'o',
markeredgecolor='k', markerfacecolor='w', markersize=15)
ax.legend(handles=markers,
frameon=True, edgecolor='black',
fontsize=9, loc="upper left",
ncols=1, framealpha=1, facecolor='none')
if despine_left:
ax.yaxis.set_label_position("right")
sns.despine(ax=ax, right=False, left=True)
else:
sns.despine(ax=ax)
ax.xaxis.set_major_locator(mticker.MaxNLocator(5))
return ax
[docs]
def plot_bdist(
results_object,
oddsratio: bool,
specs: Optional[List[List[str]]] = None,
ax: Optional[plt.Axes] = None,
title: str = '',
despine_left: bool = True,
legend_bool: bool = False,
bw_adjust: float = 0.5,
highlights: bool = True,
colormap: Union[str, matplotlib.colors.Colormap] = 'viridis'
) -> plt.Axes:
"""
Plot density‐scaled histograms and KDEs of coefficient distributions, in a fully generalisable way.
KDE is smoothed with a bandwidth adjustment factor.
Parameters
----------
results_object : object
oddsratio : bool
If True, exponentiate the coefficient estimates before plotting.
specs : list of control-lists, optional
Up to three specs to highlight. Default: None (no highlights).
ax : matplotlib.axes.Axes, optional
Axes on which to draw; if None a new (4×3) figure and axes are created.
title : str, default=''
Title to display above the plot.
despine_left : bool, default=True
If True, move y-axis ticks & label to the right spine; otherwise keep on the left.
legend_bool : bool, default=False
If True, draw a custom legend for the highlighted specifications.
bw_adjust : float, default=0.5
Bandwidth adjustment factor for the KDE; larger values make the curve smoother.
highlights : bool, default=True
If True, highlights the full model and the null model in the plot.
colormap : str or Colormap, default='viridis'
Colormap used for highlighted specifications.
Returns
-------
ax : matplotlib.axes.Axes
The axes containing the completed plot.
"""
# 1. Build a long‐form DataFrame with one row per (draw, spec)
if oddsratio is True:
draws_df = results_object.estimates_exp.T.copy()
else:
draws_df = results_object.estimates.T.copy()
# flatten the spec‐names into strings
spec_labels = [s for s in results_object.specs_names]
draws_df.columns = spec_labels
if highlights:
# pick out the special ones
null_label = spec_labels[0]
full_label = spec_labels[-1]
highlight = []
if specs:
requested: Set[frozenset] = {frozenset(sp) for sp in specs}
highlight = [lab for lab in spec_labels if lab in requested]
# define the order we'll plot (so colors map consistently)
if highlights:
order = [null_label] + highlight + [full_label]
else:
order = highlight
if specs is None and highlights is False:
df_long = draws_df.melt(var_name='spec', value_name='coef')
# ensure the order is preserved
cmap = matplotlib.colormaps[colormap] if isinstance(colormap, str) else colormap
palette = [matplotlib.colors.to_hex(cmap(0.95), keep_alpha=False)]
hue=None
else:
df_long = draws_df[order].melt(var_name='spec', value_name='coef')
# ensure the order is preserved
palette = get_colormap_colors(len(order), colormap)
hue='spec'
if ax is None:
ax = plt.gca()
sns.kdeplot(
data=df_long,
x='coef',
hue=hue,
common_norm=False, # each group integrates to 1
bw_adjust=bw_adjust, # controls smoothness (h ∝ bw_adjust)
palette=palette, # line colours for each 'spec'
linewidth=2,
fill=True, # fill under the curve
alpha=0.3, # light shading
ax=ax,
legend=False
)
# 3. Optionally draw a custom legend
if legend_bool:
handles = []
for col, lab in zip(palette, order):
handles.append(
Line2D([0], [0],
marker='s',
color=col,
markerfacecolor=col,
markersize=10,
linestyle='',
label=lab)
)
ax.legend(handles=handles, title='Specification',
frameon=True, loc='upper right')
# 4. Final formatting
axis_formatter(ax, 'Density', 'Bootstrapped Estimand', title)
ax.xaxis.set_major_locator(mticker.MaxNLocator(5))
if despine_left:
ax.yaxis.set_label_position("right")
sns.despine(ax=ax, right=False, left=True)
else:
sns.despine(ax=ax)
return ax
[docs]
def plot_kfolds(
results_object,
colormap: Union[str, matplotlib.colors.Colormap],
ax: Optional[plt.Axes] = None,
title: str = '',
despine_left: bool = True,
tau: float = 0.6
) -> plt.Axes:
"""
Plot the cross-validation metric distribution (density + histogram),
with an adaptive legend positioned safely around the tallest bars.
Parameters
----------
results_object : object
Must expose:
- summary_df : pandas.DataFrame containing column 'av_k_metric'
- name_av_k_metric : str, the metric name (e.g. 'r-squared', 'rmse')
colormap : str or Colormap
Matplotlib colormap name or object used for plotting.
ax : matplotlib.axes.Axes, optional
Axes on which to draw; if None a new (4×3) figure and axes are created.
title : str, default=''
Title to display above the plot.
despine_left : bool, default=True
If True, move y-axis ticks & label to the right spine; otherwise keep on the left.
tau : float in (0,1), default=0.6
Safety factor for legend placement: bars taller than tau*ylim are
considered “in the way” and flip the legend to the opposite side.
Returns
-------
ax : matplotlib.axes.Axes
The axes containing the completed plot.
"""
if ax is None:
_, ax = plt.subplots(figsize=(4, 3))
# KDE & histogram
data = results_object.summary_df['av_k_metric']
hist_color = _blue_palette(1)[0]
density_color = get_colormap_colors(2, colormap)[-1]
sns.kdeplot(data, ax=ax, alpha=1, color=density_color)
sns.histplot(data, ax=ax, alpha=1, color=hist_color, bins=30, stat='density')
# Symmetric x-padding
val_range = data.max() - data.min()
ax.set_xlim(data.min() - 0.1 * val_range, data.max() + 0.1 * val_range)
ax.set_xlim(ax.get_xlim()[0] - 0.066 * val_range, ax.get_xlim()[1]) # original tweak
# Adaptive legend location
legend_loc = _legend_side_from_hist(ax, tau=tau)
legend_elements = [
Line2D([0], [0], color=density_color, lw=2, label='Density'),
Patch(facecolor=hist_color, edgecolor=(0, 0, 0, 1), label='Histogram')
]
ax.legend(handles=legend_elements,
loc=legend_loc,
frameon=True,
fontsize=9,
title='Out-of-Sample',
title_fontsize=10,
framealpha=1,
facecolor='w',
edgecolor=(0, 0, 0, 1))
# Cosmetic axes work
ax.tick_params(axis='both', which='major', labelsize=11)
ax.grid(linestyle='--', color='k', alpha=0.1, zorder=-1)
ax.set_axisbelow(True)
name = results_object.name_av_k_metric
metric = r'R$^2$' if name.lower() == 'r-squared' else (name.upper() if name == 'rmse' else name.title())
axis_formatter(ax, 'Density', f'OOS Metric: {metric}', title)
ax.xaxis.set_major_locator(mticker.MaxNLocator(5))
if despine_left:
ax.yaxis.set_label_position("right")
sns.despine(ax=ax, right=False, left=True)
else:
sns.despine(ax=ax)
[docs]
def plot_bma(
results_object,
colormap: Union[str, matplotlib.colors.Colormap],
ax: plt.Axes,
feature_order: Sequence[str],
title: str = ''
) -> plt.Axes:
"""
Plot Bayesian Model Averaging (BMA) inclusion probabilities as a horizontal bar chart.
Parameters
----------
results_object : object
Must implement `compute_bma()` returning a DataFrame with columns:
- 'control_var'
- 'probs'
colormap : str or Colormap
Matplotlib colormap name or object used to pick the bar color.
ax : matplotlib.axes.Axes
Axes on which to draw the horizontal bar chart.
feature_order : sequence of str
Ordered list of control variable names to display on the y-axis.
title : str, default=''
Title to display above the plot.
Returns
-------
ax : matplotlib.axes.Axes
The axes containing the completed BMA plot.
"""
bma = results_object.compute_bma()
bma = bma.set_index('control_var')
bma = bma.reindex(feature_order)
bma['probs'].plot(kind='barh',
ax=ax,
alpha=1,
color=_blue_palette(1)[0],
edgecolor='k',
)
axis_formatter(ax, r'', 'BMA Probabilities', title)
sns.despine(ax=ax)
[docs]
def title_setter(
ax: plt.Axes,
title: str,
side: str = 'left'
) -> None:
"""
Set a title on `ax`, aligned on the left but positioned differently
depending on whether the y-axis is on the left or right.
Parameters
----------
ax : matplotlib.axes.Axes
The axes whose title you wish to set.
title : str
The title text.
side : {'left', 'right'}, default='left'
- 'left': standard positioning.
- 'right': shifts the title so it doesn’t overlap a right-side y-axis.
"""
if side == 'right':
return ax.set_title(title, loc='left', fontsize=16, y=1, x=-.26, fontweight='bold')
else:
return ax.set_title(title, loc='left', fontsize=16, y=1, fontweight='bold')
def _sanitize_specs(
specs: Optional[List[List[str]]],
max_len: int = 4
) -> Optional[List[List[str]]]:
"""
Truncate the specs list to at most `max_len`, warning if we had to cut.
Parameters
----------
specs: List of lists,
Control names to highlight in the curve, IC, and distribution panels.
max_len: int, default=4
Maximum number of specs to keep; if more are provided, truncate and warn.
"""
if specs is not None and len(specs) > max_len:
warnings.warn(
f"Received {len(specs)} specs; only the first {max_len} will be used.",
UserWarning
)
return specs[:max_len]
return specs
def _prepare_output_dir(figpath: Optional[Path], project: Optional[str]) -> Path:
base = figpath or Path.cwd()
if project:
base = base / project
try:
base.mkdir(parents=True, exist_ok=True)
except OSError as e:
raise RuntimeError(f"Could not create output directory {base!r}") from e
return base
[docs]
def plot_results(
results_object,
loess: bool = True,
ci: float = 0.95,
specs: Optional[List[List[str]]] = None,
ic: Optional[str] = None,
colormap: Union[str, matplotlib.colors.Colormap] = 'viridis',
figsize: Tuple[int, int] = (16, 16),
ext: str = 'pdf',
figpath=None,
highlights = True,
oddsratio=False,
project_name: str = None,
spec_matrix_bins: int = 128,
spec_matrix_threshold: int = 128
) -> None:
"""
Plots the coefficient estimates, IC curve, and distribution plots for the given results object.
Parameters
----------
results_object : object
An OLSResult-like object (must expose attributes `y_name`, `x_name`,
`shap_return`, `summary_df`, `specs_names`, etc.).
loess : bool, default=True
Whether to apply LOESS smoothing to the coefficient–specification curve.
ci: float, default=0.95
The confidence interval to use.
specs : list of list of str, optional
Up to three specs (lists of control names) to highlight in the curve, IC, and distribution panels.
ic : str, optional
Information criterion name to plot (one of 'aic','bic','hqic').
colormap : str or Colormap, default='viridis'
Colormap used consistently for all panels.
figsize : (width, height), default=(16,16)
Size of the full figure in inches.
figpath : str or Path, optional
Directory in which to save outputs; if None, uses current working dir.
ext : str, default='pdf'
File extension to save each panel (e.g. 'png','pdf').
project_name : str, default=None
Directory and filename prefix under `./figures/`.
spec_matrix_bins : int, default=128
Number of bins to use for the binned spec matrix heatmap.
spec_matrix_threshold : int, default=128
Minimum number of specifications required to switch from dots to heatmap.
oddsratio bool, default=False
Whether to exponentiate the coefficients (e.g. for odds ratios).
highlights bool, default=True
Whether to highlight certain specifications.
Notes
-----
- Saves a combined “_all” figure plus individual panels named:
`_R2hexbin`, `_OOS`, `_curve`, `_LLhexbin`, `_SHAP`, `_BMA`, `_IC`, `_bdist`.
for the case when len(y_name) == 1, and a subset for when >1.
"""
# If 'draws' or 'kfold' is a list/tuple, assume this is a merged‐results object:
# Handle merged results objects that are unsupported
if isinstance(results_object.draws, (list, tuple)) or isinstance(results_object.kfold, (list, tuple)):
warnings.warn(
"plot_results was passed a *merged* results object (draws/kfold are lists). "
"This function does not support plotting a merged‐results object. "
"Please extract individual result objects and plot them separately. "
"Exiting without plotting.",
UserWarning
)
return
# Clean up file extension and sanitize spec input
ext = ext.strip()
specs = _sanitize_specs(specs, max_len=6)
# Use a safe filename prefix even when caller passes project_name=None.
project_stem = project_name or "no_project_name"
# Handle odds ratio transformation for logistic regression models
if oddsratio is True:
if results_object.model_name == 'Logistic Regression Robust':
results_object.all_b_exp = np.exp([arr[0][0] for arr in results_object.all_b.copy()])
results_object.estimates_exp = np.exp(results_object.estimates.copy())
else:
raise ValueError("`oddsratio` option is only valid for logistic regression models.")
# Validate confidence interval value
if not (0 <= ci <= 1):
raise ValueError(f"`ci` must lie between 0 and 1; received ci={ci!r}")
# Prepare output directory for saving figures
outdir = _prepare_output_dir(Path(figpath) if figpath else None, project_name)
# If y_name is univariate, plot full 8-panel grid
if max(len(t) for t in results_object.y_name) == 1:
fig = plt.figure(figsize=figsize)
gs = GridSpec(9, 24, wspace=0.5, hspace=1.5)
# Set up axes for subplots
ax1 = fig.add_subplot(gs[0:3, 0:12])
ax2 = fig.add_subplot(gs[0:3, 13:24])
ax4 = fig.add_subplot(gs[3:5, 6:12])
ax3 = fig.add_subplot(gs[3:5, 0:6])
ax4_cbar = fig.add_subplot(gs[3:5, 12:13])
ax5 = fig.add_subplot(gs[3:5, 14:23])
spec_gs = GridSpecFromSubplotSpec(
2,
2,
subplot_spec=gs[5:9, 0:13],
height_ratios=[2, 1],
width_ratios=[12, 1],
hspace=0.12,
wspace=0.05
)
ax6 = fig.add_subplot(spec_gs[0, 0])
ax6m = fig.add_subplot(spec_gs[1, 0], sharex=ax6)
ax6m_cbar = fig.add_subplot(spec_gs[1, 1])
ax6_cbar_spacer = fig.add_subplot(spec_gs[0, 1])
ax6_cbar_spacer.axis('off')
ax7 = fig.add_subplot(gs[5:7, 12:21])
ax8 = fig.add_subplot(gs[7:9, 14:23])
# Prepare SHAP values and input matrix (control/z variables only).
# This keeps panel c/d/spec-matrix focused on specification-varying terms.
shap_vals_full = np.asarray(results_object.shap_return[0])
shap_x_full = results_object.shap_return[1]
if not isinstance(shap_x_full, pd.DataFrame):
shap_x_full = pd.DataFrame(shap_x_full)
if shap_vals_full.ndim != 2:
raise ValueError(
f"Expected 2D SHAP values, received shape {shap_vals_full.shape!r}."
)
# Some SHAP pipelines prepend a bias column; drop it if present.
if shap_vals_full.shape[1] == shap_x_full.shape[1] + 1:
shap_vals_full = shap_vals_full[:, 1:]
if shap_vals_full.shape[1] != shap_x_full.shape[1]:
raise ValueError(
f"SHAP values columns ({shap_vals_full.shape[1]}) do not match "
f"feature matrix columns ({shap_x_full.shape[1]})."
)
shap_cols = [c for c in results_object.controls if c in shap_x_full.columns]
shap_idx = [shap_x_full.columns.get_loc(c) for c in shap_cols]
shap_vals = shap_vals_full[:, shap_idx]
shap_x = shap_x_full[shap_cols].to_numpy()
# Generate plots in the grid
plot_hexbin_r2(results_object, ax1, fig, oddsratio, colormap, title='a.')
plot_hexbin_log(results_object, ax2, fig, oddsratio, colormap, title='b.')
cbar_width = 0.046
cbar_width_fig = cbar_width * ax1.get_position().width
feature_order = shap_violin(
ax4,
shap_vals,
shap_x,
shap_cols,
title='d.',
clear_yticklabels=True,
cmap=colormap,
cbar_ax=ax4_cbar,
cbar_width=cbar_width,
cbar_width_fig=cbar_width_fig
)
plot_bma(results_object, colormap, ax3, feature_order, title='c.')
plot_kfolds(results_object, colormap, ax5, title='e.', despine_left=True)
plot_curve(results_object=results_object, loess=loess, ci=ci, specs=specs,
ax=ax6, highlights=highlights, title='f.', oddsratio=oddsratio, colormap=colormap)
order_idx = _spec_order_idx(results_object, oddsratio)
# Build the spec-matrix control list: all controls, ordered by
# SHAP importance (feature_order) where available, with any
# controls not ranked by SHAP appended at the end.
controls_set = set(results_object.controls)
spec_matrix_controls = [c for c in feature_order if c in controls_set]
for c in results_object.controls:
if c not in spec_matrix_controls:
spec_matrix_controls.append(c)
plot_spec_matrix(
results_object=results_object,
ax=ax6m,
order_idx=order_idx,
oddsratio=oddsratio,
controls=spec_matrix_controls,
bins=spec_matrix_bins,
heatmap_threshold=spec_matrix_threshold,
colormap=colormap,
cbar_ax=ax6m_cbar,
cbar_width=cbar_width,
cbar_width_fig=cbar_width_fig
)
n_specs = len(results_object.specs_names)
use_heatmap = (
spec_matrix_bins is not None
and spec_matrix_bins > 0
and n_specs > spec_matrix_threshold
)
right_x1 = ax5.get_position().x1
expanded_left_x0 = ax6_cbar_spacer.get_position().x0
# Panel g can always expand into the blank spacer to the right of f.
_set_axes_horizontal_span(expanded_left_x0, right_x1, ax7)
# Panel h can only expand when the lower row is a dot plot; when a
# heatmap is used, keep h aligned to the e-panel column.
if use_heatmap:
_right_align_axes_to(ax5, ax8)
else:
_set_axes_horizontal_span(expanded_left_x0, right_x1, ax8)
locator = mticker.MaxNLocator(5)
ax6.xaxis.set_major_locator(locator)
ax6m.xaxis.set_major_locator(locator)
ax6.minorticks_off()
ax6m.minorticks_off()
ax6.xaxis.set_minor_locator(mticker.NullLocator())
ax6m.xaxis.set_minor_locator(mticker.NullLocator())
ax6.set_xlabel('')
ax6.tick_params(axis='x', which='major', bottom=True, labelbottom=False)
plot_ic(results_object=results_object, ic=ic, specs=specs, ax=ax7,
colormap=colormap, title='g.', despine_left=True)
plot_bdist(results_object=results_object, specs=specs, ax=ax8,
oddsratio=oddsratio, highlights=highlights, colormap=colormap,
title='h.', despine_left=True)
# Save the full panel figure
if ext == 'png':
plt.savefig(os.path.join(outdir, project_stem + '_all.' + ext), bbox_inches='tight', dpi=800)
else:
plt.savefig(os.path.join(outdir, project_stem + '_all.' + ext), bbox_inches='tight')
else:
# Plot reduced layout when y is multivariate
fig = plt.figure(figsize=figsize)
gs = GridSpec(6, 24, wspace=-.25, hspace=5)
ax1 = fig.add_subplot(gs[0:6, 0:16])
ax2 = fig.add_subplot(gs[0:3, 17:24])
ax3 = fig.add_subplot(gs[3:6, 17:24])
# Generate and save each plot
plot_curve(results_object=results_object, loess=loess, ci=ci, specs=specs,
ax=ax1, highlights=highlights, title='a.', oddsratio=oddsratio, colormap=colormap)
plot_hexbin_r2(results_object, ax2, fig, oddsratio, colormap, title='b.', side='right')
plot_bdist(results_object=results_object, specs=specs, ax=ax3,
oddsratio=oddsratio, highlights=highlights, colormap=colormap,
title='c.', despine_left=True)
# Save the full panel figure
plt.savefig(os.path.join(outdir, project_stem + '_all.' + ext), bbox_inches='tight')
# Plot and save individual panels
fig, ax = plt.subplots(figsize=(8.5, 5))
plot_hexbin_r2(results_object, ax, fig, oddsratio, colormap)
plt.savefig(os.path.join(outdir, project_stem + '_R2hexbin.' + ext), bbox_inches='tight')
plt.close(fig)
fig, ax = plt.subplots(figsize=(8.5, 5))
plot_kfolds(results_object=results_object, colormap=colormap, ax=ax, despine_left=False)
plt.savefig(os.path.join(outdir, project_stem + '_OOS.' + ext), bbox_inches='tight')
plt.close(fig)
fig, ax = plt.subplots(figsize=(12, 7))
plot_curve(results_object=results_object, loess=loess, ci=ci,
oddsratio=oddsratio, highlights=highlights, specs=specs, ax=ax, colormap=colormap)
plt.savefig(os.path.join(outdir, project_stem + '_curve.' + ext), bbox_inches='tight')
plt.close(fig)
# Additional subplots only if y is univariate
if max(len(t) for t in results_object.y_name) == 1:
fig, ax = plt.subplots(figsize=(8.5, 5))
plot_hexbin_log(results_object, ax, fig, oddsratio, colormap)
plt.savefig(os.path.join(outdir, project_stem + '_LLhexbin.' + ext), bbox_inches='tight')
plt.close(fig)
fig, ax = plt.subplots(figsize=(8.5, 5))
feature_order = shap_violin(ax, shap_vals, shap_x, shap_cols, clear_yticklabels=False, cmap=colormap)
plt.savefig(os.path.join(outdir, project_stem + '_SHAP.' + ext), bbox_inches='tight')
plt.close(fig)
fig, ax = plt.subplots(figsize=(8.5, 5))
plot_bma(results_object, colormap, ax, feature_order)
plt.savefig(os.path.join(outdir, project_stem + '_BMA.' + ext), bbox_inches='tight')
plt.close(fig)
fig, ax = plt.subplots(figsize=(8.5, 5))
plot_ic(results_object=results_object, ic=ic, specs=specs, ax=ax,
colormap=colormap, title='g.', despine_left=False)
plt.savefig(os.path.join(outdir, project_stem + '_IC.' + ext), bbox_inches='tight')
plt.close(fig)
fig, ax = plt.subplots(figsize=(8.5, 5))
plot_bdist(results_object=results_object, specs=specs, ax=ax,
oddsratio=oddsratio, despine_left=False, colormap=colormap,
highlights=highlights, legend_bool=False)
plt.savefig(os.path.join(outdir, project_stem + '_bdist.' + ext), bbox_inches='tight')
plt.close(fig)