"""
Plotting utilities for Flow-Disentangled Feature Importance.
The functions in this module provide static Matplotlib visualizations for
global FDFI scores, per-sample UEIFs, confidence intervals, disentanglement
diagnostics, and feature-correlation structure.
"""
from __future__ import annotations
import warnings
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from scipy.cluster import hierarchy
from scipy.spatial import distance
ArrayLike = Union[np.ndarray, Sequence[float]]
FeatureRef = Union[int, str]
__all__ = [
"summary_plot",
"waterfall_plot",
"force_plot",
"dependence_plot",
"correlation_heatmap",
"summary_bar",
"confidence_interval_plot",
"diagnostics_plot",
]
def _as_1d(values: ArrayLike, name: str) -> np.ndarray:
"""Return a finite-shape 1D float array."""
arr = np.asarray(values, dtype=float)
if arr.ndim != 1:
raise ValueError(f"{name} must be 1-dimensional; got shape {arr.shape}")
return arr
def _as_2d(values: ArrayLike, name: str) -> np.ndarray:
"""Return a 2D float array."""
arr = np.asarray(values, dtype=float)
if arr.ndim != 2:
raise ValueError(f"{name} must be 2-dimensional; got shape {arr.shape}")
return arr
def _feature_names(feature_names: Optional[Sequence[Any]], n_features: int) -> List[str]:
"""Validate or create feature names."""
if feature_names is None:
return [f"Feature {i}" for i in range(n_features)]
names = [str(name) for name in feature_names]
if len(names) != n_features:
raise ValueError(
f"feature_names has length {len(names)} but expected {n_features}"
)
return names
def _resolve_feature_index(
feature: FeatureRef,
feature_names: Optional[Sequence[Any]],
n_features: int,
name: str,
) -> int:
"""Resolve an integer or string feature reference to a column index."""
if isinstance(feature, str):
names = _feature_names(feature_names, n_features)
if feature not in names:
raise ValueError(f"{name}={feature!r} is not in feature_names")
return names.index(feature)
if isinstance(feature, (int, np.integer)):
idx = int(feature)
if idx < 0:
idx += n_features
if idx < 0 or idx >= n_features:
raise ValueError(f"{name} index {feature} is out of bounds")
return idx
raise TypeError(f"{name} must be an integer index or feature name")
def _top_order(scores: np.ndarray, max_display: Optional[int]) -> np.ndarray:
"""Return stable descending order by absolute score."""
if max_display is not None and max_display <= 0:
raise ValueError("max_display must be positive")
order = np.argsort(-np.abs(scores), kind="stable")
if max_display is not None:
order = order[: min(int(max_display), len(order))]
return order
def _fig_ax(
ax: Optional[Axes],
figsize: Optional[Tuple[float, float]],
) -> Tuple[Figure, Axes]:
"""Create or reuse a Matplotlib figure/axes pair."""
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
fig = ax.figure
return fig, ax
def _finish_figure(
fig: Figure,
savepath: Optional[str],
show: bool,
*,
tight_layout: bool = True,
dpi: int = 150,
bbox_inches: str = "tight",
) -> None:
"""Apply layout, optional save, and optional display."""
if tight_layout:
fig.tight_layout()
if savepath is not None:
fig.savefig(savepath, dpi=dpi, bbox_inches=bbox_inches)
if show:
plt.show()
def _sanitize_se(se: Optional[ArrayLike], n_features: int) -> np.ndarray:
"""Return nonnegative finite standard errors suitable for error bars."""
if se is None:
return np.zeros(n_features, dtype=float)
se_arr = _as_1d(se, "se_X")
if se_arr.shape[0] != n_features:
raise ValueError(f"se_X has length {se_arr.shape[0]} but expected {n_features}")
se_arr = np.abs(se_arr.astype(float, copy=True))
finite = np.isfinite(se_arr)
max_finite = float(np.max(se_arr[finite])) if np.any(finite) else 0.0
se_arr = np.where(np.isnan(se_arr), 0.0, se_arr)
se_arr = np.where(np.isinf(se_arr), max_finite, se_arr)
return se_arr
def _summary_dataframe(
feature_names: Sequence[str],
phi: np.ndarray,
se: np.ndarray,
):
"""Build the sorted summary table, using pandas when available."""
try:
import pandas as pd
except ImportError as exc: # pragma: no cover - pandas is installed via seaborn
raise ImportError(
"summary_bar returns a pandas DataFrame. Install pandas to use it."
) from exc
return (
pd.DataFrame(
{
"feature": list(feature_names),
"phi": np.abs(phi),
"se": se,
}
)
.sort_values("phi", ascending=False, kind="mergesort")
.reset_index(drop=True)
)
def _group_remaining(
values: np.ndarray,
names: Sequence[str],
max_display: int,
) -> Tuple[np.ndarray, List[str], np.ndarray]:
"""Select top waterfall/force features and optionally group the remainder."""
if max_display <= 0:
raise ValueError("max_display must be positive")
n_features = values.shape[0]
order = _top_order(values, None)
if n_features <= max_display:
top = order
return values[top], [names[i] for i in top], top
top_count = max_display - 1
top = order[:top_count]
rest = order[top_count:]
grouped_values = np.concatenate([values[top], [values[rest].sum()]])
grouped_names = [names[i] for i in top]
grouped_names.append(f"{len(rest)} remaining features")
grouped_order = np.concatenate([top, [-1]])
return grouped_values, grouped_names, grouped_order
def _feature_value_label(name: str, value: Optional[float]) -> str:
"""Combine a feature name and displayed value for single-sample plots."""
if value is None or not np.isfinite(value):
return name
return f"{name} = {value:.3g}"
[docs]
def summary_bar(
phi_X: ArrayLike,
se_X: Optional[ArrayLike] = None,
feature_names: Optional[Sequence[Any]] = None,
group_colors: Optional[Mapping[str, Any]] = None,
savepath: Optional[str] = None,
max_display: Optional[int] = None,
ax: Optional[Axes] = None,
show: bool = True,
**kwargs: Any,
):
"""
Plot global FDFI feature importance as a sorted bar chart.
Parameters
----------
phi_X : array-like of shape (n_features,)
Global FDFI scores such as ``results["phi_X"]`` or ``results["phi_Z"]``.
Bar lengths use ``abs(phi_X)`` so signed attribution summaries are also
supported.
se_X : array-like of shape (n_features,), optional
Standard errors such as ``results["se_X"]``. Missing values default to
zero. NaN, inf, and negative entries are sanitized before plotting.
feature_names : sequence of str, optional
Feature names. Defaults to ``Feature 0``, ``Feature 1``, ...
group_colors : mapping, optional
Mapping from feature name to a Matplotlib color. Missing features use a
neutral gray. When omitted, a colormap gradient is used.
savepath : str, optional
Path where the figure should be saved.
max_display : int, optional
Maximum number of features to show.
ax : matplotlib.axes.Axes, optional
Existing axes to draw on.
show : bool, default=True
Whether to display the figure via ``plt.show()``.
**kwargs
Styling options. Common keys include ``figsize``, ``title``, ``cmap``,
``capsize``, ``elinewidth``, ``dpi``, and ``bbox_inches``.
Returns
-------
fig : matplotlib.figure.Figure
The figure object.
ax : matplotlib.axes.Axes
The axes object.
importance_df : pandas.DataFrame
Sorted table with columns ``feature``, ``phi``, and ``se``.
Examples
--------
>>> fig, ax, table = summary_bar(
... results["phi_X"],
... se_X=results["se_X"],
... feature_names=feature_names,
... show=False,
... )
"""
phi = _as_1d(phi_X, "phi_X")
names = _feature_names(feature_names, phi.shape[0])
se = _sanitize_se(se_X, phi.shape[0])
importance_df = _summary_dataframe(names, phi, se)
if max_display is not None:
if max_display <= 0:
raise ValueError("max_display must be positive")
plot_df = importance_df.head(int(max_display)).copy()
else:
plot_df = importance_df
figsize = kwargs.get("figsize", (8.0, max(3.5, 0.34 * len(plot_df) + 1.5)))
fig, ax = _fig_ax(ax, figsize)
if group_colors is None:
cmap = plt.get_cmap(kwargs.get("cmap", "viridis"))
values = plot_df["phi"].to_numpy(dtype=float)
if len(values) == 0 or float(values.max()) == float(values.min()):
normalized = np.full(len(values), 0.55)
else:
normalized = (values - values.min()) / (values.max() - values.min())
colors = [cmap(value) for value in normalized]
else:
colors = [group_colors.get(feature, "#888888") for feature in plot_df["feature"]]
y_pos = np.arange(len(plot_df))
ax.barh(
y_pos,
plot_df["phi"],
xerr=plot_df["se"],
color=colors,
edgecolor="white",
linewidth=0.7,
capsize=kwargs.get("capsize", 4),
error_kw={
"elinewidth": kwargs.get("elinewidth", 1.2),
"ecolor": kwargs.get("error_color", "#222222"),
"capthick": kwargs.get("elinewidth", 1.2),
},
zorder=3,
)
ax.set_yticks(y_pos)
ax.set_yticklabels(plot_df["feature"], fontsize=kwargs.get("tick_fontsize", 9))
ax.invert_yaxis()
ax.set_xlabel(kwargs.get("xlabel", "Mean absolute FDFI score"))
ax.set_ylabel(kwargs.get("ylabel", "Feature"))
ax.set_title(
kwargs.get("title", "Global FDFI Feature Importance"),
fontsize=kwargs.get("title_fontsize", 11),
)
ax.grid(axis="x", linestyle="--", alpha=0.35, zorder=0)
ax.set_axisbelow(True)
_finish_figure(
fig,
savepath,
show,
dpi=kwargs.get("dpi", 150),
bbox_inches=kwargs.get("bbox_inches", "tight"),
)
return fig, ax, importance_df
[docs]
def summary_plot(
shap_values: ArrayLike,
features: Optional[ArrayLike] = None,
feature_names: Optional[Sequence[Any]] = None,
max_display: int = 20,
show: bool = True,
ax: Optional[Axes] = None,
savepath: Optional[str] = None,
**kwargs: Any,
):
"""
Create a SHAP-like summary plot for FDFI attributions.
Parameters
----------
shap_values : array-like of shape (n_features,) or (n_samples, n_features)
FDFI attribution values. Use aggregate arrays such as ``results["phi_X"]``
for bar summaries, or per-sample arrays such as ``explainer.ueifs_X`` for
beeswarm summaries.
features : array-like of shape (n_samples, n_features), optional
Feature values used for point colors in a 2D beeswarm plot.
feature_names : sequence of str, optional
Feature names.
max_display : int, default=20
Maximum number of features to display.
show : bool, default=True
Whether to display the figure via ``plt.show()``.
ax : matplotlib.axes.Axes, optional
Existing axes to draw on.
savepath : str, optional
Path where the figure should be saved.
**kwargs
Styling options. For 1D input, forwarded to :func:`summary_bar`.
Returns
-------
fig, ax : tuple
Matplotlib figure and axes for 2D input.
fig, ax, importance_df : tuple
For 1D input, the return value from :func:`summary_bar`.
Examples
--------
>>> summary_plot(explainer.ueifs_X, features=X_test, show=False)
>>> summary_plot(results["phi_X"], se_X=results["se_X"], show=False)
"""
values = np.asarray(shap_values, dtype=float)
if values.ndim == 1:
return summary_bar(
values,
se_X=kwargs.pop("se_X", None),
feature_names=feature_names,
group_colors=kwargs.pop("group_colors", None),
savepath=savepath,
max_display=max_display,
ax=ax,
show=show,
**kwargs,
)
if values.ndim != 2:
raise ValueError(f"shap_values must be 1D or 2D; got shape {values.shape}")
n_samples, n_features = values.shape
names = _feature_names(feature_names, n_features)
if max_display <= 0:
raise ValueError("max_display must be positive")
feature_values = None
if features is not None:
feature_values = _as_2d(features, "features")
if feature_values.shape != values.shape:
raise ValueError(
"features must have the same shape as shap_values; "
f"got {feature_values.shape} and {values.shape}"
)
mean_abs = np.nanmean(np.abs(values), axis=0)
order = _top_order(mean_abs, max_display)
plot_names = [names[i] for i in order]
figsize = kwargs.get("figsize", (8.5, max(3.5, 0.35 * len(order) + 1.6)))
fig, ax = _fig_ax(ax, figsize)
cmap = kwargs.get("cmap", "coolwarm")
dot_size = kwargs.get("dot_size", 18)
alpha = kwargs.get("alpha", 0.75)
color_values = None
scatter = None
if feature_values is not None:
color_values = feature_values[:, order].reshape(-1)
finite_color = color_values[np.isfinite(color_values)]
if finite_color.size:
vmin, vmax = np.percentile(finite_color, [5, 95])
if vmin == vmax:
vmin, vmax = float(finite_color.min()), float(finite_color.max())
else:
vmin, vmax = None, None
else:
vmin, vmax = None, None
for row, feature_index in enumerate(order):
x = values[:, feature_index]
finite = np.isfinite(x)
if not np.any(finite):
continue
x = x[finite]
offsets = ((np.arange(x.shape[0]) % 9) - 4) * 0.035
y = np.full(x.shape[0], row, dtype=float) + offsets
if feature_values is None:
scatter = ax.scatter(
x,
y,
s=dot_size,
alpha=alpha,
color=kwargs.get("color", "#1f77b4"),
edgecolors="none",
rasterized=kwargs.get("rasterized", False),
)
else:
c = feature_values[:, feature_index][finite]
scatter = ax.scatter(
x,
y,
s=dot_size,
alpha=alpha,
c=c,
cmap=cmap,
vmin=vmin,
vmax=vmax,
edgecolors="none",
rasterized=kwargs.get("rasterized", False),
)
ax.axvline(0.0, color="#555555", linewidth=0.8, zorder=0)
ax.set_yticks(np.arange(len(order)))
ax.set_yticklabels(plot_names, fontsize=kwargs.get("tick_fontsize", 9))
ax.invert_yaxis()
ax.set_xlabel(kwargs.get("xlabel", "FDFI attribution value"))
ax.set_ylabel(kwargs.get("ylabel", "Feature"))
ax.set_title(
kwargs.get("title", "FDFI Per-sample Attribution Summary"),
fontsize=kwargs.get("title_fontsize", 11),
)
ax.grid(axis="x", linestyle="--", alpha=0.25, zorder=0)
if feature_values is not None and scatter is not None and kwargs.get(
"color_bar", True
):
cbar = fig.colorbar(scatter, ax=ax, pad=0.02)
cbar.set_label(kwargs.get("colorbar_label", "Feature value"))
_finish_figure(
fig,
savepath,
show,
dpi=kwargs.get("dpi", 150),
bbox_inches=kwargs.get("bbox_inches", "tight"),
)
return fig, ax
[docs]
def waterfall_plot(
shap_values: ArrayLike,
features: Optional[ArrayLike] = None,
feature_names: Optional[Sequence[Any]] = None,
max_display: int = 10,
base_value: float = 0.0,
show: bool = True,
ax: Optional[Axes] = None,
savepath: Optional[str] = None,
**kwargs: Any,
):
"""
Create a single-explanation waterfall plot.
Parameters
----------
shap_values : array-like of shape (n_features,)
Single-sample FDFI attributions, for example ``explainer.ueifs_X[0]``.
features : array-like of shape (n_features,), optional
Feature values for label annotations.
feature_names : sequence of str, optional
Feature names.
max_display : int, default=10
Maximum number of features to display. Extra features are summed into a
final "remaining features" row.
base_value : float, default=0.0
Starting value for the additive explanation.
show : bool, default=True
Whether to display the figure via ``plt.show()``.
ax : matplotlib.axes.Axes, optional
Existing axes to draw on.
savepath : str, optional
Path where the figure should be saved.
**kwargs
Styling options.
Returns
-------
fig, ax : tuple
Matplotlib figure and axes.
Examples
--------
>>> waterfall_plot(explainer.ueifs_X[0], feature_names=feature_names, show=False)
"""
values = _as_1d(shap_values, "shap_values")
n_features = values.shape[0]
names = _feature_names(feature_names, n_features)
feature_values = None
if features is not None:
feature_values = _as_1d(features, "features")
if feature_values.shape[0] != n_features:
raise ValueError(
f"features has length {feature_values.shape[0]} but expected {n_features}"
)
grouped_values, grouped_names, grouped_order = _group_remaining(
values, names, max_display
)
labels = []
for label, original_index in zip(grouped_names, grouped_order):
if original_index == -1 or feature_values is None:
labels.append(label)
else:
labels.append(_feature_value_label(label, feature_values[original_index]))
figsize = kwargs.get("figsize", (8.5, max(3.2, 0.45 * len(grouped_values) + 1.4)))
fig, ax = _fig_ax(ax, figsize)
positive_color = kwargs.get("positive_color", "#d62728")
negative_color = kwargs.get("negative_color", "#1f77b4")
running = float(base_value)
y_pos = np.arange(len(grouped_values))
for y, contribution in zip(y_pos, grouped_values):
color = positive_color if contribution >= 0 else negative_color
ax.barh(
y,
contribution,
left=running,
color=color,
alpha=0.82,
edgecolor="white",
linewidth=0.7,
)
next_value = running + contribution
ax.plot(
[next_value, next_value],
[y - 0.38, y + 0.38],
color="#555555",
linewidth=0.6,
alpha=0.5,
)
running = next_value
final_value = float(base_value + grouped_values.sum())
ax.axvline(base_value, color="#333333", linestyle="--", linewidth=1.0)
ax.axvline(final_value, color="#333333", linestyle="-", linewidth=1.0)
ax.set_yticks(y_pos)
ax.set_yticklabels(labels, fontsize=kwargs.get("tick_fontsize", 9))
ax.invert_yaxis()
ax.set_xlabel(kwargs.get("xlabel", "Model output"))
ax.set_title(
kwargs.get(
"title",
f"FDFI Waterfall (base={base_value:.3g}, final={final_value:.3g})",
),
fontsize=kwargs.get("title_fontsize", 11),
)
ax.grid(axis="x", linestyle="--", alpha=0.25)
_finish_figure(
fig,
savepath,
show,
dpi=kwargs.get("dpi", 150),
bbox_inches=kwargs.get("bbox_inches", "tight"),
)
return fig, ax
[docs]
def force_plot(
base_value: float,
shap_values: ArrayLike,
features: Optional[ArrayLike] = None,
feature_names: Optional[Sequence[Any]] = None,
max_display: int = 10,
show: bool = True,
ax: Optional[Axes] = None,
savepath: Optional[str] = None,
**kwargs: Any,
):
"""
Create a static force-style contribution plot.
Parameters
----------
base_value : float
Baseline model output.
shap_values : array-like of shape (n_features,) or (n_samples, n_features)
FDFI attributions. If a 2D array is supplied, values are averaged across
samples before plotting.
features : array-like, optional
Feature values for single-sample label annotations.
feature_names : sequence of str, optional
Feature names.
max_display : int, default=10
Maximum number of features to display. Extra features are grouped.
show : bool, default=True
Whether to display the figure via ``plt.show()``.
ax : matplotlib.axes.Axes, optional
Existing axes to draw on.
savepath : str, optional
Path where the figure should be saved.
**kwargs
Styling options.
Returns
-------
fig, ax : tuple
Matplotlib figure and axes.
Examples
--------
>>> force_plot(0.0, explainer.ueifs_X[0], feature_names=feature_names, show=False)
"""
values = np.asarray(shap_values, dtype=float)
if values.ndim == 2:
values = np.nanmean(values, axis=0)
elif values.ndim != 1:
raise ValueError(f"shap_values must be 1D or 2D; got shape {values.shape}")
n_features = values.shape[0]
names = _feature_names(feature_names, n_features)
feature_values = None
if features is not None:
feature_values = _as_1d(features, "features")
if feature_values.shape[0] != n_features:
raise ValueError(
f"features has length {feature_values.shape[0]} but expected {n_features}"
)
grouped_values, grouped_names, grouped_order = _group_remaining(
values, names, max_display
)
labels = []
for label, original_index in zip(grouped_names, grouped_order):
if original_index == -1 or feature_values is None:
labels.append(label)
else:
labels.append(_feature_value_label(label, feature_values[original_index]))
figsize = kwargs.get("figsize", (9.0, 2.8))
fig, ax = _fig_ax(ax, figsize)
positive_color = kwargs.get("positive_color", "#d62728")
negative_color = kwargs.get("negative_color", "#1f77b4")
height = kwargs.get("height", 0.45)
pos_cursor = float(base_value)
neg_cursor = float(base_value)
for contribution, label in zip(grouped_values, labels):
if contribution >= 0:
left = pos_cursor
pos_cursor += contribution
y = 0.18
color = positive_color
va = "bottom"
text_y = y + height / 2 + 0.04
else:
left = neg_cursor + contribution
neg_cursor += contribution
y = -0.18 - height
color = negative_color
va = "top"
text_y = y - 0.04
ax.barh(
y,
abs(contribution),
left=left,
height=height,
color=color,
alpha=0.82,
edgecolor="white",
linewidth=0.7,
)
if abs(contribution) > kwargs.get("label_min_width", 0.0):
ax.text(
left + abs(contribution) / 2,
text_y,
label,
ha="center",
va=va,
fontsize=kwargs.get("label_fontsize", 8),
rotation=kwargs.get("label_rotation", 0),
)
final_value = float(base_value + grouped_values.sum())
ax.axvline(base_value, color="#333333", linestyle="--", linewidth=1.0)
ax.axvline(final_value, color="#333333", linestyle="-", linewidth=1.0)
ax.text(base_value, 0.78, "base", ha="center", fontsize=9)
ax.text(final_value, 0.78, "final", ha="center", fontsize=9)
ax.set_yticks([])
ax.set_xlabel(kwargs.get("xlabel", "Model output"))
ax.set_title(
kwargs.get(
"title", f"FDFI Force Plot (base={base_value:.3g}, final={final_value:.3g})"
),
fontsize=kwargs.get("title_fontsize", 11),
)
ax.grid(axis="x", linestyle="--", alpha=0.25)
ax.set_ylim(-0.9, 1.0)
_finish_figure(
fig,
savepath,
show,
dpi=kwargs.get("dpi", 150),
bbox_inches=kwargs.get("bbox_inches", "tight"),
)
return fig, ax
[docs]
def dependence_plot(
feature_idx: FeatureRef,
shap_values: ArrayLike,
features: ArrayLike,
feature_names: Optional[Sequence[Any]] = None,
interaction_index: Optional[FeatureRef] = None,
show: bool = True,
ax: Optional[Axes] = None,
savepath: Optional[str] = None,
**kwargs: Any,
):
"""
Create a feature-dependence scatter plot.
Parameters
----------
feature_idx : int or str
Feature index or feature name to plot.
shap_values : array-like of shape (n_samples, n_features)
Per-sample FDFI attributions such as ``explainer.ueifs_X``.
features : array-like of shape (n_samples, n_features)
Feature values for the same samples.
feature_names : sequence of str, optional
Feature names. Required when ``feature_idx`` or ``interaction_index`` is
a string.
interaction_index : int or str, optional
Feature used to color points.
show : bool, default=True
Whether to display the figure via ``plt.show()``.
ax : matplotlib.axes.Axes, optional
Existing axes to draw on.
savepath : str, optional
Path where the figure should be saved.
**kwargs
Styling options.
Returns
-------
fig, ax : tuple
Matplotlib figure and axes.
Examples
--------
>>> dependence_plot("age", explainer.ueifs_X, X_test,
... feature_names=feature_names, show=False)
"""
values = _as_2d(shap_values, "shap_values")
feature_values = _as_2d(features, "features")
if feature_values.shape != values.shape:
raise ValueError(
"features must have the same shape as shap_values; "
f"got {feature_values.shape} and {values.shape}"
)
n_samples, n_features = values.shape
names = _feature_names(feature_names, n_features)
idx = _resolve_feature_index(feature_idx, names, n_features, "feature_idx")
color = kwargs.get("color", "#1f77b4")
color_values = None
color_label = None
if interaction_index is not None:
color_idx = _resolve_feature_index(
interaction_index, names, n_features, "interaction_index"
)
color_values = feature_values[:, color_idx]
color_label = names[color_idx]
figsize = kwargs.get("figsize", (6.5, 4.8))
fig, ax = _fig_ax(ax, figsize)
if color_values is None:
scatter = ax.scatter(
feature_values[:, idx],
values[:, idx],
s=kwargs.get("dot_size", 28),
alpha=kwargs.get("alpha", 0.75),
color=color,
edgecolors="none",
)
else:
scatter = ax.scatter(
feature_values[:, idx],
values[:, idx],
s=kwargs.get("dot_size", 28),
alpha=kwargs.get("alpha", 0.75),
c=color_values,
cmap=kwargs.get("cmap", "viridis"),
edgecolors="none",
)
cbar = fig.colorbar(scatter, ax=ax, pad=0.02)
cbar.set_label(color_label)
ax.axhline(0.0, color="#555555", linewidth=0.8, zorder=0)
ax.set_xlabel(kwargs.get("xlabel", names[idx]))
ax.set_ylabel(kwargs.get("ylabel", f"FDFI attribution for {names[idx]}"))
ax.set_title(
kwargs.get("title", f"FDFI Dependence: {names[idx]}"),
fontsize=kwargs.get("title_fontsize", 11),
)
ax.grid(linestyle="--", alpha=0.25)
_finish_figure(
fig,
savepath,
show,
dpi=kwargs.get("dpi", 150),
bbox_inches=kwargs.get("bbox_inches", "tight"),
)
return fig, ax
[docs]
def correlation_heatmap(
X_background: ArrayLike,
feature_names: Optional[Sequence[Any]] = None,
savepath: Optional[str] = None,
show: bool = True,
ax: Optional[Axes] = None,
**kwargs: Any,
):
"""
Plot a clustered Pearson correlation heatmap for background features.
Parameters
----------
X_background : array-like of shape (n_samples, n_features)
Background or training feature matrix used to estimate correlation
structure.
feature_names : sequence of str, optional
Feature names. Defaults to ``Feature 0``, ``Feature 1``, ...
savepath : str, optional
Path where the figure should be saved.
show : bool, default=True
Whether to display the figure via ``plt.show()``.
ax : matplotlib.axes.Axes, optional
Existing axes to draw on.
**kwargs
Styling options including ``figsize``, ``cmap``, ``vmin``, ``vmax``,
``fontsize``, ``dpi``, and ``bbox_inches``.
Returns
-------
fig : matplotlib.figure.Figure
The figure object.
ax : matplotlib.axes.Axes
The axes object containing the heatmap.
feature_names_reordered : list of str
Feature names in clustered order.
Examples
--------
>>> correlation_heatmap(X_background, feature_names, show=False)
"""
X = _as_2d(X_background, "X_background")
n_samples, n_features = X.shape
names = _feature_names(feature_names, n_features)
if n_samples < kwargs.get("sample_warning_threshold", 50):
warnings.warn(
f"X_background has only {n_samples} samples; correlation estimates "
"may be unstable. Use a representative background set when possible.",
UserWarning,
stacklevel=2,
)
corr = np.corrcoef(X, rowvar=False)
corr = np.asarray(corr, dtype=float)
if corr.ndim == 0:
corr = np.array([[1.0]])
corr = np.nan_to_num(corr, nan=0.0, posinf=0.0, neginf=0.0)
np.fill_diagonal(corr, 1.0)
if n_features > 1:
dist_matrix = 1.0 - np.abs(corr)
dist_matrix = np.clip(dist_matrix, 0.0, 1.0)
np.fill_diagonal(dist_matrix, 0.0)
linkage_matrix = hierarchy.linkage(
distance.squareform(dist_matrix, checks=False), method="average"
)
leaf_order = hierarchy.dendrogram(linkage_matrix, no_plot=True)["leaves"]
else:
leaf_order = [0]
corr_reordered = corr[np.ix_(leaf_order, leaf_order)]
names_reordered = [names[i] for i in leaf_order]
figsize = kwargs.get("figsize", (8.5, 7.0))
fig, ax = _fig_ax(ax, figsize)
im = ax.imshow(
corr_reordered,
cmap=kwargs.get("cmap", "RdBu_r"),
vmin=kwargs.get("vmin", -1),
vmax=kwargs.get("vmax", 1),
aspect="auto",
)
cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
cbar.set_label(kwargs.get("colorbar_label", "Pearson r"))
ax.set_xticks(np.arange(n_features))
ax.set_yticks(np.arange(n_features))
fontsize = kwargs.get("fontsize", 9)
ax.set_xticklabels(names_reordered, rotation=45, ha="right", fontsize=fontsize)
ax.set_yticklabels(names_reordered, fontsize=fontsize)
ax.set_title(
kwargs.get(
"title",
"Pearson Correlation Matrix (clustered by absolute correlation)",
),
fontsize=kwargs.get("title_fontsize", 11),
)
_finish_figure(
fig,
savepath,
show,
dpi=kwargs.get("dpi", 150),
bbox_inches=kwargs.get("bbox_inches", "tight"),
)
return fig, ax, names_reordered
[docs]
def confidence_interval_plot(
ci_results: Mapping[str, Any],
feature_names: Optional[Sequence[Any]] = None,
max_display: int = 20,
ax: Optional[Axes] = None,
show: bool = True,
savepath: Optional[str] = None,
**kwargs: Any,
):
"""
Plot FDFI confidence intervals from ``conf_int()`` output.
Parameters
----------
ci_results : mapping
Dictionary returned by ``explainer.conf_int()``. Required keys are
``score``, ``ci_lower``, and ``ci_upper``. ``reject_null`` and
``ranking`` are used when present.
feature_names : sequence of str, optional
Feature names. If ``ci_results`` contains ``groups``, those group names
are used by default.
max_display : int, default=20
Maximum number of rows to display.
ax : matplotlib.axes.Axes, optional
Existing axes to draw on.
show : bool, default=True
Whether to display the figure via ``plt.show()``.
savepath : str, optional
Path where the figure should be saved.
**kwargs
Styling options.
Returns
-------
fig, ax : tuple
Matplotlib figure and axes.
Examples
--------
>>> ci = explainer.conf_int(alpha=0.05, target="X")
>>> confidence_interval_plot(ci, feature_names=feature_names, show=False)
"""
required = ("score", "ci_lower", "ci_upper")
missing = [key for key in required if key not in ci_results]
if missing:
raise ValueError(f"ci_results is missing required keys: {missing}")
if max_display <= 0:
raise ValueError("max_display must be positive")
score = _as_1d(ci_results["score"], "ci_results['score']")
lower = _as_1d(ci_results["ci_lower"], "ci_results['ci_lower']")
upper = _as_1d(ci_results["ci_upper"], "ci_results['ci_upper']")
n_features = score.shape[0]
if lower.shape[0] != n_features or upper.shape[0] != n_features:
raise ValueError("score, ci_lower, and ci_upper must have the same length")
if feature_names is None and "groups" in ci_results:
names = _feature_names(ci_results["groups"], n_features)
else:
names = _feature_names(feature_names, n_features)
if "ranking" in ci_results:
ranking = _as_1d(ci_results["ranking"], "ci_results['ranking']")
if ranking.shape[0] != n_features:
raise ValueError("ranking must have the same length as score")
order = np.argsort(ranking, kind="stable")
else:
order = np.argsort(-np.abs(score), kind="stable")
order = order[: min(max_display, n_features)]
reject = np.asarray(ci_results.get("reject_null", np.zeros(n_features, dtype=bool)))
if reject.shape[0] != n_features:
raise ValueError("reject_null must have the same length as score")
plot_score = score[order]
plot_lower = lower[order]
plot_upper = upper[order]
plot_names = [names[i] for i in order]
plot_reject = reject[order].astype(bool)
finite_values = np.concatenate(
[
plot_score[np.isfinite(plot_score)],
plot_lower[np.isfinite(plot_lower)],
plot_upper[np.isfinite(plot_upper)],
]
)
if finite_values.size == 0:
finite_min, finite_max = -1.0, 1.0
else:
finite_min = float(finite_values.min())
finite_max = float(finite_values.max())
if finite_min == finite_max:
finite_min -= 1.0
finite_max += 1.0
pad = 0.08 * (finite_max - finite_min)
clip_min = finite_min - pad
clip_max = finite_max + pad
lower_plot = np.where(np.isfinite(plot_lower), plot_lower, clip_min)
upper_plot = np.where(np.isfinite(plot_upper), plot_upper, clip_max)
xerr = np.vstack([plot_score - lower_plot, upper_plot - plot_score])
xerr = np.maximum(xerr, 0.0)
figsize = kwargs.get("figsize", (8.0, max(3.5, 0.34 * len(order) + 1.5)))
fig, ax = _fig_ax(ax, figsize)
y_pos = np.arange(len(order))
colors = np.where(plot_reject, kwargs.get("significant_color", "#d62728"), "#777777")
for y, value, err, color in zip(y_pos, plot_score, xerr.T, colors):
ax.errorbar(
value,
y,
xerr=err.reshape(2, 1),
fmt="o",
color=kwargs.get("interval_color", "#333333"),
ecolor=color,
markerfacecolor="white",
markeredgecolor=color,
markersize=kwargs.get("markersize", 5),
capsize=kwargs.get("capsize", 3),
linewidth=kwargs.get("linewidth", 1.2),
zorder=3,
)
margin = ci_results.get("margin", 0.0)
if np.ndim(margin) == 0:
ax.axvline(float(margin), color="#555555", linestyle="--", linewidth=0.9)
ax.axvline(0.0, color="#999999", linestyle=":", linewidth=0.8)
ax.set_yticks(y_pos)
ax.set_yticklabels(plot_names, fontsize=kwargs.get("tick_fontsize", 9))
ax.invert_yaxis()
ax.set_xlabel(kwargs.get("xlabel", "FDFI score with confidence interval"))
ax.set_title(
kwargs.get("title", "FDFI Confidence Intervals"),
fontsize=kwargs.get("title_fontsize", 11),
)
ax.grid(axis="x", linestyle="--", alpha=0.25)
ax.set_xlim(clip_min, clip_max)
_finish_figure(
fig,
savepath,
show,
dpi=kwargs.get("dpi", 150),
bbox_inches=kwargs.get("bbox_inches", "tight"),
)
return fig, ax
[docs]
def diagnostics_plot(
diagnostics: Mapping[str, Any],
feature_names: Optional[Sequence[Any]] = None,
ax: Optional[Axes] = None,
show: bool = True,
savepath: Optional[str] = None,
**kwargs: Any,
):
"""
Plot FDFI disentanglement diagnostics.
Parameters
----------
diagnostics : mapping
Diagnostics dictionary from an explainer. Existing keys include
``latent_independence_median``, ``distribution_fidelity_mmd``,
``latent_independence_label``, ``distribution_fidelity_label``, and
optionally ``latent_independence_dcor``.
feature_names : sequence of str, optional
Feature names for the latent dCor matrix when displayed.
ax : matplotlib.axes.Axes, optional
Existing axes for the scalar diagnostics bar plot. When omitted and a
dCor matrix is available, a two-panel figure is created.
show : bool, default=True
Whether to display the figure via ``plt.show()``.
savepath : str, optional
Path where the figure should be saved.
**kwargs
Styling options.
Returns
-------
fig : matplotlib.figure.Figure
The figure object.
ax_or_axes : matplotlib.axes.Axes or numpy.ndarray
The axes object(s).
Examples
--------
>>> diagnostics_plot(explainer.diagnostics, show=False)
"""
metric_keys = [
("latent_independence_median", "Median dCor", "latent_independence_label"),
("distribution_fidelity_mmd", "MMD", "distribution_fidelity_label"),
]
available = [
(key, label, label_key)
for key, label, label_key in metric_keys
if key in diagnostics
]
if not available:
raise ValueError(
"diagnostics must contain at least one of "
"'latent_independence_median' or 'distribution_fidelity_mmd'"
)
dcor = diagnostics.get("latent_independence_dcor")
include_matrix = kwargs.get("include_matrix", ax is None and dcor is not None)
if ax is None and include_matrix:
fig, axes = plt.subplots(
1,
2,
figsize=kwargs.get("figsize", (10.0, 4.0)),
gridspec_kw={"width_ratios": [1.0, 1.25]},
)
metric_ax = axes[0]
matrix_ax = axes[1]
returned_axes: Union[Axes, np.ndarray] = axes
else:
fig, metric_ax = _fig_ax(ax, kwargs.get("figsize", (5.8, 3.8)))
matrix_ax = None
returned_axes = metric_ax
label_colors = {
"GOOD": kwargs.get("good_color", "#2ca02c"),
"MODERATE": kwargs.get("moderate_color", "#ffbf00"),
"POOR": kwargs.get("poor_color", "#d62728"),
}
labels = [label for _, label, _ in available]
values = np.array([float(diagnostics[key]) for key, _, _ in available])
quality = [str(diagnostics.get(label_key, "")).upper() for _, _, label_key in available]
colors = [label_colors.get(label, "#777777") for label in quality]
x_pos = np.arange(len(values))
metric_ax.bar(x_pos, values, color=colors, edgecolor="white", linewidth=0.7)
metric_ax.set_xticks(x_pos)
metric_ax.set_xticklabels(labels, rotation=20, ha="right")
metric_ax.set_ylabel(kwargs.get("ylabel", "Diagnostic value"))
metric_ax.set_title(
kwargs.get("title", "FDFI Diagnostics"),
fontsize=kwargs.get("title_fontsize", 11),
)
metric_ax.grid(axis="y", linestyle="--", alpha=0.25)
for x, value, label in zip(x_pos, values, quality):
text = f"{value:.3g}"
if label:
text = f"{text}\n{label.title()}"
metric_ax.text(x, value, text, ha="center", va="bottom", fontsize=8)
if matrix_ax is not None:
dcor_arr = _as_2d(dcor, "diagnostics['latent_independence_dcor']")
if dcor_arr.shape[0] != dcor_arr.shape[1]:
raise ValueError("latent_independence_dcor must be a square matrix")
dcor_names = _feature_names(feature_names, dcor_arr.shape[0])
im = matrix_ax.imshow(dcor_arr, cmap=kwargs.get("matrix_cmap", "magma_r"))
matrix_ax.set_title("Latent dCor Matrix", fontsize=kwargs.get("title_fontsize", 11))
matrix_ax.set_xticks(np.arange(dcor_arr.shape[1]))
matrix_ax.set_yticks(np.arange(dcor_arr.shape[0]))
matrix_ax.set_xticklabels(
dcor_names, rotation=45, ha="right", fontsize=kwargs.get("matrix_fontsize", 8)
)
matrix_ax.set_yticklabels(dcor_names, fontsize=kwargs.get("matrix_fontsize", 8))
fig.colorbar(im, ax=matrix_ax, fraction=0.046, pad=0.04)
_finish_figure(
fig,
savepath,
show,
dpi=kwargs.get("dpi", 150),
bbox_inches=kwargs.get("bbox_inches", "tight"),
)
return fig, returned_axes