{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "f2756f07", "metadata": { "execution": { "iopub.execute_input": "2026-05-17T00:41:58.701768Z", "iopub.status.busy": "2026-05-17T00:41:58.701493Z", "iopub.status.idle": "2026-05-17T00:41:59.169078Z", "shell.execute_reply": "2026-05-17T00:41:59.168507Z" } }, "outputs": [], "source": [ "import fdfi\n", "print('FDFI version:', fdfi.__version__)" ] }, { "cell_type": "markdown", "id": "64771235", "metadata": { "papermill": { "duration": 0.005461, "end_time": "2026-02-12T05:43:29.993132", "exception": false, "start_time": "2026-02-12T05:43:29.987671", "status": "completed" }, "tags": [] }, "source": [ "# OTExplainer: Gaussian Optimal Transport\n", "\n", "This tutorial provides a deep dive into the `OTExplainer`, which uses Gaussian optimal transport for computing feature importance.\n", "\n", "## What You'll Learn\n", "\n", "1. Mathematical foundation of Gaussian OT\n", "2. Key hyperparameters and their effects\n", "3. When to use OTExplainer vs other methods\n", "4. Best practices for real-world usage" ] }, { "cell_type": "code", "execution_count": null, "id": "1b6990e4", "metadata": { "execution": { "iopub.execute_input": "2026-05-17T00:41:59.170912Z", "iopub.status.busy": "2026-05-17T00:41:59.170748Z", "iopub.status.idle": "2026-05-17T00:41:59.413788Z", "shell.execute_reply": "2026-05-17T00:41:59.413280Z" }, "papermill": { "duration": 0.785167, "end_time": "2026-02-12T05:43:30.781871", "exception": false, "start_time": "2026-02-12T05:43:29.996704", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from fdfi.explainers import OTExplainer\n", "from fdfi.plots import correlation_heatmap, diagnostics_plot, summary_bar\n", "\n", "np.random.seed(42)" ] }, { "cell_type": "markdown", "id": "c051506a", "metadata": { "papermill": { "duration": 0.036953, "end_time": "2026-02-12T05:43:30.820461", "exception": false, "start_time": "2026-02-12T05:43:30.783508", "status": "completed" }, "tags": [] }, "source": [ "## Mathematical Background\n", "\n", "The key insight of OTExplainer is to use the **Gaussian optimal transport map** to create counterfactual distributions.\n", "\n", "Given data $X$ with mean $\\mu$ and covariance $\\Sigma$, we compute:\n", "\n", "$$Z = L^{-1}(X - \\mu)$$\n", "\n", "where $\\Sigma = LL^T$ (Cholesky decomposition). In Z-space, features are uncorrelated.\n", "\n", "To measure the importance of feature $j$, we:\n", "1. Replace $Z_j$ with an independent sample from $N(0, 1)$\n", "2. Transform back to X-space\n", "3. Compare model outputs" ] }, { "cell_type": "markdown", "id": "e3a59f3d", "metadata": { "papermill": { "duration": 0.001245, "end_time": "2026-02-12T05:43:30.823037", "exception": false, "start_time": "2026-02-12T05:43:30.821792", "status": "completed" }, "tags": [] }, "source": [ "## Setup: Create Correlated Data\n", "\n", "Let's create data with correlated features to see how OTExplainer handles dependencies:" ] }, { "cell_type": "code", "execution_count": null, "id": "04031dcd", "metadata": { "execution": { "iopub.execute_input": "2026-05-17T00:41:59.415628Z", "iopub.status.busy": "2026-05-17T00:41:59.415469Z", "iopub.status.idle": "2026-05-17T00:41:59.418848Z", "shell.execute_reply": "2026-05-17T00:41:59.418443Z" }, "papermill": { "duration": 0.007164, "end_time": "2026-02-12T05:43:30.831368", "exception": false, "start_time": "2026-02-12T05:43:30.824204", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "# Create covariance matrix with correlations\n", "n_features = 5\n", "correlation = 0.7\n", "\n", "cov_matrix = np.eye(n_features)\n", "cov_matrix[0, 1] = cov_matrix[1, 0] = correlation # Features 0 and 1 are correlated\n", "cov_matrix[2, 3] = cov_matrix[3, 2] = correlation # Features 2 and 3 are correlated\n", "\n", "# Generate correlated data\n", "n_samples = 500\n", "X_train = np.random.multivariate_normal(\n", " mean=np.zeros(n_features),\n", " cov=cov_matrix,\n", " size=n_samples\n", ")\n", "\n", "print(\"Correlation matrix of training data:\")\n", "print(np.corrcoef(X_train.T).round(2))" ] }, { "cell_type": "code", "execution_count": null, "id": "d88e79e9", "metadata": { "execution": { "iopub.execute_input": "2026-05-17T00:41:59.420208Z", "iopub.status.busy": "2026-05-17T00:41:59.420117Z", "iopub.status.idle": "2026-05-17T00:41:59.422226Z", "shell.execute_reply": "2026-05-17T00:41:59.421828Z" }, "papermill": { "duration": 0.005076, "end_time": "2026-02-12T05:43:30.837969", "exception": false, "start_time": "2026-02-12T05:43:30.832893", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "# Create model: only feature 0 matters\n", "def model(X):\n", " return X[:, 0] ** 2\n", "\n", "# Test data\n", "X_test = np.random.multivariate_normal(\n", " mean=np.zeros(n_features),\n", " cov=cov_matrix,\n", " size=50\n", ")" ] }, { "cell_type": "markdown", "id": "c805f8df", "metadata": { "papermill": { "duration": 0.001252, "end_time": "2026-02-12T05:43:30.840651", "exception": false, "start_time": "2026-02-12T05:43:30.839399", "status": "completed" }, "tags": [] }, "source": [ "## Effect of nsamples\n", "\n", "The `nsamples` parameter controls Monte Carlo variance. Let's see its effect:" ] }, { "cell_type": "code", "execution_count": null, "id": "cb0dc14d", "metadata": { "execution": { "iopub.execute_input": "2026-05-17T00:41:59.423594Z", "iopub.status.busy": "2026-05-17T00:41:59.423501Z", "iopub.status.idle": "2026-05-17T00:41:59.680496Z", "shell.execute_reply": "2026-05-17T00:41:59.679911Z" }, "papermill": { "duration": 0.180749, "end_time": "2026-02-12T05:43:31.022720", "exception": false, "start_time": "2026-02-12T05:43:30.841971", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "# Compare different nsamples values\n", "nsamples_values = [10, 50, 200]\n", "results_by_nsamples = {}\n", "feature_names = [f\"X{i}\" for i in range(n_features)]\n", "\n", "for ns in nsamples_values:\n", " explainer = OTExplainer(model, data=X_train, nsamples=ns)\n", " results_by_nsamples[ns] = explainer(X_test)\n", "\n", "fig, axes = plt.subplots(1, len(nsamples_values), figsize=(13, 4), sharex=True)\n", "for ax, ns in zip(axes, nsamples_values):\n", " summary_bar(\n", " results_by_nsamples[ns][\"phi_X\"],\n", " results_by_nsamples[ns][\"se_X\"],\n", " feature_names,\n", " ax=ax,\n", " show=False,\n", " title=f\"nsamples={ns}\",\n", " max_display=n_features,\n", " )\n", "fig" ] }, { "cell_type": "markdown", "id": "8ccc9e70", "metadata": { "papermill": { "duration": 0.001628, "end_time": "2026-02-12T05:43:31.026370", "exception": false, "start_time": "2026-02-12T05:43:31.024742", "status": "completed" }, "tags": [] }, "source": [ "**Key observation**: Higher `nsamples` gives smaller error bars (lower variance) but takes longer to compute." ] }, { "cell_type": "markdown", "id": "db4df2f5", "metadata": { "papermill": { "duration": 0.001524, "end_time": "2026-02-12T05:43:31.029749", "exception": false, "start_time": "2026-02-12T05:43:31.028225", "status": "completed" }, "tags": [] }, "source": [ "## Effect of sampling_method\n", "\n", "OTExplainer supports three sampling methods for counterfactual generation:" ] }, { "cell_type": "code", "execution_count": null, "id": "51487042", "metadata": { "execution": { "iopub.execute_input": "2026-05-17T00:41:59.681968Z", "iopub.status.busy": "2026-05-17T00:41:59.681841Z", "iopub.status.idle": "2026-05-17T00:41:59.728322Z", "shell.execute_reply": "2026-05-17T00:41:59.727795Z" }, "papermill": { "duration": 0.011995, "end_time": "2026-02-12T05:43:31.043695", "exception": false, "start_time": "2026-02-12T05:43:31.031700", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "# Compare sampling methods\n", "sampling_methods = [\"resample\", \"permutation\", \"normal\"]\n", "results_by_method = {}\n", "\n", "for method in sampling_methods:\n", " explainer = OTExplainer(\n", " model, \n", " data=X_train, \n", " nsamples=100,\n", " sampling_method=method\n", " )\n", " results_by_method[method] = explainer(X_test)\n", "\n", "# Compare results\n", "print(\"Feature importance by sampling method:\")\n", "print(\"-\" * 50)\n", "print(f\"{'Feature':>8}\", end=\"\")\n", "for method in sampling_methods:\n", " print(f\"{method:>12}\", end=\"\")\n", "print()\n", "print(\"-\" * 50)\n", "\n", "for i in range(n_features):\n", " print(f\"{i:>8}\", end=\"\")\n", " for method in sampling_methods:\n", " phi = results_by_method[method][\"phi_X\"][i]\n", " print(f\"{phi:>12.4f}\", end=\"\")\n", " print()" ] }, { "cell_type": "markdown", "id": "7f3e40a2", "metadata": { "papermill": { "duration": 0.001749, "end_time": "2026-02-12T05:43:31.047408", "exception": false, "start_time": "2026-02-12T05:43:31.045659", "status": "completed" }, "tags": [] }, "source": [ "**Sampling methods explained:**\n", "\n", "- `resample`: Sample from the background data (preserves marginal distribution)\n", "- `permutation`: Permute values within test set (no new values introduced)\n", "- `normal`: Sample from standard normal (strongest Gaussian assumption)" ] }, { "cell_type": "markdown", "id": "a506c2ce", "metadata": { "papermill": { "duration": 0.001681, "end_time": "2026-02-12T05:43:31.050909", "exception": false, "start_time": "2026-02-12T05:43:31.049228", "status": "completed" }, "tags": [] }, "source": [ "## Visualizing the Z-space Transformation\n", "\n", "Let's visualize how OTExplainer transforms correlated data to uncorrelated Z-space:" ] }, { "cell_type": "code", "execution_count": null, "id": "fe9642eb", "metadata": { "execution": { "iopub.execute_input": "2026-05-17T00:41:59.730033Z", "iopub.status.busy": "2026-05-17T00:41:59.729914Z", "iopub.status.idle": "2026-05-17T00:41:59.861045Z", "shell.execute_reply": "2026-05-17T00:41:59.860240Z" }, "papermill": { "duration": 0.128911, "end_time": "2026-02-12T05:43:31.181421", "exception": false, "start_time": "2026-02-12T05:43:31.052510", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "# Create explainer and access internal matrices\n", "explainer = OTExplainer(model, data=X_train, nsamples=50)\n", "feature_names = [f\"X{i}\" for i in range(n_features)]\n", "\n", "# Background feature correlation before whitening\n", "correlation_heatmap(X_train, feature_names, show=False)\n", "\n", "# Transform to Z-space\n", "Z_train = (X_train - explainer.mean) @ explainer.L_inv\n", "\n", "# Plot X-space vs Z-space\n", "fig, axes = plt.subplots(1, 2, figsize=(10, 4))\n", "\n", "# X-space (correlated)\n", "axes[0].scatter(X_train[:, 0], X_train[:, 1], alpha=0.5, s=10)\n", "axes[0].set_xlabel(\"X₀\")\n", "axes[0].set_ylabel(\"X₁\")\n", "axes[0].set_title(f\"X-space (correlation = {np.corrcoef(X_train[:, 0], X_train[:, 1])[0, 1]:.2f})\")\n", "axes[0].axis('equal')\n", "\n", "# Z-space (uncorrelated)\n", "axes[1].scatter(Z_train[:, 0], Z_train[:, 1], alpha=0.5, s=10)\n", "axes[1].set_xlabel(\"Z₀\")\n", "axes[1].set_ylabel(\"Z₁\")\n", "axes[1].set_title(f\"Z-space (correlation = {np.corrcoef(Z_train[:, 0], Z_train[:, 1])[0, 1]:.2f})\")\n", "axes[1].axis('equal')\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "c5532f3f", "metadata": {}, "source": [ "## Shared Diagnostics and Summary\n", "\n", "Use the built-in diagnostics and `summary()` methods for consistent reporting\n", "across OT/EOT/Flow explainers.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "866ce1ea", "metadata": { "execution": { "iopub.execute_input": "2026-05-17T00:41:59.862854Z", "iopub.status.busy": "2026-05-17T00:41:59.862725Z", "iopub.status.idle": "2026-05-17T00:41:59.866872Z", "shell.execute_reply": "2026-05-17T00:41:59.866465Z" } }, "outputs": [], "source": [ "_ = explainer(X_test)\n", "diag = explainer.diagnostics\n", "print(\"OT diagnostics\")\n", "print(\"-\" * 50)\n", "print(f\"Latent independence (median dCor): {diag['latent_independence_median']:.6f} [{diag['latent_independence_label']}]\")\n", "print(f\"Distribution fidelity (MMD): {diag['distribution_fidelity_mmd']:.6f} [{diag['distribution_fidelity_label']}]\")\n", "\n", "print(\"\\nX-space summary\")\n", "_ = explainer.summary(alpha=0.05, target='X', alternative='greater')\n", "\n", "\n", "diagnostics_plot(diag, feature_names=feature_names, show=False)\n" ] }, { "cell_type": "markdown", "id": "97247a77", "metadata": { "papermill": { "duration": 0.001894, "end_time": "2026-02-12T05:43:31.185570", "exception": false, "start_time": "2026-02-12T05:43:31.183676", "status": "completed" }, "tags": [] }, "source": [ "## Best Practices\n", "\n", "### When to use OTExplainer\n", "\n", "✅ **Good for:**\n", "- Continuous features\n", "- Roughly Gaussian data\n", "- Fast computation\n", "- Stable results\n", "\n", "❌ **Consider EOTExplainer for:**\n", "- Heavily non-Gaussian data\n", "- Multimodal distributions\n", "- Mixed categorical/continuous features\n", "\n", "### Recommended settings\n", "\n", "```python\n", "explainer = OTExplainer(\n", " model,\n", " data=X_train,\n", " nsamples=50, # Good balance of speed/accuracy\n", " sampling_method=\"resample\", # Preserves marginal distribution\n", ")\n", "```" ] }, { "cell_type": "markdown", "id": "4fd67813", "metadata": { "papermill": { "duration": 0.001946, "end_time": "2026-02-12T05:43:31.189580", "exception": false, "start_time": "2026-02-12T05:43:31.187634", "status": "completed" }, "tags": [] }, "source": [ "## Summary\n", "\n", "Key takeaways:\n", "\n", "1. OTExplainer uses Gaussian OT to disentangle correlated features.\n", "2. Higher `nsamples` reduces variance at the cost of computation time.\n", "3. `sampling_method=\"resample\"` is recommended for most cases.\n", "4. The transformation to Z-space removes correlations for clean attribution.\n", "5. Shared diagnostics and `summary()` provide standardized inference reporting.\n" ] } ], "metadata": { "kernelspec": { "display_name": "dfi", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.19" }, "papermill": { "default_parameters": {}, "duration": 2.254231, "end_time": "2026-02-12T05:43:31.408797", "environment_variables": {}, "exception": null, "input_path": "docs/tutorials/ot_explainer.ipynb", "output_path": "docs/tutorials/ot_explainer.ipynb", "parameters": {}, "start_time": "2026-02-12T05:43:29.154566", "version": "2.6.0" } }, "nbformat": 4, "nbformat_minor": 5 }