Installation
FDFI can be installed from source. We recommend using a conda environment for managing dependencies.
Using Conda (Recommended)
Create and activate a conda environment:
conda create -n fdfi python=3.10
conda activate fdfi
Then install FDFI:
git clone https://github.com/jaydu1/FDFI.git
cd FDFI
pip install -e .
From Source with pip
git clone https://github.com/jaydu1/FDFI.git
cd FDFI
pip install -e .
Optional Dependencies
FDFI includes plotting dependencies in the base install. Optional dependency groups are available for heavier workflows:
Flow matching models (PyTorch, torchdiffeq):
pip install -e ".[flow]"
Development tools (pytest, black, flake8, mypy):
pip install -e ".[dev]"
Documentation building (Sphinx, RTD theme):
pip install -e ".[docs]"
All optional dependencies:
pip install -e ".[all]"
Using environment.yml
You can also use the provided conda environment file:
conda env create -f environment.yml
conda activate fdfi
Requirements
Core requirements:
Python >= 3.8
NumPy >= 1.20.0
SciPy >= 1.7.0
Optional requirements:
matplotlib >= 3.5.0 (for plotting)
seaborn >= 0.12.0 (for plotting)
torch >= 2.0.0 (for flow matching)
torchdiffeq >= 0.2.3 (for flow matching)
scikit-learn (for mixture models in utilities)
Verifying Installation
After installation, verify that FDFI is working:
import fdfi
print(fdfi.__version__)
# Test basic functionality
import numpy as np
from fdfi.explainers import OTExplainer
def model(X):
return X.sum(axis=1)
X = np.random.randn(50, 5)
explainer = OTExplainer(model, data=X, nsamples=20)
results = explainer(X[:5])
print("Installation successful!")
Troubleshooting
ImportError for torch or torchdiffeq
If you see import errors related to PyTorch, you need to install the flow dependencies:
pip install -e ".[flow]"
Or pass fit_flow=False when creating explainers to disable flow matching:
explainer = Explainer(model, data=X, fit_flow=False)
Matplotlib backend issues
If you encounter issues with matplotlib on headless servers:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt