{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "c149bc65", "metadata": { "execution": { "iopub.execute_input": "2026-02-18T00:54:45.578461Z", "iopub.status.busy": "2026-02-18T00:54:45.578136Z", "iopub.status.idle": "2026-02-18T00:54:45.588565Z", "shell.execute_reply": "2026-02-18T00:54:45.587727Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "FDFI version: 0.0.5\n" ] } ], "source": [ "import fdfi\n", "print('FDFI version:', fdfi.__version__)" ] }, { "cell_type": "markdown", "id": "c02efd2d", "metadata": { "papermill": { "duration": 0.005769, "end_time": "2026-02-12T05:43:37.568161", "exception": false, "start_time": "2026-02-12T05:43:37.562392", "status": "completed" }, "tags": [] }, "source": [ "# EOTExplainer: Semicontinuous Entropic Optimal Transport\n", "\n", "This tutorial covers the `EOTExplainer`, which uses **semicontinuous entropic optimal transport** with **population backward attribution** (best linear projection) to compute disentangled feature importance.\n", "\n", "## What You'll Learn\n", "\n", "1. How EOT whitening + semicontinuous transport disentangles features\n", "2. How the population backward attribution maps Z-importance to X-importance\n", "3. How to run attribution inference with confidence intervals\n", "4. How epsilon controls the transport shrinkage\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "22022569", "metadata": { "execution": { "iopub.execute_input": "2026-02-18T00:54:45.591026Z", "iopub.status.busy": "2026-02-18T00:54:45.590822Z", "iopub.status.idle": "2026-02-18T00:54:47.756268Z", "shell.execute_reply": "2026-02-18T00:54:47.755594Z" }, "papermill": { "duration": 0.788061, "end_time": "2026-02-12T05:43:38.359344", "exception": false, "start_time": "2026-02-12T05:43:37.571283", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import numpy as np\n", "from fdfi.explainers import EOTExplainer\n", "\n", "np.random.seed(42)\n" ] }, { "cell_type": "markdown", "id": "bc1fdd30", "metadata": { "papermill": { "duration": 0.001562, "end_time": "2026-02-12T05:43:38.396820", "exception": false, "start_time": "2026-02-12T05:43:38.395258", "status": "completed" }, "tags": [] }, "source": [ "## Why Semicontinuous EOT?\n", "\n", "Gaussian OT whitens data via $\\Sigma^{-1/2}$, which assumes linear structure. **Semicontinuous EOT** solves the entropic transport problem between the empirical source and a continuous $\\mathcal{N}(0, I)$ target analytically:\n", "\n", "$$Z = s \\cdot X_{\\text{whitened}}, \\quad s = \\frac{2}{2 + \\varepsilon}$$\n", "\n", "The **population backward attribution** computes the best linear projection $E[X_w \\mid Z]$ using the analytically known coupling moments:\n", "\n", "$$M_w = E_\\pi[ZZ^\\top]^{-1} E_\\pi[ZX_w^\\top]$$\n", "\n", "Then the weight matrix $W = L \\cdot M_w$ maps Z-space importance to X-space via:\n", "\n", "$$\\phi_{X,j} = \\sum_k W_{jk}^2 \\cdot \\phi_{Z,k}$$\n" ] }, { "cell_type": "markdown", "id": "caa03543", "metadata": { "papermill": { "duration": 0.001674, "end_time": "2026-02-12T05:43:38.399974", "exception": false, "start_time": "2026-02-12T05:43:38.398300", "status": "completed" }, "tags": [] }, "source": [ "## Synthetic Data: Relevant vs Null Features\n", "\n", "We build a dataset with correlated features. The model directly uses $X_0, X_2, X_4$, but because $X_1$ is highly correlated with $X_0$ ($\\rho = 0.7$) and $X_3$ is correlated with $X_2$ ($\\rho = 0.5$), they also carry predictive signal. FDFI's design goal is to detect **all features that provide predictive information**, so the relevant set is $\\{X_0, X_1, X_2, X_3, X_4\\}$, while $X_5, \\ldots, X_9$ are truly null." ] }, { "cell_type": "code", "execution_count": 3, "id": "89ec974f", "metadata": { "execution": { "iopub.execute_input": "2026-02-18T00:54:47.758283Z", "iopub.status.busy": "2026-02-18T00:54:47.758129Z", "iopub.status.idle": "2026-02-18T00:54:47.763121Z", "shell.execute_reply": "2026-02-18T00:54:47.762544Z" }, "papermill": { "duration": 0.142511, "end_time": "2026-02-12T05:43:38.544232", "exception": false, "start_time": "2026-02-12T05:43:38.401721", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train shape: (400, 10)\n", "Test shape: (150, 10)\n", "Relevant features: [0, 1, 2, 3, 4] (model uses 0,2,4; 1,3 correlated)\n", "Null features: [5, 6, 7, 8, 9]\n", "Correlation(X0, X1): 0.684\n", "Correlation(X2, X3): 0.481\n" ] } ], "source": [ "# Correlated synthetic data with known active features\n", "n_train = 400\n", "n_test = 150\n", "d = 10\n", "\n", "# Build a covariance matrix with block correlations\n", "rng = np.random.default_rng(42)\n", "Sigma = np.eye(d)\n", "# Correlate features 0-1 and 2-3\n", "Sigma[0, 1] = Sigma[1, 0] = 0.7\n", "Sigma[2, 3] = Sigma[3, 2] = 0.5\n", "\n", "X_train = rng.multivariate_normal(np.zeros(d), Sigma, size=n_train)\n", "X_test = rng.multivariate_normal(np.zeros(d), Sigma, size=n_test)\n", "\n", "# Model directly uses features 0, 2, 4\n", "active_idx = [0, 2, 4]\n", "# Features 1 and 3 carry predictive signal via correlation\n", "# → all 5 are \"relevant\"; features 5-9 are truly null\n", "relevant_idx = [0, 1, 2, 3, 4]\n", "null_idx = [5, 6, 7, 8, 9]\n", "\n", "def exp_model(X):\n", " return 3.0 * X[:, 0] + 2.0 * X[:, 2] + 1.5 * X[:, 4]\n", "\n", "print(\"Train shape:\", X_train.shape)\n", "print(\"Test shape:\", X_test.shape)\n", "print(\"Relevant features:\", relevant_idx, \"(model uses 0,2,4; 1,3 correlated)\")\n", "print(\"Null features:\", null_idx)\n", "print(\"Correlation(X0, X1):\", f\"{np.corrcoef(X_train[:, 0], X_train[:, 1])[0, 1]:.3f}\")\n", "print(\"Correlation(X2, X3):\", f\"{np.corrcoef(X_train[:, 2], X_train[:, 3])[0, 1]:.3f}\")" ] }, { "cell_type": "code", "execution_count": 4, "id": "37bd4e00", "metadata": { "execution": { "iopub.execute_input": "2026-02-18T00:54:47.764502Z", "iopub.status.busy": "2026-02-18T00:54:47.764390Z", "iopub.status.idle": "2026-02-18T00:54:47.767196Z", "shell.execute_reply": "2026-02-18T00:54:47.766639Z" }, "papermill": { "duration": 0.006289, "end_time": "2026-02-12T05:43:38.552692", "exception": false, "start_time": "2026-02-12T05:43:38.546403", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Preview predictions: [-0.623 -0.873 -2.557 3.994 -1.331]\n", "Response variance: 14.504\n" ] } ], "source": [ "# Sanity check: model predictions\n", "y_preview = exp_model(X_test[:5])\n", "print(\"Preview predictions:\", np.round(y_preview, 3))\n", "print(\"Response variance:\", f\"{np.var(exp_model(X_train)):.3f}\")\n" ] }, { "cell_type": "markdown", "id": "63ea52c7", "metadata": { "papermill": { "duration": 0.001823, "end_time": "2026-02-12T05:43:38.556565", "exception": false, "start_time": "2026-02-12T05:43:38.554742", "status": "completed" }, "tags": [] }, "source": [ "## Basic EOTExplainer Usage\n", "\n", "Create an explainer, compute importance, and inspect results.\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "c869c63a", "metadata": { "execution": { "iopub.execute_input": "2026-02-18T00:54:47.768878Z", "iopub.status.busy": "2026-02-18T00:54:47.768750Z", "iopub.status.idle": "2026-02-18T00:54:49.848341Z", "shell.execute_reply": "2026-02-18T00:54:49.847839Z" }, "papermill": { "duration": 0.02748, "end_time": "2026-02-12T05:43:38.585709", "exception": false, "start_time": "2026-02-12T05:43:38.558229", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Feature importance (phi_X):\n", "-------------------------------------------------------\n", " Feature phi_X Status\n", "-------------------------------------------------------\n", " X_0 6.2095 model\n", " X_1 1.9614 correlated\n", " X_2 3.6030 model\n", " X_3 0.5729 correlated\n", " X_4 1.6349 model\n", " X_5 0.0253 null\n", " X_6 0.0245 null\n", " X_7 0.0353 null\n", " X_8 0.0266 null\n", " X_9 0.0177 null\n", "\n", "Auto epsilon: 0.4700\n", "Forward shrinkage s: 0.9817\n", "Backward weight matrix W shape: (10, 10)\n" ] } ], "source": [ "explainer = EOTExplainer(\n", " exp_model,\n", " data=X_train,\n", " nsamples=60,\n", " auto_epsilon=True,\n", " random_state=0,\n", ")\n", "\n", "results = explainer(X_test)\n", "phi_X = results[\"phi_X\"]\n", "\n", "print(\"Feature importance (phi_X):\")\n", "print(\"-\" * 55)\n", "print(f\"{'Feature':>8} {'phi_X':>10} {'Status':>12}\")\n", "print(\"-\" * 55)\n", "for i in range(d):\n", " status = \"model\" if i in active_idx else (\"correlated\" if i in relevant_idx else \"null\")\n", " print(f\"{'X_' + str(i):>8} {phi_X[i]:>10.4f} {status:>12}\")\n", "\n", "print(f\"\\nAuto epsilon: {explainer.epsilon:.4f}\")\n", "print(f\"Forward shrinkage s: {explainer.s_fwd:.4f}\")\n", "print(f\"Backward weight matrix W shape: {explainer.W.shape}\")" ] }, { "cell_type": "markdown", "id": "692d413d", "metadata": {}, "source": [ "## Attribution Inference\n", "\n", "Use one-sided testing to identify features with significant predictive importance. We expect all 5 relevant features ($X_0$–$X_4$) to be detected, while the 5 null features ($X_5$–$X_9$) should not." ] }, { "cell_type": "code", "execution_count": 6, "id": "7c7a6e95", "metadata": { "execution": { "iopub.execute_input": "2026-02-18T00:54:49.849917Z", "iopub.status.busy": "2026-02-18T00:54:49.849760Z", "iopub.status.idle": "2026-02-18T00:54:49.853047Z", "shell.execute_reply": "2026-02-18T00:54:49.852607Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[margin] method=auto → gap (d=10 < 30)\n", "[margin] gap: sorted phi range [0.0177, 6.2095], largest log-gap between rank 4 (0.0353) and 5 (0.5729), ratio=16.2x, margin=0.0353\n", "\n", "Margin method: gap, margin: 0.0353\n", "Detected features: [0, 1, 2, 3, 4]\n", "Relevant features: [0, 1, 2, 3, 4]\n", "True positives: [0, 1, 2, 3, 4]\n", "False positives: []\n", "Missed: []\n", "\n", " X_0 [model]: phi=6.2095 se=0.6972 p=0.0000 *\n", " X_1 [ corr]: phi=1.9614 se=0.2225 p=0.0000 *\n", " X_2 [model]: phi=3.6030 se=0.5154 p=0.0000 *\n", " X_3 [ corr]: phi=0.5729 se=0.1745 p=0.0010 *\n", " X_4 [model]: phi=1.6349 se=0.2475 p=0.0000 *\n", " X_5 [ null]: phi=0.0253 se=0.1653 p=0.5243 \n", " X_6 [ null]: phi=0.0245 se=0.1653 p=0.5261 \n", " X_7 [ null]: phi=0.0353 se=0.1653 p=0.5000 \n", " X_8 [ null]: phi=0.0266 se=0.1653 p=0.5211 \n", " X_9 [ null]: phi=0.0177 se=0.1653 p=0.5425 \n" ] } ], "source": [ "# Default conf_int: margin_method=\"auto\" (gap for d<30, mixture for d>=30)\n", "ci = explainer.conf_int(\n", " alpha=0.05,\n", " target=\"X\",\n", " alternative=\"greater\",\n", " verbose=True,\n", ")\n", "\n", "attribution_idx = np.where(ci[\"reject_null\"])[0]\n", "expected = set(relevant_idx)\n", "detected = set(attribution_idx.tolist())\n", "\n", "print(f\"\\nMargin method: {ci['margin_method']}, margin: {ci['margin']:.4f}\")\n", "print(\"Detected features:\", sorted(detected))\n", "print(\"Relevant features:\", sorted(expected))\n", "print(\"True positives:\", sorted(expected & detected))\n", "print(\"False positives:\", sorted(detected - expected))\n", "print(\"Missed:\", sorted(expected - detected))\n", "print()\n", "for i in range(d):\n", " tag = \"*\" if ci[\"reject_null\"][i] else \"\"\n", " status = \"model\" if i in active_idx else (\"corr\" if i in relevant_idx else \"null\")\n", " print(f\" X_{i} [{status:>5}]: phi={ci['score'][i]:.4f} se={ci['se'][i]:.4f} p={ci['pvalue'][i]:.4f} {tag}\")" ] }, { "cell_type": "markdown", "id": "d45e88c2", "metadata": { "papermill": { "duration": 0.002768, "end_time": "2026-02-12T05:43:38.945609", "exception": false, "start_time": "2026-02-12T05:43:38.942841", "status": "completed" }, "tags": [] }, "source": [ "## Z-Space vs X-Space Importance\n", "\n", "The EOT decomposition first computes importance in the disentangled Z-space, then maps back to X-space via the backward weight matrix $W$.\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "e2241247", "metadata": { "execution": { "iopub.execute_input": "2026-02-18T00:54:49.854629Z", "iopub.status.busy": "2026-02-18T00:54:49.854531Z", "iopub.status.idle": "2026-02-18T00:54:50.246801Z", "shell.execute_reply": "2026-02-18T00:54:50.246319Z" }, "papermill": { "duration": 0.274662, "end_time": "2026-02-12T05:43:39.222726", "exception": false, "start_time": "2026-02-12T05:43:38.948064", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " Feature phi_Z phi_X\n", "--------------------------------\n", " X_0 8.3475 6.2095\n", " X_1 1.0850 1.9614\n", " X_2 3.7892 3.6030\n", " X_3 0.3915 0.5729\n", " X_4 1.8826 1.6349\n", " X_5 0.0185 0.0253\n", " X_6 0.0191 0.0245\n", " X_7 0.0209 0.0353\n", " X_8 0.0201 0.0266\n", " X_9 0.0187 0.0177\n", "\n", "Total phi_Z: 15.5931\n", "Total phi_X: 14.1110\n", "\n", "Note: phi_Z measures importance in the disentangled space.\n", "phi_X maps it back to original features via the backward weights W.\n" ] } ], "source": [ "phi_Z = results[\"phi_Z\"]\n", "phi_X = results[\"phi_X\"]\n", "\n", "print(f\"{'Feature':>8} {'phi_Z':>10} {'phi_X':>10}\")\n", "print(\"-\" * 32)\n", "for i in range(d):\n", " print(f\"{'X_' + str(i):>8} {phi_Z[i]:>10.4f} {phi_X[i]:>10.4f}\")\n", "\n", "print(f\"\\nTotal phi_Z: {phi_Z.sum():.4f}\")\n", "print(f\"Total phi_X: {phi_X.sum():.4f}\")\n", "print(\"\\nNote: phi_Z measures importance in the disentangled space.\")\n", "print(\"phi_X maps it back to original features via the backward weights W.\")\n" ] }, { "cell_type": "markdown", "id": "c0d20c3c", "metadata": { "papermill": { "duration": 0.002331, "end_time": "2026-02-12T05:43:39.227767", "exception": false, "start_time": "2026-02-12T05:43:39.225436", "status": "completed" }, "tags": [] }, "source": [ "## Effect of Epsilon on Attribution\n", "\n", "Epsilon controls the EOT regularization. Smaller epsilon gives sharper transport (closer to exact OT), while larger epsilon shrinks toward Gaussian transport.\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "26149aff", "metadata": { "execution": { "iopub.execute_input": "2026-02-18T00:54:50.248619Z", "iopub.status.busy": "2026-02-18T00:54:50.248492Z", "iopub.status.idle": "2026-02-18T00:54:50.659800Z", "shell.execute_reply": "2026-02-18T00:54:50.659347Z" }, "papermill": { "duration": 0.011575, "end_time": "2026-02-12T05:43:39.241748", "exception": false, "start_time": "2026-02-12T05:43:39.230173", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "eps=0.00 s=1.0000 active_mean=3.9655 null_mean=0.3813\n", "eps=0.01 s=1.0000 active_mean=3.9654 null_mean=0.3812\n", "eps=0.10 s=0.9989 active_mean=3.9553 null_mean=0.3801\n", "\n", " Feature eps=0.001 eps=0.01 eps=0.1\n", "--------------------------------------------\n", " X_0 6.4625 6.4623 6.4461\n", " X_1 2.0585 2.0584 2.0513\n", " X_2 3.7603 3.7602 3.7495\n", " X_3 0.5681 0.5681 0.5675\n", " X_4 1.6737 1.6737 1.6705\n", " X_5 0.0068 0.0068 0.0068\n", " X_6 0.0049 0.0049 0.0050\n", " X_7 0.0190 0.0190 0.0191\n", " X_8 0.0105 0.0105 0.0104\n", " X_9 0.0009 0.0009 0.0010\n" ] } ], "source": [ "epsilons = [1e-3, 0.01, 0.1]\n", "all_phi = {}\n", "\n", "for eps in epsilons:\n", " exp_eps = EOTExplainer(\n", " exp_model,\n", " data=X_train,\n", " nsamples=60,\n", " epsilon=eps,\n", " random_state=0,\n", " )\n", " res = exp_eps(X_test)\n", " all_phi[eps] = res[\"phi_X\"]\n", " print(f\"eps={eps:.2f} s={exp_eps.s_fwd:.4f} \"\n", " f\"active_mean={res['phi_X'][active_idx].mean():.4f} \"\n", " f\"null_mean={res['phi_X'][[i for i in range(d) if i not in active_idx]].mean():.4f}\")\n", "\n", "print()\n", "header = f\"{'Feature':>8}\" + \"\".join(f\"{'eps=' + str(e):>12}\" for e in epsilons)\n", "print(header)\n", "print(\"-\" * len(header))\n", "for i in range(d):\n", " row = f\"{'X_' + str(i):>8}\"\n", " for eps in epsilons:\n", " row += f\"{all_phi[eps][i]:>12.4f}\"\n", " print(row)\n" ] }, { "cell_type": "markdown", "id": "13007a36", "metadata": { "execution": { "iopub.execute_input": "2026-02-17T16:26:13.234601Z", "iopub.status.busy": "2026-02-17T16:26:13.234489Z", "iopub.status.idle": "2026-02-17T16:26:13.237114Z", "shell.execute_reply": "2026-02-17T16:26:13.236656Z" }, "papermill": { "duration": 0.008126, "end_time": "2026-02-12T05:43:39.252595", "exception": false, "start_time": "2026-02-12T05:43:39.244469", "status": "completed" }, "tags": [] }, "source": [ "## Compare with OTExplainer (Gaussian Baseline)\n", "\n", "The `OTExplainer` uses plain Gaussian whitening ($W = L$). The `EOTExplainer` adds the population backward projection ($W = L \\cdot M_w$), which can better handle non-Gaussian structure.\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "2af05493", "metadata": { "execution": { "iopub.execute_input": "2026-02-18T00:54:50.661842Z", "iopub.status.busy": "2026-02-18T00:54:50.661701Z", "iopub.status.idle": "2026-02-18T00:54:51.429749Z", "shell.execute_reply": "2026-02-18T00:54:51.429304Z" }, "papermill": { "duration": 0.016598, "end_time": "2026-02-12T05:43:39.271857", "exception": false, "start_time": "2026-02-12T05:43:39.255259", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " Feature OT (Gauss) EOT (Semicont) Status\n", "-------------------------------------------------\n", " X_0 6.4625 6.2095 model\n", " X_1 2.0585 1.9614 corr\n", " X_2 3.7603 3.6030 model\n", " X_3 0.5681 0.5729 corr\n", " X_4 1.6737 1.6349 model\n", " X_5 0.0068 0.0253 null\n", " X_6 0.0049 0.0245 null\n", " X_7 0.0190 0.0353 null\n", " X_8 0.0105 0.0266 null\n", " X_9 0.0009 0.0177 null\n", "\n", "Relevant/null ratio (OT): 344.14x\n", "Relevant/null ratio (EOT): 108.07x\n" ] } ], "source": [ "from fdfi.explainers import OTExplainer\n", "\n", "explainer_ot = OTExplainer(\n", " exp_model,\n", " data=X_train,\n", " nsamples=60,\n", " random_state=0,\n", ")\n", "results_ot = explainer_ot(X_test)\n", "\n", "phi_ot = results_ot[\"phi_X\"]\n", "phi_eot = results[\"phi_X\"]\n", "\n", "print(f\"{'Feature':>8} {'OT (Gauss)':>12} {'EOT (Semicont)':>15} {'Status':>10}\")\n", "print(\"-\" * 49)\n", "for i in range(d):\n", " status = \"model\" if i in active_idx else (\"corr\" if i in relevant_idx else \"null\")\n", " print(f\"{'X_' + str(i):>8} {phi_ot[i]:>12.4f} {phi_eot[i]:>15.4f} {status:>10}\")\n", "\n", "ratio_ot = phi_ot[relevant_idx].mean() / phi_ot[null_idx].mean()\n", "ratio_eot = phi_eot[relevant_idx].mean() / phi_eot[null_idx].mean()\n", "print(f\"\\nRelevant/null ratio (OT): {ratio_ot:.2f}x\")\n", "print(f\"Relevant/null ratio (EOT): {ratio_eot:.2f}x\")" ] }, { "cell_type": "markdown", "id": "ee842231", "metadata": { "papermill": { "duration": 0.006106, "end_time": "2026-02-12T05:43:39.280514", "exception": false, "start_time": "2026-02-12T05:43:39.274408", "status": "completed" }, "tags": [] }, "source": [ "## Diagnostics and Summary\n", "\n", "Use `diagnostics` to inspect transport quality and `summary()` for a tabular overview.\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "6d8207d4", "metadata": { "execution": { "iopub.execute_input": "2026-02-18T00:54:51.431808Z", "iopub.status.busy": "2026-02-18T00:54:51.431642Z", "iopub.status.idle": "2026-02-18T00:54:51.435454Z", "shell.execute_reply": "2026-02-18T00:54:51.434858Z" }, "papermill": { "duration": 0.035844, "end_time": "2026-02-12T05:43:39.319161", "exception": false, "start_time": "2026-02-12T05:43:39.283317", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Diagnostics:\n", " Latent independence (median dCor): 0.074783 [GOOD]\n", " Distribution fidelity (MMD): 0.011061 [GOOD]\n", "\n", "==============================================================================\n", "Feature Importance Results\n", "==============================================================================\n", "Method: EOTExplainer\n", "Number of features: 10\n", "Significance level: 0.05\n", "Alternative: greater\n", "Margin method: gap\n", "Practical margin: 0.0353\n", "------------------------------------------------------------------------------\n", " Feature Estimate Std Err CI Lower CI Upper P-value Sig\n", "------------------------------------------------------------------------------\n", " 0 6.2095 0.6972 5.0628 inf 0.0000 ***\n", " 1 1.9614 0.2225 1.5955 inf 0.0000 ***\n", " 2 3.6030 0.5154 2.7553 inf 0.0000 ***\n", " 3 0.5729 0.1745 0.2858 inf 0.0010 ***\n", " 4 1.6349 0.2475 1.2278 inf 0.0000 ***\n", " 5 0.0253 0.1653 -0.2466 inf 0.5243 \n", " 6 0.0245 0.1653 -0.2474 inf 0.5261 \n", " 7 0.0353 0.1653 -0.2365 inf 0.5000 \n", " 8 0.0266 0.1653 -0.2453 inf 0.5211 \n", " 9 0.0177 0.1653 -0.2542 inf 0.5425 \n", "==============================================================================\n", "Significant features: 5 / 10\n", "---\n", "Signif. codes: 0 '***' 0.01 '**' 0.05 '*' 0.1 ' ' 1\n", "==============================================================================\n" ] } ], "source": [ "diag = explainer.diagnostics\n", "print(\"Diagnostics:\")\n", "print(f\" Latent independence (median dCor): {diag['latent_independence_median']:.6f} [{diag['latent_independence_label']}]\")\n", "print(f\" Distribution fidelity (MMD): {diag['distribution_fidelity_mmd']:.6f} [{diag['distribution_fidelity_label']}]\")\n", "print()\n", "\n", "# Standardized summary table\n", "_ = explainer.summary(alpha=0.05, target=\"X\", alternative=\"greater\")\n" ] }, { "cell_type": "markdown", "id": "872aba4b", "metadata": { "papermill": { "duration": 0.002478, "end_time": "2026-02-12T05:43:39.324421", "exception": false, "start_time": "2026-02-12T05:43:39.321943", "status": "completed" }, "tags": [] }, "source": [ "## Quick Reference\n", "\n", "```python\n", "from fdfi.explainers import EOTExplainer\n", "\n", "explainer = EOTExplainer(\n", " model,\n", " data=X_train,\n", " nsamples=60,\n", " auto_epsilon=True, # median-distance heuristic\n", " random_state=0,\n", ")\n", "\n", "results = explainer(X_test)\n", "# results[\"phi_X\"] — X-space feature importance\n", "# results[\"phi_Z\"] — Z-space (disentangled) importance\n", "\n", "# Attribution inference\n", "ci = explainer.conf_int(alpha=0.05, target=\"X\", alternative=\"greater\")\n", "significant = np.where(ci[\"reject_null\"])[0]\n", "\n", "# Inspect transport quality\n", "explainer.diagnostics\n", "```\n" ] }, { "cell_type": "markdown", "id": "0d89628e", "metadata": { "papermill": { "duration": 0.002542, "end_time": "2026-02-12T05:43:39.329450", "exception": false, "start_time": "2026-02-12T05:43:39.326908", "status": "completed" }, "tags": [] }, "source": [ "## Summary\n", "\n", "Key takeaways:\n", "\n", "1. `EOTExplainer` uses **semicontinuous entropic OT** — the forward map $Z = s \\cdot X_w$ is analytical (no Sinkhorn needed).\n", "2. **Population backward attribution** computes $W = L \\cdot M_w$ using the best linear projection from the coupling moments.\n", "3. FDFI detects **all features with predictive signal**, including correlated features — not only those directly in the model.\n", "4. `epsilon` controls regularization: smaller → closer to exact OT, larger → more Gaussian shrinkage.\n", "5. Use `auto_epsilon=True` for automatic tuning via the median-distance heuristic.\n", "6. `conf_int()` provides rigorous attribution inference with confidence intervals." ] } ], "metadata": { "kernelspec": { "display_name": "dfi", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.19" }, "papermill": { "default_parameters": {}, "duration": 2.71807, "end_time": "2026-02-12T05:43:39.550872", "environment_variables": {}, "exception": null, "input_path": "docs/tutorials/eot_explainer.ipynb", "output_path": "docs/tutorials/eot_explainer.ipynb", "parameters": {}, "start_time": "2026-02-12T05:43:36.832802", "version": "2.6.0" } }, "nbformat": 4, "nbformat_minor": 5 }