Skip to content

Commit 69a8d6d

Browse files
committed
Add SHAP stub
1 parent 2bafa5c commit 69a8d6d

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

trustyai/explainers.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,15 @@ def __init__(self, steps=10_000):
3737
)
3838
self._solver_config = (
3939
SolverConfigBuilder.builder()
40-
.withTerminationConfig(self._termination_config)
41-
.build()
40+
.withTerminationConfig(self._termination_config)
41+
.build()
4242
)
4343
self._cf_config = CounterfactualConfig().withSolverConfig(self._solver_config)
4444

4545
self._explainer = _CounterfactualExplainer(self._cf_config)
4646

4747
def explain(
48-
self, prediction: CounterfactualPrediction, model: PredictionProvider
48+
self, prediction: CounterfactualPrediction, model: PredictionProvider
4949
) -> CounterfactualResult:
5050
"""Request for a counterfactual explanation given a prediction and a model"""
5151
return self._explainer.explainAsync(prediction, model).get()
@@ -56,27 +56,34 @@ class LimeExplainer:
5656
"""Wrapper for TrustyAI's LIME explainer"""
5757

5858
def __init__(
59-
self,
60-
perturbations=1,
61-
seed=0,
62-
samples=10,
63-
penalise_sparse_balance=True,
64-
normalise_weights=True,
59+
self,
60+
perturbations=1,
61+
seed=0,
62+
samples=10,
63+
penalise_sparse_balance=True,
64+
normalise_weights=True,
6565
):
6666
# build LIME configuration
6767
self._jrandom = Random()
6868
self._jrandom.setSeed(seed)
6969

7070
self._lime_config = (
7171
LimeConfig()
72-
.withNormalizeWeights(normalise_weights)
73-
.withPerturbationContext(PerturbationContext(self._jrandom, perturbations))
74-
.withSamples(samples)
75-
.withPenalizeBalanceSparse(penalise_sparse_balance)
72+
.withNormalizeWeights(normalise_weights)
73+
.withPerturbationContext(PerturbationContext(self._jrandom, perturbations))
74+
.withSamples(samples)
75+
.withPenalizeBalanceSparse(penalise_sparse_balance)
7676
)
7777

7878
self._explainer = _LimeExplainer(self._lime_config)
7979

8080
def explain(self, prediction, model: PredictionProvider) -> Dict[str, Saliency]:
8181
"""Request for a LIME explanation given a prediction and a model"""
8282
return self._explainer.explainAsync(prediction, model).get()
83+
84+
85+
class SHAPExplainer:
86+
"""Wrapper for TrustyAI's SHAP explainer"""
87+
88+
def __init__(self):
89+
pass

0 commit comments

Comments
 (0)