Certified bounds on SHAP values for neural networks using branch-and-bound with linear relaxation-based bound propagation.
Abstract: Shapley additive explanations (SHAP) are widely recognised as computationally intractable for neural networks, since they induce an exponential search space over the input features. In this work, we take a first step towards scaling exact SHAP computation to higher-dimensional search spaces by introducing an algorithm that leverages recent advances in neural network verification to compute arbitrarily tight exact lower and upper bounds on SHAP values for neural networks, ultimately recovering the exact SHAP values. We demonstrate that our approach scales to substantially higher-dimensional search spaces than state-of-the-art exact methods. This provides an important first step towards exact SHAP computation and establishes a principled cornerstone for evaluating statistical approximation methods on higher-dimensional search spaces.
import jax.numpy as jnp
from jax import nn
from shap_bounds import multi_shap_bab, baseline_value
# Define or load your model in jax
def model(x):
return nn.relu(x @ jnp.array([[1.0], [2.0], [3.0]]))
# Create a value function for SHAP computation
sample = jnp.array([1.0, 2.0, 3.0])
baseline = jnp.zeros(3)
value_fn = baseline_value(model, sample, baseline, output=0)
# Compute SHAP bounds iteratively
for bounds in multi_shap_bab(value_fn, base_mask=jnp.zeros(3)):
print(f"SHAP bounds: {bounds}")
# bounds converge to exact SHAP values over iterationsClone the repository and initialize the submodules:
git submodule update --initInstall dependencies using uv:
uv venv && uv sync --all-extras
source .venv/bin/activateor using conda/pip:
conda create -n shap-bounds python=3.12
conda activate shap-bounds
pip install -e ".[all]"All experiments are Python modules or Bash scripts under experiments/.
Compute SHAP bounds for image patches:
python -m experiments.vision_patches.bound \
--model experiments/resources/mnist-cnn.eqx \
--num-patches 7 \
--input 43 --output-feature 4 \
--shap-variant zero-baseline \
--overlay-loggerRun with -h to see all available options:
python -m experiments.vision_patches.bound -hCompute SHAP bounds for a tabular dataset:
python -m experiments.tabular.bound \
--model experiments/resources/mushroom-mlp-8x1.eqx \
--shap-variant marginalRun full experiment suites using the provided shell scripts:
| Script | Description |
|---|---|
experiments/tabular/compare_to_exactshap.sh |
Compare VeriSHAP against ExactSHAP on tabular datasets |
experiments/tabular/compare_estimators.sh |
Benchmark SHAP estimators (KernelSHAP, PermutationSHAP, etc.) |
experiments/tabular/bab_vs_estimators.sh |
Compare branch-and-bound bounds against estimator convergence |
experiments/vision_patches/compute_bounds.sh |
Run MNIST patch experiments with varying patch counts |
Example:
# Compare to ExactSHAP with marginal SHAP variant
./experiments/tabular/compare_to_exactshap.sh marginal
# Run MNIST experiments (sample 43, output class 4, 5-7 patches)
./experiments/vision_patches/compute_bounds.sh 43 4 mean-baseline 4096