{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "a3208230", "metadata": { "execution": { "iopub.execute_input": "2026-02-17T16:01:03.941070Z", "iopub.status.busy": "2026-02-17T16:01:03.940931Z", "iopub.status.idle": "2026-02-17T16:01:03.948147Z", "shell.execute_reply": "2026-02-17T16:01:03.947431Z" } }, "outputs": [], "source": [ "import fdfi\n", "print('FDFI version:', fdfi.__version__)" ] }, { "cell_type": "markdown", "id": "a6923e4a", "metadata": { "papermill": { "duration": 0.007708, "end_time": "2026-02-12T05:43:45.872831", "exception": false, "start_time": "2026-02-12T05:43:45.865123", "status": "completed" }, "tags": [] }, "source": [ "# FlowExplainer: Flow-Disentangled Feature Importance\n", "\n", "This tutorial covers the `FlowExplainer`, which uses **normalizing flows** to compute feature importance. By the end, you'll understand:\n", "\n", "1. How Flow-DFI differs from OT-based methods\n", "2. The difference between CPI and SCPI methods\n", "3. How to use FlowExplainer with default and custom flow models\n", "4. When to choose FlowExplainer over OTExplainer" ] }, { "cell_type": "markdown", "id": "8f3e56e6", "metadata": { "papermill": { "duration": 0.003491, "end_time": "2026-02-12T05:43:45.880304", "exception": false, "start_time": "2026-02-12T05:43:45.876813", "status": "completed" }, "tags": [] }, "source": [ "## Setup\n", "\n", "First, let's import the necessary libraries:" ] }, { "cell_type": "code", "execution_count": null, "id": "7c6820dd", "metadata": { "execution": { "iopub.execute_input": "2026-02-17T16:01:03.950075Z", "iopub.status.busy": "2026-02-17T16:01:03.949917Z", "iopub.status.idle": "2026-02-17T16:01:04.379302Z", "shell.execute_reply": "2026-02-17T16:01:04.378711Z" }, "papermill": { "duration": 0.525849, "end_time": "2026-02-12T05:43:46.462357", "exception": false, "start_time": "2026-02-12T05:43:45.936508", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import sys\n", "sys.path.insert(0, '../..') # Add project root to path\n", "\n", "import numpy as np\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "\n", "from fdfi.explainers import FlowExplainer, OTExplainer\n", "from fdfi.plots import confidence_interval_plot, diagnostics_plot, summary_bar\n", "\n", "# Set random seed for reproducibility\n", "np.random.seed(42)\n", "\n" ] }, { "cell_type": "markdown", "id": "34f4b836", "metadata": { "papermill": { "duration": 0.001672, "end_time": "2026-02-12T05:43:46.466014", "exception": false, "start_time": "2026-02-12T05:43:46.464342", "status": "completed" }, "tags": [] }, "source": [ "## Background: Why Flow-DFI?\n", "\n", "**OTExplainer** and **EOTExplainer** use optimal transport to create a disentangled representation. However, they rely on parametric assumptions (Gaussian or kernel-based).\n", "\n", "**FlowExplainer** uses a **normalizing flow** — a learned neural network that maps between the original feature space X and a disentangled latent space Z. This is:\n", "\n", "- **More flexible**: Learns complex non-linear transformations\n", "- **Data-driven**: No parametric assumptions\n", "- **Invertible**: Can map X → Z and Z → X\n", "\n", "The tradeoff is that it requires training a neural network, which takes longer." ] }, { "cell_type": "markdown", "id": "0a01ebcd", "metadata": { "papermill": { "duration": 0.001575, "end_time": "2026-02-12T05:43:46.469195", "exception": false, "start_time": "2026-02-12T05:43:46.467620", "status": "completed" }, "tags": [] }, "source": [ "## Create Test Data\n", "\n", "Let's create data with correlated features where only some features are truly important:" ] }, { "cell_type": "code", "execution_count": null, "id": "3993bd28", "metadata": { "execution": { "iopub.execute_input": "2026-02-17T16:01:04.382108Z", "iopub.status.busy": "2026-02-17T16:01:04.381911Z", "iopub.status.idle": "2026-02-17T16:01:04.386185Z", "shell.execute_reply": "2026-02-17T16:01:04.385629Z" }, "papermill": { "duration": 0.007333, "end_time": "2026-02-12T05:43:46.478102", "exception": false, "start_time": "2026-02-12T05:43:46.470769", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "# Create correlated features\n", "n_samples = 500\n", "n_features = 8\n", "\n", "# Covariance matrix with correlations\n", "cov = np.eye(n_features)\n", "cov[0, 1] = cov[1, 0] = 0.7 # Features 0 and 1 are correlated\n", "cov[2, 3] = cov[3, 2] = 0.5 # Features 2 and 3 are correlated\n", "\n", "X = np.random.multivariate_normal(np.zeros(n_features), cov, size=n_samples)\n", "\n", "# Model: only features 0, 1, 2 are active\n", "def model(X):\n", " return X[:, 0] + 2 * X[:, 1] + 0.5 * X[:, 2]\n", "\n", "# Split into train/test\n", "X_train, X_test = X[:400], X[400:]\n", "\n", "print(f\"Training data: {X_train.shape}\")\n", "print(f\"Test data: {X_test.shape}\")\n", "print(f\"\\nActive features: 0, 1, 2\")\n", "print(f\"Null features: 3, 4, 5, 6, 7\")" ] }, { "cell_type": "markdown", "id": "bc92ea4f", "metadata": { "papermill": { "duration": 0.001792, "end_time": "2026-02-12T05:43:46.482027", "exception": false, "start_time": "2026-02-12T05:43:46.480235", "status": "completed" }, "tags": [] }, "source": [ "## Basic Usage: CPI Method (Default)\n", "\n", "The simplest way to use FlowExplainer is with the default CPI (Conditional Permutation Importance) method:" ] }, { "cell_type": "code", "execution_count": null, "id": "3b3caa62", "metadata": { "execution": { "iopub.execute_input": "2026-02-17T16:01:04.388238Z", "iopub.status.busy": "2026-02-17T16:01:04.388121Z", "iopub.status.idle": "2026-02-17T16:01:08.018901Z", "shell.execute_reply": "2026-02-17T16:01:08.017922Z" }, "papermill": { "duration": 5.229684, "end_time": "2026-02-12T05:43:51.713610", "exception": false, "start_time": "2026-02-12T05:43:46.483926", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "# Create FlowExplainer with CPI method\n", "explainer_cpi = FlowExplainer(\n", " model,\n", " data=X_train,\n", " method='cpi', # Conditional Permutation Importance\n", " nsamples=30, # Monte Carlo samples per feature\n", " num_steps=200, # Flow training steps (use more for better results)\n", " random_state=42,\n", ")\n", "\n", "# Compute importance\n", "results_cpi = explainer_cpi(X_test)\n", "\n", "print(\"\\nCPI Feature Importance:\")\n", "print(\"-\" * 40)\n", "for i, phi in enumerate(results_cpi['phi_X']):\n", " marker = \"*\" if i < 3 else \"\" # Mark active features\n", " print(f\" Feature {i}: {phi:8.4f} {marker}\")\n", "print(\"\\n* = active feature\")" ] }, { "cell_type": "markdown", "id": "10223cd3", "metadata": { "papermill": { "duration": 0.008607, "end_time": "2026-02-12T05:43:51.731024", "exception": false, "start_time": "2026-02-12T05:43:51.722417", "status": "completed" }, "tags": [] }, "source": [ "## CPI vs SCPI: What's the Difference?\n", "\n", "FlowExplainer provides two methods for computing importance in the **latent Z-space**:\n", "\n", "**CPI (Conditional Permutation Importance)**\n", "- Squared difference after averaging predictions\n", "- Formula: $\\phi_{Z,j}^{CPI} = (Y - \\mathbb{E}_b[f(\\tilde{X}_b^{(j)})])^2$\n", "\n", "**SCPI (Sobol-CPI)**\n", "- Conditional variance of counterfactual predictions\n", "- Formula: $\\phi_{Z,j}^{SCPI} = \\text{Var}_b[f(\\tilde{X}_b^{(j)})]$\n", "- Related to Sobol total-order sensitivity indices\n", "\n", "For **L2 loss** with independent (disentangled) features, **CPI and SCPI give similar results**, since both measure how much the model output changes when feature $j$ is permuted.\n", "\n", "**Jacobian Transformation to X-space**\n", "\n", "Both methods compute importance in Z-space (disentangled features). To attribute importance to the **original features** $X_l$, FlowExplainer uses the **Jacobian** $H = \\frac{\\partial X}{\\partial Z}$:\n", "\n", "$$\\phi_{X,l} = \\sum_{k=1}^{d} H_{lk}^2 \\cdot \\phi_{Z,k}$$\n", "\n", "This correctly accounts for how each latent dimension affects each original feature." ] }, { "cell_type": "code", "execution_count": null, "id": "f9d94ca0", "metadata": { "execution": { "iopub.execute_input": "2026-02-17T16:01:08.020566Z", "iopub.status.busy": "2026-02-17T16:01:08.020396Z", "iopub.status.idle": "2026-02-17T16:01:09.817724Z", "shell.execute_reply": "2026-02-17T16:01:09.816908Z" }, "papermill": { "duration": 3.495416, "end_time": "2026-02-12T05:43:55.234856", "exception": false, "start_time": "2026-02-12T05:43:51.739440", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "# Create FlowExplainer with SCPI method\n", "explainer_scpi = FlowExplainer(\n", " model,\n", " data=X_train,\n", " method='scpi', # Sobol-CPI\n", " nsamples=30,\n", " num_steps=200,\n", " random_state=42,\n", ")\n", "\n", "results_scpi = explainer_scpi(X_test)\n", "\n", "# Compare CPI vs SCPI\n", "print(\"Comparison: CPI vs SCPI\")\n", "print(\"-\" * 50)\n", "print(f\"{'Feature':>8} {'CPI':>12} {'SCPI':>12} {'Active':>10}\")\n", "print(\"-\" * 50)\n", "for i in range(n_features):\n", " active = \"Yes\" if i < 3 else \"No\"\n", " print(f\"{i:>8} {results_cpi['phi_X'][i]:>12.4f} {results_scpi['phi_X'][i]:>12.4f} {active:>10}\")" ] }, { "cell_type": "markdown", "id": "877171ff", "metadata": { "papermill": { "duration": 0.014726, "end_time": "2026-02-12T05:43:55.265758", "exception": false, "start_time": "2026-02-12T05:43:55.251032", "status": "completed" }, "tags": [] }, "source": [ "## Computing Both Methods at Once\n", "\n", "Use `method='both'` to compute CPI and SCPI simultaneously (more efficient than running twice):" ] }, { "cell_type": "code", "execution_count": null, "id": "4735221c", "metadata": { "execution": { "iopub.execute_input": "2026-02-17T16:01:09.819661Z", "iopub.status.busy": "2026-02-17T16:01:09.819515Z", "iopub.status.idle": "2026-02-17T16:01:11.699991Z", "shell.execute_reply": "2026-02-17T16:01:11.699264Z" }, "papermill": { "duration": 3.633958, "end_time": "2026-02-12T05:43:58.914400", "exception": false, "start_time": "2026-02-12T05:43:55.280442", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "# Compute both CPI and SCPI\n", "explainer_both = FlowExplainer(\n", " model,\n", " data=X_train,\n", " method='both',\n", " nsamples=30,\n", " num_steps=200,\n", " random_state=42,\n", ")\n", "\n", "results_both = explainer_both(X_test)\n", "\n", "print(\"Result keys:\", list(results_both.keys()))\n", "print(\"\\nCPI importance (phi_Z):\", results_both['phi_Z'][:3].round(4))\n", "print(\"SCPI importance (phi_Z_scpi):\", results_both['phi_Z_scpi'][:3].round(4))" ] }, { "cell_type": "markdown", "id": "4b254a26", "metadata": { "papermill": { "duration": 0.020779, "end_time": "2026-02-12T05:43:58.956251", "exception": false, "start_time": "2026-02-12T05:43:58.935472", "status": "completed" }, "tags": [] }, "source": [ "## Confidence Intervals and Summary\n", "\n", "Use the built-in `conf_int()` and `summary()` methods for quick, reproducible inference.\n", "This avoids custom ad hoc diagnostics code and keeps reporting consistent across explainers.\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "edd0887e", "metadata": { "execution": { "iopub.execute_input": "2026-02-17T16:01:11.702265Z", "iopub.status.busy": "2026-02-17T16:01:11.702076Z", "iopub.status.idle": "2026-02-17T16:01:11.705753Z", "shell.execute_reply": "2026-02-17T16:01:11.705104Z" }, "papermill": { "duration": 0.033446, "end_time": "2026-02-12T05:43:59.010808", "exception": false, "start_time": "2026-02-12T05:43:58.977362", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "# Shared diagnostics (computed automatically)\n", "flow_diag = explainer_cpi.diagnostics\n", "print(\"Flow diagnostics\")\n", "print(\"-\" * 72)\n", "print(f\"Latent independence (median dCor): {flow_diag['latent_independence_median']:.6f} [{flow_diag['latent_independence_label']}]\")\n", "print(f\"Distribution fidelity (MMD): {flow_diag['distribution_fidelity_mmd']:.6f} [{flow_diag['distribution_fidelity_label']}]\")\n", "\n", "# Standardized inference summary (v0.0.2 defaults use mixture methods)\n", "print(\"\\nX-space summary\")\n", "_ = explainer_cpi.summary(alpha=0.05, target='X', alternative='greater')\n", "\n", "\n", "feature_names = [f\"X{i}\" for i in range(n_features)]\n", "summary_bar(\n", " results_cpi[\"phi_X\"],\n", " results_cpi[\"se_X\"],\n", " feature_names,\n", " show=False,\n", ")\n", "\n", "ci_flow = explainer_cpi.conf_int(alpha=0.05, target=\"X\", alternative=\"greater\")\n", "confidence_interval_plot(ci_flow, feature_names=feature_names, show=False)\n", "diagnostics_plot(flow_diag, feature_names=feature_names, show=False)\n" ] }, { "cell_type": "markdown", "id": "46fcd173", "metadata": { "papermill": { "duration": 0.021235, "end_time": "2026-02-12T05:43:59.052601", "exception": false, "start_time": "2026-02-12T05:43:59.031366", "status": "completed" }, "tags": [] }, "source": [ "## Sampling Methods\n", "\n", "FlowExplainer supports different ways to generate counterfactual values in Z-space:\n", "\n", "- `'resample'`: Sample from background data (default)\n", "- `'permutation'`: Permute within test set\n", "- `'normal'`: Sample from standard normal\n", "- `'condperm'`: Conditional permutation (regress Z_j | Z_{-j})" ] }, { "cell_type": "code", "execution_count": null, "id": "1b281856", "metadata": { "execution": { "iopub.execute_input": "2026-02-17T16:01:11.707276Z", "iopub.status.busy": "2026-02-17T16:01:11.707157Z", "iopub.status.idle": "2026-02-17T16:01:17.164457Z", "shell.execute_reply": "2026-02-17T16:01:17.163760Z" }, "papermill": { "duration": 10.791123, "end_time": "2026-02-12T05:44:09.863905", "exception": false, "start_time": "2026-02-12T05:43:59.072782", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "# Try different sampling methods\n", "sampling_methods = ['resample', 'permutation', 'normal']\n", "results_by_method = {}\n", "\n", "for method in sampling_methods:\n", " exp = FlowExplainer(\n", " model, X_train,\n", " sampling_method=method,\n", " nsamples=30,\n", " num_steps=200,\n", " random_state=42,\n", " )\n", " results_by_method[method] = exp(X_test)\n", "\n", "print(\"Importance by Sampling Method:\")\n", "print(\"-\" * 55)\n", "print(f\"{'Feature':>8} {'resample':>12} {'permutation':>12} {'normal':>12}\")\n", "print(\"-\" * 55)\n", "for i in range(n_features):\n", " r = results_by_method['resample']['phi_X'][i]\n", " p = results_by_method['permutation']['phi_X'][i]\n", " n = results_by_method['normal']['phi_X'][i]\n", " print(f\"{i:>8} {r:>12.4f} {p:>12.4f} {n:>12.4f}\")" ] }, { "cell_type": "markdown", "id": "5502193f", "metadata": { "papermill": { "duration": 0.039814, "end_time": "2026-02-12T05:44:09.944466", "exception": false, "start_time": "2026-02-12T05:44:09.904652", "status": "completed" }, "tags": [] }, "source": [ "## Using a Custom Flow Model\n", "\n", "You can train a flow model separately and pass it to FlowExplainer. This is useful when:\n", "- You want more control over flow training\n", "- You have a pre-trained flow\n", "- You want to use the same flow for multiple explainers" ] }, { "cell_type": "code", "execution_count": null, "id": "55db939e", "metadata": { "execution": { "iopub.execute_input": "2026-02-17T16:01:17.166820Z", "iopub.status.busy": "2026-02-17T16:01:17.166670Z", "iopub.status.idle": "2026-02-17T16:01:20.114388Z", "shell.execute_reply": "2026-02-17T16:01:20.113774Z" }, "papermill": { "duration": 7.247732, "end_time": "2026-02-12T05:44:17.230586", "exception": false, "start_time": "2026-02-12T05:44:09.982854", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "from fdfi.models import FlowMatchingModel\n", "\n", "# Train a custom flow model\n", "custom_flow = FlowMatchingModel(\n", " X=X_train,\n", " dim=n_features,\n", " hidden_dim=64,\n", " num_blocks=2,\n", ")\n", "custom_flow.fit(num_steps=300, verbose='final')\n", "\n", "# Use the pre-trained flow\n", "explainer_custom = FlowExplainer(\n", " model,\n", " data=X_train,\n", " flow_model=custom_flow, # Pass pre-trained flow\n", " fit_flow=False, # Don't retrain\n", " nsamples=30,\n", " random_state=42,\n", ")\n", "\n", "results_custom = explainer_custom(X_test)\n", "print(\"\\nImportance with custom flow:\", results_custom['phi_X'][:4].round(4))" ] }, { "cell_type": "markdown", "id": "8ca82e62", "metadata": { "papermill": { "duration": 0.053469, "end_time": "2026-02-12T05:44:17.334267", "exception": false, "start_time": "2026-02-12T05:44:17.280798", "status": "completed" }, "tags": [] }, "source": [ "## Comparison: FlowExplainer vs OTExplainer\n", "\n", "Let's compare FlowExplainer to OTExplainer on the same data:" ] }, { "cell_type": "code", "execution_count": null, "id": "e3ca716d", "metadata": { "execution": { "iopub.execute_input": "2026-02-17T16:01:20.116360Z", "iopub.status.busy": "2026-02-17T16:01:20.116216Z", "iopub.status.idle": "2026-02-17T16:01:20.132518Z", "shell.execute_reply": "2026-02-17T16:01:20.131852Z" }, "papermill": { "duration": 0.056569, "end_time": "2026-02-12T05:44:17.441899", "exception": false, "start_time": "2026-02-12T05:44:17.385330", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "# OTExplainer for comparison\n", "explainer_ot = OTExplainer(\n", " model,\n", " data=X_train,\n", " nsamples=30,\n", " random_state=42,\n", ")\n", "results_ot = explainer_ot(X_test)\n", "\n", "print(\"Comparison: FlowExplainer vs OTExplainer\")\n", "print(\"-\" * 55)\n", "print(f\"{'Feature':>8} {'Flow (CPI)':>12} {'Flow (SCPI)':>12} {'OT':>12}\")\n", "print(\"-\" * 55)\n", "for i in range(n_features):\n", " f_cpi = results_cpi['phi_X'][i]\n", " f_scpi = results_scpi['phi_X'][i]\n", " ot = results_ot['phi_X'][i]\n", " print(f\"{i:>8} {f_cpi:>12.4f} {f_scpi:>12.4f} {ot:>12.4f}\")\n", "\n", "# Summary statistics\n", "active_mask = np.array([True, True, True, False, False, False, False, False])\n", "\n", "print(\"\\n\" + \"=\" * 55)\n", "print(\"Summary: Active vs Null Feature Importance\")\n", "print(\"=\" * 55)\n", "for name, phi in [('Flow CPI', results_cpi['phi_X']), \n", " ('Flow SCPI', results_scpi['phi_X']),\n", " ('OT', results_ot['phi_X'])]:\n", " active_mean = phi[active_mask].mean()\n", " null_mean = phi[~active_mask].mean()\n", " ratio = active_mean / null_mean if null_mean > 0 else float('inf')\n", " print(f\"{name:>12}: active={active_mean:.4f}, null={null_mean:.4f}, ratio={ratio:.2f}x\")" ] }, { "cell_type": "markdown", "id": "11d46d96", "metadata": { "papermill": { "duration": 0.044792, "end_time": "2026-02-12T05:44:17.531581", "exception": false, "start_time": "2026-02-12T05:44:17.486789", "status": "completed" }, "tags": [] }, "source": [ "## Z-space vs X-space Importance\n", "\n", "FlowExplainer provides both:\n", "- **phi_Z**: Importance in the disentangled latent space\n", "- **phi_X**: Importance attributed to original features (via Jacobian)\n", "\n", "| Space | Meaning |\n", "|-------|---------|\n", "| Z-space (phi_Z) | Importance of independent latent factors |\n", "| X-space (phi_X) | Importance of original correlated features |\n", "\n", "For linear transformations (like OTExplainer), these are related by `phi_X = H^T @ H @ phi_Z` where H is the Cholesky factor. For flows, the Jacobian varies with position.\n", "\n", "## When to Use FlowExplainer" ] }, { "cell_type": "markdown", "id": "bdbe9510", "metadata": { "papermill": { "duration": 0.044818, "end_time": "2026-02-12T05:44:17.622518", "exception": false, "start_time": "2026-02-12T05:44:17.577700", "status": "completed" }, "tags": [] }, "source": [ "## Summary\n", "\n", "In this tutorial, you learned:\n", "\n", "1. **FlowExplainer** uses normalizing flows for flexible, data-driven feature importance.\n", "2. **CPI** averages predictions first, **SCPI** averages squared differences (Sobol-style).\n", "3. Use `method='both'` to compute CPI and SCPI simultaneously.\n", "4. Different **sampling methods** offer different tradeoffs.\n", "5. You can use **custom flow models** for more control.\n", "6. Shared diagnostics and `conf_int`/`summary` provide a consistent inference workflow.\n", "7. Strict one-sided testing (`alternative='greater'`) is useful for feature screening.\n", "\n", "## Next Steps\n", "\n", "- Try FlowExplainer on your own data.\n", "- Compare X-space and Z-space significance.\n", "- Cross-check with OT/EOT for consistency.\n", "\n" ] } ], "metadata": { "kernelspec": { "display_name": "fdfi", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.19" }, "papermill": { "default_parameters": {}, "duration": 35.25608, "end_time": "2026-02-12T05:44:20.284198", "environment_variables": {}, "exception": null, "input_path": "docs/tutorials/flow_explainer.ipynb", "output_path": "docs/tutorials/flow_explainer.ipynb", "parameters": {}, "start_time": "2026-02-12T05:43:45.028118", "version": "2.6.0" } }, "nbformat": 4, "nbformat_minor": 5 }