{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "c149bc65", "metadata": { "execution": { "iopub.execute_input": "2026-05-17T00:42:00.685447Z", "iopub.status.busy": "2026-05-17T00:42:00.685295Z", "iopub.status.idle": "2026-05-17T00:42:01.132530Z", "shell.execute_reply": "2026-05-17T00:42:01.132105Z" } }, "outputs": [], "source": [ "import fdfi\n", "print('FDFI version:', fdfi.__version__)" ] }, { "cell_type": "markdown", "id": "c02efd2d", "metadata": { "papermill": { "duration": 0.005769, "end_time": "2026-02-12T05:43:37.568161", "exception": false, "start_time": "2026-02-12T05:43:37.562392", "status": "completed" }, "tags": [] }, "source": [ "# EOTExplainer: Semicontinuous Entropic Optimal Transport\n", "\n", "This tutorial covers the `EOTExplainer`, which uses **semicontinuous entropic optimal transport** with **population backward attribution** (best linear projection) to compute disentangled feature importance.\n", "\n", "## What You'll Learn\n", "\n", "1. How EOT whitening + semicontinuous transport disentangles features\n", "2. How the population backward attribution maps Z-importance to X-importance\n", "3. How to run attribution inference with confidence intervals\n", "4. How epsilon controls the transport shrinkage\n" ] }, { "cell_type": "code", "execution_count": null, "id": "22022569", "metadata": { "execution": { "iopub.execute_input": "2026-05-17T00:42:01.134209Z", "iopub.status.busy": "2026-05-17T00:42:01.134075Z", "iopub.status.idle": "2026-05-17T00:42:01.136017Z", "shell.execute_reply": "2026-05-17T00:42:01.135613Z" }, "papermill": { "duration": 0.788061, "end_time": "2026-02-12T05:43:38.359344", "exception": false, "start_time": "2026-02-12T05:43:37.571283", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import numpy as np\n", "from fdfi.explainers import EOTExplainer\n", "from fdfi.plots import confidence_interval_plot, diagnostics_plot, summary_bar\n", "\n", "np.random.seed(42)\n" ] }, { "cell_type": "markdown", "id": "bc1fdd30", "metadata": { "papermill": { "duration": 0.001562, "end_time": "2026-02-12T05:43:38.396820", "exception": false, "start_time": "2026-02-12T05:43:38.395258", "status": "completed" }, "tags": [] }, "source": [ "## Why Semicontinuous EOT?\n", "\n", "Gaussian OT whitens data via $\\Sigma^{-1/2}$, which assumes linear structure. **Semicontinuous EOT** solves the entropic transport problem between the empirical source and a continuous $\\mathcal{N}(0, I)$ target analytically:\n", "\n", "$$Z = s \\cdot X_{\\text{whitened}}, \\quad s = \\frac{2}{2 + \\varepsilon}$$\n", "\n", "The **population backward attribution** computes the best linear projection $E[X_w \\mid Z]$ using the analytically known coupling moments:\n", "\n", "$$M_w = E_\\pi[ZZ^\\top]^{-1} E_\\pi[ZX_w^\\top]$$\n", "\n", "Then the weight matrix $W = L \\cdot M_w$ maps Z-space importance to X-space via:\n", "\n", "$$\\phi_{X,j} = \\sum_k W_{jk}^2 \\cdot \\phi_{Z,k}$$\n" ] }, { "cell_type": "markdown", "id": "caa03543", "metadata": { "papermill": { "duration": 0.001674, "end_time": "2026-02-12T05:43:38.399974", "exception": false, "start_time": "2026-02-12T05:43:38.398300", "status": "completed" }, "tags": [] }, "source": [ "## Synthetic Data: Relevant vs Null Features\n", "\n", "We build a dataset with correlated features. The model directly uses $X_0, X_2, X_4$, but because $X_1$ is highly correlated with $X_0$ ($\\rho = 0.7$) and $X_3$ is correlated with $X_2$ ($\\rho = 0.5$), they also carry predictive signal. FDFI's design goal is to detect **all features that provide predictive information**, so the relevant set is $\\{X_0, X_1, X_2, X_3, X_4\\}$, while $X_5, \\ldots, X_9$ are truly null." ] }, { "cell_type": "code", "execution_count": null, "id": "89ec974f", "metadata": { "execution": { "iopub.execute_input": "2026-05-17T00:42:01.137501Z", "iopub.status.busy": "2026-05-17T00:42:01.137406Z", "iopub.status.idle": "2026-05-17T00:42:01.142017Z", "shell.execute_reply": "2026-05-17T00:42:01.141596Z" }, "papermill": { "duration": 0.142511, "end_time": "2026-02-12T05:43:38.544232", "exception": false, "start_time": "2026-02-12T05:43:38.401721", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "# Correlated synthetic data with known active features\n", "n_train = 400\n", "n_test = 150\n", "d = 10\n", "\n", "# Build a covariance matrix with block correlations\n", "rng = np.random.default_rng(42)\n", "Sigma = np.eye(d)\n", "# Correlate features 0-1 and 2-3\n", "Sigma[0, 1] = Sigma[1, 0] = 0.7\n", "Sigma[2, 3] = Sigma[3, 2] = 0.5\n", "\n", "X_train = rng.multivariate_normal(np.zeros(d), Sigma, size=n_train)\n", "X_test = rng.multivariate_normal(np.zeros(d), Sigma, size=n_test)\n", "\n", "# Model directly uses features 0, 2, 4\n", "active_idx = [0, 2, 4]\n", "# Features 1 and 3 carry predictive signal via correlation\n", "# → all 5 are \"relevant\"; features 5-9 are truly null\n", "relevant_idx = [0, 1, 2, 3, 4]\n", "null_idx = [5, 6, 7, 8, 9]\n", "\n", "def exp_model(X):\n", " return 3.0 * X[:, 0] + 2.0 * X[:, 2] + 1.5 * X[:, 4]\n", "\n", "print(\"Train shape:\", X_train.shape)\n", "print(\"Test shape:\", X_test.shape)\n", "print(\"Relevant features:\", relevant_idx, \"(model uses 0,2,4; 1,3 correlated)\")\n", "print(\"Null features:\", null_idx)\n", "print(\"Correlation(X0, X1):\", f\"{np.corrcoef(X_train[:, 0], X_train[:, 1])[0, 1]:.3f}\")\n", "print(\"Correlation(X2, X3):\", f\"{np.corrcoef(X_train[:, 2], X_train[:, 3])[0, 1]:.3f}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "37bd4e00", "metadata": { "execution": { "iopub.execute_input": "2026-05-17T00:42:01.143492Z", "iopub.status.busy": "2026-05-17T00:42:01.143392Z", "iopub.status.idle": "2026-05-17T00:42:01.146224Z", "shell.execute_reply": "2026-05-17T00:42:01.145695Z" }, "papermill": { "duration": 0.006289, "end_time": "2026-02-12T05:43:38.552692", "exception": false, "start_time": "2026-02-12T05:43:38.546403", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "# Sanity check: model predictions\n", "y_preview = exp_model(X_test[:5])\n", "print(\"Preview predictions:\", np.round(y_preview, 3))\n", "print(\"Response variance:\", f\"{np.var(exp_model(X_train)):.3f}\")\n" ] }, { "cell_type": "markdown", "id": "63ea52c7", "metadata": { "papermill": { "duration": 0.001823, "end_time": "2026-02-12T05:43:38.556565", "exception": false, "start_time": "2026-02-12T05:43:38.554742", "status": "completed" }, "tags": [] }, "source": [ "## Basic EOTExplainer Usage\n", "\n", "Create an explainer, compute importance, and inspect results.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "c869c63a", "metadata": { "execution": { "iopub.execute_input": "2026-05-17T00:42:01.147722Z", "iopub.status.busy": "2026-05-17T00:42:01.147617Z", "iopub.status.idle": "2026-05-17T00:42:01.170864Z", "shell.execute_reply": "2026-05-17T00:42:01.170390Z" }, "papermill": { "duration": 0.02748, "end_time": "2026-02-12T05:43:38.585709", "exception": false, "start_time": "2026-02-12T05:43:38.558229", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "explainer = EOTExplainer(\n", " exp_model,\n", " data=X_train,\n", " nsamples=60,\n", " auto_epsilon=True,\n", " random_state=0,\n", ")\n", "\n", "results = explainer(X_test)\n", "phi_X = results[\"phi_X\"]\n", "\n", "print(\"Feature importance (phi_X):\")\n", "print(\"-\" * 55)\n", "print(f\"{'Feature':>8} {'phi_X':>10} {'Status':>12}\")\n", "print(\"-\" * 55)\n", "for i in range(d):\n", " status = \"model\" if i in active_idx else (\"correlated\" if i in relevant_idx else \"null\")\n", " print(f\"{'X_' + str(i):>8} {phi_X[i]:>10.4f} {status:>12}\")\n", "\n", "print(f\"\\nAuto epsilon: {explainer.epsilon:.4f}\")\n", "print(f\"Forward shrinkage s: {explainer.s_fwd:.4f}\")\n", "print(f\"Backward weight matrix W shape: {explainer.W.shape}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "feature_names = [f\"X{i}\" for i in range(d)]\n", "\n", "summary_bar(\n", " results[\"phi_X\"],\n", " results[\"se_X\"],\n", " feature_names,\n", " show=False,\n", ")" ], "id": "eot-explainer-09" }, { "cell_type": "markdown", "id": "692d413d", "metadata": {}, "source": [ "## Attribution Inference\n", "\n", "Use one-sided testing to identify features with significant predictive importance. We expect all 5 relevant features ($X_0$–$X_4$) to be detected, while the 5 null features ($X_5$–$X_9$) should not." ] }, { "cell_type": "code", "execution_count": null, "id": "7c7a6e95", "metadata": { "execution": { "iopub.execute_input": "2026-05-17T00:42:01.172372Z", "iopub.status.busy": "2026-05-17T00:42:01.172259Z", "iopub.status.idle": "2026-05-17T00:42:01.791487Z", "shell.execute_reply": "2026-05-17T00:42:01.790978Z" } }, "outputs": [], "source": [ "# Default conf_int: margin_method=\"auto\" (gap for d<30, mixture for d>=30)\n", "ci = explainer.conf_int(\n", " alpha=0.05,\n", " target=\"X\",\n", " alternative=\"greater\",\n", " verbose=True,\n", ")\n", "\n", "attribution_idx = np.where(ci[\"reject_null\"])[0]\n", "expected = set(relevant_idx)\n", "detected = set(attribution_idx.tolist())\n", "\n", "print(f\"\\nMargin method: {ci['margin_method']}, margin: {ci['margin']:.4f}\")\n", "print(\"Detected features:\", sorted(detected))\n", "print(\"Relevant features:\", sorted(expected))\n", "print(\"True positives:\", sorted(expected & detected))\n", "print(\"False positives:\", sorted(detected - expected))\n", "print(\"Missed:\", sorted(expected - detected))\n", "print()\n", "for i in range(d):\n", " tag = \"*\" if ci[\"reject_null\"][i] else \"\"\n", " status = \"model\" if i in active_idx else (\"corr\" if i in relevant_idx else \"null\")\n", " print(f\" X_{i} [{status:>5}]: phi={ci['score'][i]:.4f} se={ci['se'][i]:.4f}\"\n", " f\" z={ci['zscore'][i]:.2f} rank={ci['ranking'][i]:>2} p={ci['pvalue'][i]:.4f} {tag}\")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "feature_names = [f\"X{i}\" for i in range(d)]\n", "\n", "confidence_interval_plot(\n", " ci,\n", " feature_names=feature_names,\n", " show=False,\n", ")" ], "id": "eot-explainer-12" }, { "cell_type": "markdown", "id": "d45e88c2", "metadata": { "papermill": { "duration": 0.002768, "end_time": "2026-02-12T05:43:38.945609", "exception": false, "start_time": "2026-02-12T05:43:38.942841", "status": "completed" }, "tags": [] }, "source": [ "## Z-Space vs X-Space Importance\n", "\n", "The EOT decomposition first computes importance in the disentangled Z-space, then maps back to X-space via the backward weight matrix $W$.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "e2241247", "metadata": { "execution": { "iopub.execute_input": "2026-05-17T00:42:01.793329Z", "iopub.status.busy": "2026-05-17T00:42:01.793136Z", "iopub.status.idle": "2026-05-17T00:42:01.796494Z", "shell.execute_reply": "2026-05-17T00:42:01.796060Z" }, "papermill": { "duration": 0.274662, "end_time": "2026-02-12T05:43:39.222726", "exception": false, "start_time": "2026-02-12T05:43:38.948064", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "phi_Z = results[\"phi_Z\"]\n", "phi_X = results[\"phi_X\"]\n", "\n", "print(f\"{'Feature':>8} {'phi_Z':>10} {'phi_X':>10}\")\n", "print(\"-\" * 32)\n", "for i in range(d):\n", " print(f\"{'X_' + str(i):>8} {phi_Z[i]:>10.4f} {phi_X[i]:>10.4f}\")\n", "\n", "print(f\"\\nTotal phi_Z: {phi_Z.sum():.4f}\")\n", "print(f\"Total phi_X: {phi_X.sum():.4f}\")\n", "print(\"\\nNote: phi_Z measures importance in the disentangled space.\")\n", "print(\"phi_X maps it back to original features via the backward weights W.\")\n" ] }, { "cell_type": "markdown", "id": "c0d20c3c", "metadata": { "papermill": { "duration": 0.002331, "end_time": "2026-02-12T05:43:39.227767", "exception": false, "start_time": "2026-02-12T05:43:39.225436", "status": "completed" }, "tags": [] }, "source": [ "## Effect of Epsilon on Attribution\n", "\n", "Epsilon controls the EOT regularization. Smaller epsilon gives sharper transport (closer to exact OT), while larger epsilon shrinks toward Gaussian transport.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "26149aff", "metadata": { "execution": { "iopub.execute_input": "2026-05-17T00:42:01.798163Z", "iopub.status.busy": "2026-05-17T00:42:01.798061Z", "iopub.status.idle": "2026-05-17T00:42:01.849298Z", "shell.execute_reply": "2026-05-17T00:42:01.848814Z" }, "papermill": { "duration": 0.011575, "end_time": "2026-02-12T05:43:39.241748", "exception": false, "start_time": "2026-02-12T05:43:39.230173", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "epsilons = [1e-3, 0.01, 0.1]\n", "all_phi = {}\n", "\n", "for eps in epsilons:\n", " exp_eps = EOTExplainer(\n", " exp_model,\n", " data=X_train,\n", " nsamples=60,\n", " epsilon=eps,\n", " random_state=0,\n", " )\n", " res = exp_eps(X_test)\n", " all_phi[eps] = res[\"phi_X\"]\n", " print(f\"eps={eps:.2f} s={exp_eps.s_fwd:.4f} \"\n", " f\"active_mean={res['phi_X'][active_idx].mean():.4f} \"\n", " f\"null_mean={res['phi_X'][[i for i in range(d) if i not in active_idx]].mean():.4f}\")\n", "\n", "print()\n", "header = f\"{'Feature':>8}\" + \"\".join(f\"{'eps=' + str(e):>12}\" for e in epsilons)\n", "print(header)\n", "print(\"-\" * len(header))\n", "for i in range(d):\n", " row = f\"{'X_' + str(i):>8}\"\n", " for eps in epsilons:\n", " row += f\"{all_phi[eps][i]:>12.4f}\"\n", " print(row)\n" ] }, { "cell_type": "markdown", "id": "13007a36", "metadata": { "execution": { "iopub.execute_input": "2026-02-17T16:26:13.234601Z", "iopub.status.busy": "2026-02-17T16:26:13.234489Z", "iopub.status.idle": "2026-02-17T16:26:13.237114Z", "shell.execute_reply": "2026-02-17T16:26:13.236656Z" }, "papermill": { "duration": 0.008126, "end_time": "2026-02-12T05:43:39.252595", "exception": false, "start_time": "2026-02-12T05:43:39.244469", "status": "completed" }, "tags": [] }, "source": [ "## Compare with OTExplainer (Gaussian Baseline)\n", "\n", "The `OTExplainer` uses plain Gaussian whitening ($W = L$). The `EOTExplainer` adds the population backward projection ($W = L \\cdot M_w$), which can better handle non-Gaussian structure.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "2af05493", "metadata": { "execution": { "iopub.execute_input": "2026-05-17T00:42:01.850804Z", "iopub.status.busy": "2026-05-17T00:42:01.850693Z", "iopub.status.idle": "2026-05-17T00:42:01.869799Z", "shell.execute_reply": "2026-05-17T00:42:01.869337Z" }, "papermill": { "duration": 0.016598, "end_time": "2026-02-12T05:43:39.271857", "exception": false, "start_time": "2026-02-12T05:43:39.255259", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "from fdfi.explainers import OTExplainer\n", "\n", "explainer_ot = OTExplainer(\n", " exp_model,\n", " data=X_train,\n", " nsamples=60,\n", " random_state=0,\n", ")\n", "results_ot = explainer_ot(X_test)\n", "\n", "phi_ot = results_ot[\"phi_X\"]\n", "phi_eot = results[\"phi_X\"]\n", "\n", "print(f\"{'Feature':>8} {'OT (Gauss)':>12} {'EOT (Semicont)':>15} {'Status':>10}\")\n", "print(\"-\" * 49)\n", "for i in range(d):\n", " status = \"model\" if i in active_idx else (\"corr\" if i in relevant_idx else \"null\")\n", " print(f\"{'X_' + str(i):>8} {phi_ot[i]:>12.4f} {phi_eot[i]:>15.4f} {status:>10}\")\n", "\n", "ratio_ot = phi_ot[relevant_idx].mean() / phi_ot[null_idx].mean()\n", "ratio_eot = phi_eot[relevant_idx].mean() / phi_eot[null_idx].mean()\n", "print(f\"\\nRelevant/null ratio (OT): {ratio_ot:.2f}x\")\n", "print(f\"Relevant/null ratio (EOT): {ratio_eot:.2f}x\")" ] }, { "cell_type": "markdown", "id": "ee842231", "metadata": { "papermill": { "duration": 0.006106, "end_time": "2026-02-12T05:43:39.280514", "exception": false, "start_time": "2026-02-12T05:43:39.274408", "status": "completed" }, "tags": [] }, "source": [ "## Diagnostics and Summary\n", "\n", "Use `diagnostics` to inspect transport quality and `summary()` for a tabular overview.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "6d8207d4", "metadata": { "execution": { "iopub.execute_input": "2026-05-17T00:42:01.871324Z", "iopub.status.busy": "2026-05-17T00:42:01.871202Z", "iopub.status.idle": "2026-05-17T00:42:01.875764Z", "shell.execute_reply": "2026-05-17T00:42:01.875360Z" }, "papermill": { "duration": 0.035844, "end_time": "2026-02-12T05:43:39.319161", "exception": false, "start_time": "2026-02-12T05:43:39.283317", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "diag = explainer.diagnostics\n", "print(\"Diagnostics:\")\n", "print(f\" Latent independence (median dCor): {diag['latent_independence_median']:.6f} [{diag['latent_independence_label']}]\")\n", "print(f\" Distribution fidelity (MMD): {diag['distribution_fidelity_mmd']:.6f} [{diag['distribution_fidelity_label']}]\")\n", "print()\n", "\n", "# Standardized summary table\n", "_ = explainer.summary(alpha=0.05, target=\"X\", alternative=\"greater\")\n", "\n", "\n", "diagnostics_plot(diag, feature_names=feature_names, show=False)\n" ] }, { "cell_type": "markdown", "id": "872aba4b", "metadata": { "papermill": { "duration": 0.002478, "end_time": "2026-02-12T05:43:39.324421", "exception": false, "start_time": "2026-02-12T05:43:39.321943", "status": "completed" }, "tags": [] }, "source": [ "## Quick Reference\n", "\n", "```python\n", "from fdfi.explainers import EOTExplainer\n", "\n", "explainer = EOTExplainer(\n", " model,\n", " data=X_train,\n", " nsamples=60,\n", " auto_epsilon=True, # median-distance heuristic\n", " random_state=0,\n", ")\n", "\n", "results = explainer(X_test)\n", "# results[\"phi_X\"] — X-space feature importance\n", "# results[\"phi_Z\"] — Z-space (disentangled) importance\n", "\n", "# Attribution inference\n", "ci = explainer.conf_int(alpha=0.05, target=\"X\", alternative=\"greater\")\n", "significant = np.where(ci[\"reject_null\"])[0]\n", "\n", "# Inspect transport quality\n", "explainer.diagnostics\n", "```\n" ] }, { "cell_type": "markdown", "id": "0d89628e", "metadata": { "papermill": { "duration": 0.002542, "end_time": "2026-02-12T05:43:39.329450", "exception": false, "start_time": "2026-02-12T05:43:39.326908", "status": "completed" }, "tags": [] }, "source": [ "## Summary\n", "\n", "Key takeaways:\n", "\n", "1. `EOTExplainer` uses **semicontinuous entropic OT** — the forward map $Z = s \\cdot X_w$ is analytical (no Sinkhorn needed).\n", "2. **Population backward attribution** computes $W = L \\cdot M_w$ using the best linear projection from the coupling moments.\n", "3. FDFI detects **all features with predictive signal**, including correlated features — not only those directly in the model.\n", "4. `epsilon` controls regularization: smaller → closer to exact OT, larger → more Gaussian shrinkage.\n", "5. Use `auto_epsilon=True` for automatic tuning via the median-distance heuristic.\n", "6. `conf_int()` provides rigorous attribution inference with confidence intervals." ] } ], "metadata": { "kernelspec": { "display_name": "dfi", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.19" }, "papermill": { "default_parameters": {}, "duration": 2.71807, "end_time": "2026-02-12T05:43:39.550872", "environment_variables": {}, "exception": null, "input_path": "docs/tutorials/eot_explainer.ipynb", "output_path": "docs/tutorials/eot_explainer.ipynb", "parameters": {}, "start_time": "2026-02-12T05:43:36.832802", "version": "2.6.0" } }, "nbformat": 4, "nbformat_minor": 5 }