Visualization with FDFI

This tutorial shows the plotting helpers in fdfi.plots for background data, aggregate FDFI scores, per-sample UEIFs, inference output, and explainer diagnostics.

[1]:
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display

from fdfi.explainers import OTExplainer
from fdfi.plots import (
    confidence_interval_plot,
    correlation_heatmap,
    dependence_plot,
    diagnostics_plot,
    force_plot,
    summary_bar,
    summary_plot,
    waterfall_plot,
)

rng = np.random.default_rng(42)

Simulated Data

Use a small correlated background matrix so the correlation heatmap and dependence plot have structure.

[2]:
n_train = 200
n_test = 60
n_features = 6

cov = np.full((n_features, n_features), 0.25)
np.fill_diagonal(cov, 1.0)
cov[0, 1] = cov[1, 0] = 0.75
cov[2, 3] = cov[3, 2] = -0.55

X_background = rng.multivariate_normal(np.zeros(n_features), cov, size=n_train)
X_test = rng.multivariate_normal(np.zeros(n_features), cov, size=n_test)
feature_names = [f"X{i}" for i in range(n_features)]

def model(X):
    return 1.5 * X[:, 0] - 0.8 * X[:, 2] + 0.4 * X[:, 4] ** 2

explainer = OTExplainer(model, data=X_background, nsamples=40)
results = explainer(X_test)
ci = explainer.conf_int(alpha=0.05, target="X", alternative="greater")

Background Correlation

correlation_heatmap clusters features by 1 - abs(correlation) and returns the reordered names.

[3]:
fig, ax, reordered = correlation_heatmap(
    X_background,
    feature_names,
    show=False,
)
display(fig)
plt.close(fig)
reordered
../_images/tutorials_visualization_5_0.png
[3]:
['X4', 'X0', 'X1', 'X5', 'X2', 'X3']

Global Scores

summary_bar is the aggregate bar-chart helper for phi_X, phi_Z, and their standard errors.

[4]:
fig, ax, importance_table = summary_bar(
    results["phi_X"],
    results["se_X"],
    feature_names,
    show=False,
)
display(fig)
plt.close(fig)
importance_table.head()
../_images/tutorials_visualization_7_0.png
[4]:
feature phi se
0 X0 1.380897 0.266202
1 X1 0.598447 0.080826
2 X2 0.376329 0.062175
3 X4 0.290675 0.064366
4 X3 0.238167 0.026621

Per-sample UEIFs

After the explainer runs, explainer.ueifs_X stores per-sample X-space UEIFs. Use these for summary, dependence, waterfall, and force views.

[5]:
fig, ax = summary_plot(
    explainer.ueifs_X,
    features=X_test,
    feature_names=feature_names,
    show=False,
)
display(fig)
plt.close(fig)

fig, ax = dependence_plot(
    "X0",
    explainer.ueifs_X,
    X_test,
    feature_names=feature_names,
    interaction_index="X1",
    show=False,
)
display(fig)
plt.close(fig)

fig, ax = waterfall_plot(
    explainer.ueifs_X[0],
    features=X_test[0],
    feature_names=feature_names,
    max_display=5,
    show=False,
)
display(fig)
plt.close(fig)

fig, ax = force_plot(
    0.0,
    explainer.ueifs_X[0],
    features=X_test[0],
    feature_names=feature_names,
    max_display=5,
    show=False,
)
display(fig)
plt.close(fig)
../_images/tutorials_visualization_9_0.png
../_images/tutorials_visualization_9_1.png
../_images/tutorials_visualization_9_2.png
../_images/tutorials_visualization_9_3.png

Inference and Diagnostics

confidence_interval_plot consumes conf_int() dictionaries. diagnostics_plot consumes the shared diagnostics dictionary from OT/EOT/Flow explainers.

[6]:
fig, ax = confidence_interval_plot(
    ci,
    feature_names=feature_names,
    show=False,
)
display(fig)
plt.close(fig)

fig, axes = diagnostics_plot(
    explainer.diagnostics,
    feature_names=feature_names,
    show=False,
)
display(fig)
plt.close(fig)
../_images/tutorials_visualization_11_0.png
../_images/tutorials_visualization_11_1.png