{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "a243e6bb", "metadata": { "execution": { "iopub.execute_input": "2026-02-17T15:56:18.517156Z", "iopub.status.busy": "2026-02-17T15:56:18.516895Z", "iopub.status.idle": "2026-02-17T15:56:18.528720Z", "shell.execute_reply": "2026-02-17T15:56:18.527512Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "FDFI version: 0.0.5\n" ] } ], "source": [ "import fdfi\n", "print('FDFI version:', fdfi.__version__)" ] }, { "cell_type": "markdown", "id": "b13c0b60", "metadata": { "papermill": { "duration": 0.005393, "end_time": "2026-02-12T05:43:22.309499", "exception": false, "start_time": "2026-02-12T05:43:22.304106", "status": "completed" }, "tags": [] }, "source": [ "# Quickstart: FDFI in 5 Minutes\n", "\n", "This tutorial introduces the basics of FDFI (Flow-Disentangled Feature Importance). By the end, you'll be able to:\n", "\n", "1. Create an explainer for any model\n", "2. Compute feature importance\n", "3. Interpret the results\n", "4. Get confidence intervals" ] }, { "cell_type": "markdown", "id": "b0216e9f", "metadata": { "papermill": { "duration": 0.002001, "end_time": "2026-02-12T05:43:22.314590", "exception": false, "start_time": "2026-02-12T05:43:22.312589", "status": "completed" }, "tags": [] }, "source": [ "## Setup\n", "\n", "First, let's import the necessary libraries:" ] }, { "cell_type": "code", "execution_count": 2, "id": "1fd3b39e", "metadata": { "execution": { "iopub.execute_input": "2026-02-17T15:56:18.532132Z", "iopub.status.busy": "2026-02-17T15:56:18.531903Z", "iopub.status.idle": "2026-02-17T15:56:18.966970Z", "shell.execute_reply": "2026-02-17T15:56:18.966308Z" }, "papermill": { "duration": 0.535767, "end_time": "2026-02-12T05:43:22.852781", "exception": false, "start_time": "2026-02-12T05:43:22.317014", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import numpy as np\n", "from fdfi.explainers import OTExplainer\n", "\n", "# Set random seed for reproducibility\n", "np.random.seed(42)" ] }, { "cell_type": "markdown", "id": "8fe4f8e0", "metadata": { "papermill": { "duration": 0.001302, "end_time": "2026-02-12T05:43:22.891237", "exception": false, "start_time": "2026-02-12T05:43:22.889935", "status": "completed" }, "tags": [] }, "source": [ "## Create a Simple Model\n", "\n", "Let's create a simple model where we know the true feature importance. Features 0 and 1 are important, the rest are noise:" ] }, { "cell_type": "code", "execution_count": 3, "id": "e8088eb0", "metadata": { "execution": { "iopub.execute_input": "2026-02-17T15:56:18.968825Z", "iopub.status.busy": "2026-02-17T15:56:18.968675Z", "iopub.status.idle": "2026-02-17T15:56:18.972018Z", "shell.execute_reply": "2026-02-17T15:56:18.971487Z" }, "papermill": { "duration": 0.007344, "end_time": "2026-02-12T05:43:22.899778", "exception": false, "start_time": "2026-02-12T05:43:22.892434", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training data shape: (200, 10)\n", "Test data shape: (100, 10)\n", "Model predictions for test data: [-1.36042558 -3.5175806 1.2950158 -0.90179092 -1.77532221]\n" ] } ], "source": [ "def model(X):\n", " \"\"\"Simple model: y = x0 + 2*x1 + 0.5*x2\"\"\"\n", " return X[:, 0] + 2 * X[:, 1] + 0.5 * X[:, 2]\n", "\n", "# Create training data (used as background distribution)\n", "n_samples = 200\n", "n_features = 10\n", "X_train = np.random.randn(n_samples, n_features)\n", "\n", "# Create test data to explain\n", "X_test = np.random.randn(100, n_features)\n", "\n", "print(f\"Training data shape: {X_train.shape}\")\n", "print(f\"Test data shape: {X_test.shape}\")\n", "print(f\"Model predictions for test data: {model(X_test)[:5]}\")" ] }, { "cell_type": "markdown", "id": "7ee95aea", "metadata": { "papermill": { "duration": 0.001409, "end_time": "2026-02-12T05:43:22.902831", "exception": false, "start_time": "2026-02-12T05:43:22.901422", "status": "completed" }, "tags": [] }, "source": [ "## Create an Explainer\n", "\n", "The `OTExplainer` uses Gaussian optimal transport to compute feature importance:" ] }, { "cell_type": "code", "execution_count": 4, "id": "492ea9db", "metadata": { "execution": { "iopub.execute_input": "2026-02-17T15:56:18.973652Z", "iopub.status.busy": "2026-02-17T15:56:18.973557Z", "iopub.status.idle": "2026-02-17T15:56:18.979093Z", "shell.execute_reply": "2026-02-17T15:56:18.978593Z" }, "papermill": { "duration": 0.005513, "end_time": "2026-02-12T05:43:22.909797", "exception": false, "start_time": "2026-02-12T05:43:22.904284", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Explainer created!\n" ] } ], "source": [ "# Create the explainer\n", "explainer = OTExplainer(\n", " model, # The model to explain\n", " data=X_train, # Background data\n", " nsamples=50, # Monte Carlo samples per feature\n", ")\n", "\n", "print(\"Explainer created!\")" ] }, { "cell_type": "markdown", "id": "c109f383", "metadata": { "papermill": { "duration": 0.001504, "end_time": "2026-02-12T05:43:22.913031", "exception": false, "start_time": "2026-02-12T05:43:22.911527", "status": "completed" }, "tags": [] }, "source": [ "## Compute Feature Importance\n", "\n", "Call the explainer on test data to get feature importance:" ] }, { "cell_type": "code", "execution_count": 5, "id": "f8b8883f", "metadata": { "execution": { "iopub.execute_input": "2026-02-17T15:56:18.980935Z", "iopub.status.busy": "2026-02-17T15:56:18.980806Z", "iopub.status.idle": "2026-02-17T15:56:18.984413Z", "shell.execute_reply": "2026-02-17T15:56:18.983979Z" }, "papermill": { "duration": 0.006654, "end_time": "2026-02-12T05:43:22.921070", "exception": false, "start_time": "2026-02-12T05:43:22.914416", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Feature Importance (phi_X):\n", " Feature 0: 0.7829\n", " Feature 1: 4.9322\n", " Feature 2: 0.2430\n", " Feature 3: 0.0113\n", " Feature 4: 0.0224\n", " Feature 5: 0.0095\n", " Feature 6: 0.0159\n", " Feature 7: 0.0033\n", " Feature 8: 0.0010\n", " Feature 9: 0.0369\n" ] } ], "source": [ "# Compute feature importance\n", "results = explainer(X_test)\n", "\n", "# Print the results\n", "print(\"Feature Importance (phi_X):\")\n", "for i, phi in enumerate(results[\"phi_X\"]):\n", " print(f\" Feature {i}: {phi:.4f}\")" ] }, { "cell_type": "markdown", "id": "712787f5", "metadata": { "papermill": { "duration": 0.001359, "end_time": "2026-02-12T05:43:22.924029", "exception": false, "start_time": "2026-02-12T05:43:22.922670", "status": "completed" }, "tags": [] }, "source": [ "## Interpret the Results\n", "\n", "The `results` dictionary contains:\n", "- `phi_X`: Feature importance in the original X-space\n", "- `phi_Z`: Feature importance in the disentangled Z-space\n", "- `se_X`, `se_Z`: Standard errors for uncertainty quantification\n", "\n", "Higher values indicate more important features. Since our model uses `x0 + 2*x1 + 0.5*x2`, we expect Features 0, 1, and 2 to have the highest importance." ] }, { "cell_type": "code", "execution_count": 6, "id": "018a1e8c", "metadata": { "execution": { "iopub.execute_input": "2026-02-17T15:56:18.986215Z", "iopub.status.busy": "2026-02-17T15:56:18.986118Z", "iopub.status.idle": "2026-02-17T15:56:18.988657Z", "shell.execute_reply": "2026-02-17T15:56:18.988275Z" }, "papermill": { "duration": 0.0056, "end_time": "2026-02-12T05:43:22.931108", "exception": false, "start_time": "2026-02-12T05:43:22.925508", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Features ranked by importance:\n", " Rank 1: Feature 1 (importance = 4.9322)\n", " Rank 2: Feature 0 (importance = 0.7829)\n", " Rank 3: Feature 2 (importance = 0.2430)\n", " Rank 4: Feature 9 (importance = 0.0369)\n", " Rank 5: Feature 4 (importance = 0.0224)\n", " Rank 6: Feature 6 (importance = 0.0159)\n", " Rank 7: Feature 3 (importance = 0.0113)\n", " Rank 8: Feature 5 (importance = 0.0095)\n", " Rank 9: Feature 7 (importance = 0.0033)\n", " Rank 10: Feature 8 (importance = 0.0010)\n" ] } ], "source": [ "# Sort features by importance\n", "importance = results[\"phi_X\"]\n", "sorted_idx = np.argsort(importance)[::-1]\n", "\n", "print(\"Features ranked by importance:\")\n", "for rank, idx in enumerate(sorted_idx):\n", " print(f\" Rank {rank+1}: Feature {idx} (importance = {importance[idx]:.4f})\")" ] }, { "cell_type": "markdown", "id": "e9949f90", "metadata": { "papermill": { "duration": 0.00141, "end_time": "2026-02-12T05:43:22.934041", "exception": false, "start_time": "2026-02-12T05:43:22.932631", "status": "completed" }, "tags": [] }, "source": [ "## Get Confidence Intervals\n", "\n", "FDFI provides statistical inference via `conf_int()`:" ] }, { "cell_type": "code", "execution_count": 7, "id": "e1a48247", "metadata": { "execution": { "iopub.execute_input": "2026-02-17T15:56:18.990260Z", "iopub.status.busy": "2026-02-17T15:56:18.990162Z", "iopub.status.idle": "2026-02-17T15:56:19.605601Z", "shell.execute_reply": "2026-02-17T15:56:19.604985Z" }, "papermill": { "duration": 0.006332, "end_time": "2026-02-12T05:43:22.941719", "exception": false, "start_time": "2026-02-12T05:43:22.935387", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Confidence Intervals (95%, one-sided):\n", "------------------------------------------------------------\n", " Feature Estimate SE CI Lower P-value\n", "------------------------------------------------------------\n", " 0 0.7829 0.1201 0.5854 0.0000 *\n", " 1 4.9322 0.7104 3.7636 0.0000 *\n", " 2 0.2430 0.0747 0.1202 0.0029 *\n", " 3 0.0113 0.0671 -0.0991 0.6484 \n", " 4 0.0224 0.0672 -0.0881 0.5859 \n", " 5 0.0095 0.0672 -0.1010 0.6587 \n", " 6 0.0159 0.0672 -0.0946 0.6228 \n", " 7 0.0033 0.0671 -0.1071 0.6917 \n", " 8 0.0010 0.0671 -0.1094 0.7037 \n", " 9 0.0369 0.0672 -0.0737 0.5000 \n", "\n", "* = significant at alpha=0.05\n" ] } ], "source": [ "# Compute confidence intervals\n", "ci = explainer.conf_int(\n", " alpha=0.05, # 95% confidence level\n", " target=\"X\", # Use X-space importance\n", " alternative=\"greater\" # Test if importance > 0\n", ")\n", "\n", "print(\"\\nConfidence Intervals (95%, one-sided):\")\n", "print(\"-\" * 60)\n", "print(f\"{'Feature':>8} {'Estimate':>10} {'SE':>10} {'CI Lower':>10} {'P-value':>10}\")\n", "print(\"-\" * 60)\n", "for i in range(n_features):\n", " sig = \"*\" if ci[\"reject_null\"][i] else \"\"\n", " print(f\"{i:>8} {ci['score'][i]:>10.4f} {ci['se'][i]:>10.4f} \"\n", " f\"{ci['ci_lower'][i]:>10.4f} {ci['pvalue'][i]:>10.4f} {sig}\")\n", "\n", "print(\"\\n* = significant at alpha=0.05\")" ] }, { "cell_type": "markdown", "id": "5d2740fb", "metadata": { "papermill": { "duration": 0.001646, "end_time": "2026-02-12T05:43:22.944959", "exception": false, "start_time": "2026-02-12T05:43:22.943313", "status": "completed" }, "tags": [] }, "source": [ "## View Summary\n", "\n", "Use the built-in `summary()` method for a formatted output:" ] }, { "cell_type": "code", "execution_count": 8, "id": "b23ce403", "metadata": { "execution": { "iopub.execute_input": "2026-02-17T15:56:19.607326Z", "iopub.status.busy": "2026-02-17T15:56:19.607125Z", "iopub.status.idle": "2026-02-17T15:56:19.615229Z", "shell.execute_reply": "2026-02-17T15:56:19.614623Z" }, "papermill": { "duration": 0.007091, "end_time": "2026-02-12T05:43:22.953424", "exception": false, "start_time": "2026-02-12T05:43:22.946333", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "==============================================================================\n", "Feature Importance Results\n", "==============================================================================\n", "Method: OTExplainer\n", "Number of features: 10\n", "Significance level: 0.05\n", "Alternative: greater\n", "Margin method: gap\n", "Practical margin: 0.0369\n", "------------------------------------------------------------------------------\n", " Feature Estimate Std Err CI Lower CI Upper P-value Sig\n", "------------------------------------------------------------------------------\n", " 0 0.7829 0.1201 0.5854 inf 0.0000 ***\n", " 1 4.9322 0.7104 3.7636 inf 0.0000 ***\n", " 2 0.2430 0.0747 0.1202 inf 0.0029 ***\n", " 3 0.0113 0.0671 -0.0991 inf 0.6484 \n", " 4 0.0224 0.0672 -0.0881 inf 0.5859 \n", " 5 0.0095 0.0672 -0.1010 inf 0.6587 \n", " 6 0.0159 0.0672 -0.0946 inf 0.6228 \n", " 7 0.0033 0.0671 -0.1071 inf 0.6917 \n", " 8 0.0010 0.0671 -0.1094 inf 0.7037 \n", " 9 0.0369 0.0672 -0.0737 inf 0.5000 \n", "==============================================================================\n", "Significant features: 3 / 10\n", "---\n", "Signif. codes: 0 '***' 0.01 '**' 0.05 '*' 0.1 ' ' 1\n", "==============================================================================\n" ] }, { "data": { "text/plain": [ "\"==============================================================================\\nFeature Importance Results\\n==============================================================================\\nMethod: OTExplainer\\nNumber of features: 10\\nSignificance level: 0.05\\nAlternative: greater\\nMargin method: gap\\nPractical margin: 0.0369\\n------------------------------------------------------------------------------\\n Feature Estimate Std Err CI Lower CI Upper P-value Sig\\n------------------------------------------------------------------------------\\n 0 0.7829 0.1201 0.5854 inf 0.0000 ***\\n 1 4.9322 0.7104 3.7636 inf 0.0000 ***\\n 2 0.2430 0.0747 0.1202 inf 0.0029 ***\\n 3 0.0113 0.0671 -0.0991 inf 0.6484 \\n 4 0.0224 0.0672 -0.0881 inf 0.5859 \\n 5 0.0095 0.0672 -0.1010 inf 0.6587 \\n 6 0.0159 0.0672 -0.0946 inf 0.6228 \\n 7 0.0033 0.0671 -0.1071 inf 0.6917 \\n 8 0.0010 0.0671 -0.1094 inf 0.7037 \\n 9 0.0369 0.0672 -0.0737 inf 0.5000 \\n==============================================================================\\nSignificant features: 3 / 10\\n---\\nSignif. codes: 0 '***' 0.01 '**' 0.05 '*' 0.1 ' ' 1\\n==============================================================================\"" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Print formatted summary\n", "explainer.summary(alpha=0.05, alternative=\"greater\")" ] }, { "cell_type": "markdown", "id": "0b9460f3", "metadata": { "papermill": { "duration": 0.001626, "end_time": "2026-02-12T05:43:22.956772", "exception": false, "start_time": "2026-02-12T05:43:22.955146", "status": "completed" }, "tags": [] }, "source": [ "## Next Steps\n", "\n", "Now that you've learned the basics, check out these tutorials:\n", "\n", "- **OT Explainer Deep Dive**: Learn more about the Gaussian OT method\n", "- **EOT Explainer**: Entropic OT for non-Gaussian data\n", "- **Confidence Intervals**: Advanced statistical inference" ] } ], "metadata": { "kernelspec": { "display_name": "dfi", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.19" }, "papermill": { "default_parameters": {}, "duration": 1.703521, "end_time": "2026-02-12T05:43:23.175283", "environment_variables": {}, "exception": null, "input_path": "docs/tutorials/quickstart.ipynb", "output_path": "docs/tutorials/quickstart.ipynb", "parameters": {}, "start_time": "2026-02-12T05:43:21.471762", "version": "2.6.0" } }, "nbformat": 4, "nbformat_minor": 5 }