{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "a243e6bb", "metadata": { "execution": { "iopub.execute_input": "2026-05-17T00:41:56.781683Z", "iopub.status.busy": "2026-05-17T00:41:56.781516Z", "iopub.status.idle": "2026-05-17T00:41:57.246800Z", "shell.execute_reply": "2026-05-17T00:41:57.246259Z" } }, "outputs": [], "source": [ "import fdfi\n", "print('FDFI version:', fdfi.__version__)" ] }, { "cell_type": "markdown", "id": "b13c0b60", "metadata": { "papermill": { "duration": 0.005393, "end_time": "2026-02-12T05:43:22.309499", "exception": false, "start_time": "2026-02-12T05:43:22.304106", "status": "completed" }, "tags": [] }, "source": [ "# Quickstart: FDFI in 5 Minutes\n", "\n", "This tutorial introduces the basics of FDFI (Flow-Disentangled Feature Importance). By the end, you'll be able to:\n", "\n", "1. Create an explainer for any model\n", "2. Compute feature importance\n", "3. Interpret the results\n", "4. Get confidence intervals" ] }, { "cell_type": "markdown", "id": "b0216e9f", "metadata": { "papermill": { "duration": 0.002001, "end_time": "2026-02-12T05:43:22.314590", "exception": false, "start_time": "2026-02-12T05:43:22.312589", "status": "completed" }, "tags": [] }, "source": [ "## Setup\n", "\n", "First, let's import the necessary libraries:" ] }, { "cell_type": "code", "execution_count": null, "id": "1fd3b39e", "metadata": { "execution": { "iopub.execute_input": "2026-05-17T00:41:57.248754Z", "iopub.status.busy": "2026-05-17T00:41:57.248586Z", "iopub.status.idle": "2026-05-17T00:41:57.250725Z", "shell.execute_reply": "2026-05-17T00:41:57.250287Z" }, "papermill": { "duration": 0.535767, "end_time": "2026-02-12T05:43:22.852781", "exception": false, "start_time": "2026-02-12T05:43:22.317014", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import numpy as np\n", "from fdfi.explainers import OTExplainer\n", "from fdfi.plots import confidence_interval_plot, summary_bar\n", "\n", "# Set random seed for reproducibility\n", "np.random.seed(42)" ] }, { "cell_type": "markdown", "id": "8fe4f8e0", "metadata": { "papermill": { "duration": 0.001302, "end_time": "2026-02-12T05:43:22.891237", "exception": false, "start_time": "2026-02-12T05:43:22.889935", "status": "completed" }, "tags": [] }, "source": [ "## Create a Simple Model\n", "\n", "Let's create a simple model where we know the true feature importance. Features 0 and 1 are important, the rest are noise:" ] }, { "cell_type": "code", "execution_count": null, "id": "e8088eb0", "metadata": { "execution": { "iopub.execute_input": "2026-05-17T00:41:57.252277Z", "iopub.status.busy": "2026-05-17T00:41:57.252152Z", "iopub.status.idle": "2026-05-17T00:41:57.255236Z", "shell.execute_reply": "2026-05-17T00:41:57.254854Z" }, "papermill": { "duration": 0.007344, "end_time": "2026-02-12T05:43:22.899778", "exception": false, "start_time": "2026-02-12T05:43:22.892434", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "def model(X):\n", " \"\"\"Simple model: y = x0 + 2*x1 + 0.5*x2\"\"\"\n", " return X[:, 0] + 2 * X[:, 1] + 0.5 * X[:, 2]\n", "\n", "# Create training data (used as background distribution)\n", "n_samples = 200\n", "n_features = 10\n", "X_train = np.random.randn(n_samples, n_features)\n", "\n", "# Create test data to explain\n", "X_test = np.random.randn(100, n_features)\n", "\n", "print(f\"Training data shape: {X_train.shape}\")\n", "print(f\"Test data shape: {X_test.shape}\")\n", "print(f\"Model predictions for test data: {model(X_test)[:5]}\")" ] }, { "cell_type": "markdown", "id": "7ee95aea", "metadata": { "papermill": { "duration": 0.001409, "end_time": "2026-02-12T05:43:22.902831", "exception": false, "start_time": "2026-02-12T05:43:22.901422", "status": "completed" }, "tags": [] }, "source": [ "## Create an Explainer\n", "\n", "The `OTExplainer` uses Gaussian optimal transport to compute feature importance:" ] }, { "cell_type": "code", "execution_count": null, "id": "492ea9db", "metadata": { "execution": { "iopub.execute_input": "2026-05-17T00:41:57.256572Z", "iopub.status.busy": "2026-05-17T00:41:57.256485Z", "iopub.status.idle": "2026-05-17T00:41:57.261845Z", "shell.execute_reply": "2026-05-17T00:41:57.261313Z" }, "papermill": { "duration": 0.005513, "end_time": "2026-02-12T05:43:22.909797", "exception": false, "start_time": "2026-02-12T05:43:22.904284", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "# Create the explainer\n", "explainer = OTExplainer(\n", " model, # The model to explain\n", " data=X_train, # Background data\n", " nsamples=50, # Monte Carlo samples per feature\n", ")\n", "\n", "print(\"Explainer created!\")" ] }, { "cell_type": "markdown", "id": "c109f383", "metadata": { "papermill": { "duration": 0.001504, "end_time": "2026-02-12T05:43:22.913031", "exception": false, "start_time": "2026-02-12T05:43:22.911527", "status": "completed" }, "tags": [] }, "source": [ "## Compute Feature Importance\n", "\n", "Call the explainer on test data to get feature importance:" ] }, { "cell_type": "code", "execution_count": null, "id": "f8b8883f", "metadata": { "execution": { "iopub.execute_input": "2026-05-17T00:41:57.263389Z", "iopub.status.busy": "2026-05-17T00:41:57.263276Z", "iopub.status.idle": "2026-05-17T00:41:57.267630Z", "shell.execute_reply": "2026-05-17T00:41:57.267250Z" }, "papermill": { "duration": 0.006654, "end_time": "2026-02-12T05:43:22.921070", "exception": false, "start_time": "2026-02-12T05:43:22.914416", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "# Compute feature importance\n", "results = explainer(X_test)\n", "\n", "# Print the results\n", "print(\"Feature Importance (phi_X):\")\n", "for i, phi in enumerate(results[\"phi_X\"]):\n", " print(f\" Feature {i}: {phi:.4f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualize Feature Importance\n", "\n", "Use `summary_bar` immediately after `results = explainer(X_test)` to inspect global scores and standard errors." ], "id": "quickstart-10" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "feature_names = [f\"X{i}\" for i in range(n_features)]\n", "\n", "fig, ax, importance_table = summary_bar(\n", " results[\"phi_X\"],\n", " results[\"se_X\"],\n", " feature_names,\n", " max_display=8,\n", " show=False,\n", ")\n", "importance_table.head()" ], "id": "quickstart-11" }, { "cell_type": "markdown", "id": "712787f5", "metadata": { "papermill": { "duration": 0.001359, "end_time": "2026-02-12T05:43:22.924029", "exception": false, "start_time": "2026-02-12T05:43:22.922670", "status": "completed" }, "tags": [] }, "source": [ "## Interpret the Results\n", "\n", "The `results` dictionary contains:\n", "- `phi_X`: Feature importance in the original X-space\n", "- `phi_Z`: Feature importance in the disentangled Z-space\n", "- `se_X`, `se_Z`: Standard errors for uncertainty quantification\n", "\n", "Higher values indicate more important features. Since our model uses `x0 + 2*x1 + 0.5*x2`, we expect Features 0, 1, and 2 to have the highest importance." ] }, { "cell_type": "code", "execution_count": null, "id": "018a1e8c", "metadata": { "execution": { "iopub.execute_input": "2026-05-17T00:41:57.269039Z", "iopub.status.busy": "2026-05-17T00:41:57.268949Z", "iopub.status.idle": "2026-05-17T00:41:57.271244Z", "shell.execute_reply": "2026-05-17T00:41:57.270869Z" }, "papermill": { "duration": 0.0056, "end_time": "2026-02-12T05:43:22.931108", "exception": false, "start_time": "2026-02-12T05:43:22.925508", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "# Sort features by importance\n", "importance = results[\"phi_X\"]\n", "sorted_idx = np.argsort(importance)[::-1]\n", "\n", "print(\"Features ranked by importance:\")\n", "for rank, idx in enumerate(sorted_idx):\n", " print(f\" Rank {rank+1}: Feature {idx} (importance = {importance[idx]:.4f})\")" ] }, { "cell_type": "markdown", "id": "e9949f90", "metadata": { "papermill": { "duration": 0.00141, "end_time": "2026-02-12T05:43:22.934041", "exception": false, "start_time": "2026-02-12T05:43:22.932631", "status": "completed" }, "tags": [] }, "source": [ "## Get Confidence Intervals\n", "\n", "FDFI provides statistical inference via `conf_int()`:" ] }, { "cell_type": "code", "execution_count": null, "id": "e1a48247", "metadata": { "execution": { "iopub.execute_input": "2026-05-17T00:41:57.272617Z", "iopub.status.busy": "2026-05-17T00:41:57.272529Z", "iopub.status.idle": "2026-05-17T00:41:57.890192Z", "shell.execute_reply": "2026-05-17T00:41:57.889698Z" }, "papermill": { "duration": 0.006332, "end_time": "2026-02-12T05:43:22.941719", "exception": false, "start_time": "2026-02-12T05:43:22.935387", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "# Compute confidence intervals\n", "ci = explainer.conf_int(\n", " alpha=0.05, # 95% confidence level\n", " target=\"X\", # Use X-space importance\n", " alternative=\"greater\" # Test if importance > 0\n", ")\n", "\n", "print(\"\\nConfidence Intervals (95%, one-sided):\")\n", "print(\"-\" * 70)\n", "print(f\"{'Feature':>8} {'Estimate':>10} {'SE':>10} {'Z-score':>10} {'Rank':>6} {'P-value':>10}\")\n", "print(\"-\" * 70)\n", "for i in range(n_features):\n", " sig = \"*\" if ci[\"reject_null\"][i] else \"\"\n", " print(f\"{i:>8} {ci['score'][i]:>10.4f} {ci['se'][i]:>10.4f} \"\n", " f\"{ci['zscore'][i]:>10.4f} {ci['ranking'][i]:>6} {ci['pvalue'][i]:>10.4f} {sig}\")\n", "\n", "print(\"\\n* = significant at alpha=0.05\")\n", "print(\"\\nNote: 'zscore' = (score - margin) / se, 'ranking' = rank by descending z-score.\")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "confidence_interval_plot(\n", " ci,\n", " feature_names=feature_names,\n", " max_display=8,\n", " show=False,\n", ")" ], "id": "quickstart-16" }, { "cell_type": "markdown", "id": "5d2740fb", "metadata": { "papermill": { "duration": 0.001646, "end_time": "2026-02-12T05:43:22.944959", "exception": false, "start_time": "2026-02-12T05:43:22.943313", "status": "completed" }, "tags": [] }, "source": [ "## View Summary\n", "\n", "Use the built-in `summary()` method for a formatted output:" ] }, { "cell_type": "code", "execution_count": null, "id": "b23ce403", "metadata": { "execution": { "iopub.execute_input": "2026-05-17T00:41:57.891998Z", "iopub.status.busy": "2026-05-17T00:41:57.891787Z", "iopub.status.idle": "2026-05-17T00:41:57.898020Z", "shell.execute_reply": "2026-05-17T00:41:57.897599Z" }, "papermill": { "duration": 0.007091, "end_time": "2026-02-12T05:43:22.953424", "exception": false, "start_time": "2026-02-12T05:43:22.946333", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "# Print formatted summary\n", "explainer.summary(alpha=0.05, alternative=\"greater\")" ] }, { "cell_type": "markdown", "id": "0b9460f3", "metadata": { "papermill": { "duration": 0.001626, "end_time": "2026-02-12T05:43:22.956772", "exception": false, "start_time": "2026-02-12T05:43:22.955146", "status": "completed" }, "tags": [] }, "source": [ "## Next Steps\n", "\n", "Now that you've learned the basics, check out these tutorials:\n", "\n", "- **OT Explainer Deep Dive**: Learn more about the Gaussian OT method\n", "- **EOT Explainer**: Entropic OT for non-Gaussian data\n", "- **Confidence Intervals**: Advanced statistical inference" ] } ], "metadata": { "kernelspec": { "display_name": "dfi", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.19" }, "papermill": { "default_parameters": {}, "duration": 1.703521, "end_time": "2026-02-12T05:43:23.175283", "environment_variables": {}, "exception": null, "input_path": "docs/tutorials/quickstart.ipynb", "output_path": "docs/tutorials/quickstart.ipynb", "parameters": {}, "start_time": "2026-02-12T05:43:21.471762", "version": "2.6.0" } }, "nbformat": 4, "nbformat_minor": 5 }