Skip to content

Commit f1a3a90

Browse files
committed
Add initial SHAP support
1 parent 69a8d6d commit f1a3a90

File tree

3 files changed

+83
-17
lines changed

3 files changed

+83
-17
lines changed

tests/test_limeexplainer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import pytest
77

8-
from common import mock_feature
98
from trustyai.explainers import LimeExplainer
109
from trustyai.local.counterfactual import simple_prediction
1110
from trustyai.utils import TestUtils

tests/test_shap.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# pylint: disable=import-error, wrong-import-position, wrong-import-order, duplicate-code
2+
"""SHAP explainer test suite"""
3+
4+
from common import *
5+
6+
import pytest
7+
8+
from trustyai.explainers import SHAPExplainer
9+
from trustyai.local.counterfactual import simple_prediction
10+
from trustyai.model import feature, PredictionInput
11+
from trustyai.utils import TestUtils
12+
13+
14+
def test_no_variance_one_output():
15+
"""Check if the explanation returned is not null"""
16+
model = TestUtils.getSumSkipModel(0)
17+
18+
background = [PredictionInput([feature(name="f", value=value, dtype="number") for value in [1.0, 2.0, 3.0]]) for _
19+
in
20+
range(2)]
21+
22+
prediction_outputs = model.predictAsync(background).get()
23+
24+
predictions = [simple_prediction(input_features=background[i].features, outputs=prediction_outputs[i].outputs) for i
25+
in
26+
range(2)]
27+
28+
shap_explainer = SHAPExplainer(background=background)
29+
30+
explanations = [shap_explainer.explain(prediction, model) for prediction in predictions]
31+
32+
for explanation in explanations:
33+
for saliency in explanation.getSaliencies():
34+
for feature_importance in saliency.getPerFeatureImportance():
35+
assert feature_importance.getScore() == 0.0

trustyai/explainers.py

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""Explainers module"""
22
# pylint: disable = import-error, too-few-public-methods
3-
from typing import Dict
3+
from typing import Dict, Optional, List
44

5+
from jpype import JInt
56
from org.kie.kogito.explainability.local.counterfactual import (
67
CounterfactualExplainer as _CounterfactualExplainer,
78
CounterfactualResult,
@@ -13,11 +14,18 @@
1314
LimeExplainer as _LimeExplainer,
1415
)
1516

17+
from org.kie.kogito.explainability.local.shap import (
18+
ShapConfig as _ShapConfig,
19+
ShapResults,
20+
ShapKernelExplainer as _ShapKernelExplainer,
21+
)
22+
1623
from org.kie.kogito.explainability.model import (
1724
CounterfactualPrediction,
1825
PredictionProvider,
1926
Saliency,
2027
PerturbationContext,
28+
PredictionInput as _PredictionInput,
2129
)
2230
from org.optaplanner.core.config.solver.termination import TerminationConfig
2331
from java.lang import Long
@@ -37,15 +45,15 @@ def __init__(self, steps=10_000):
3745
)
3846
self._solver_config = (
3947
SolverConfigBuilder.builder()
40-
.withTerminationConfig(self._termination_config)
41-
.build()
48+
.withTerminationConfig(self._termination_config)
49+
.build()
4250
)
4351
self._cf_config = CounterfactualConfig().withSolverConfig(self._solver_config)
4452

4553
self._explainer = _CounterfactualExplainer(self._cf_config)
4654

4755
def explain(
48-
self, prediction: CounterfactualPrediction, model: PredictionProvider
56+
self, prediction: CounterfactualPrediction, model: PredictionProvider
4957
) -> CounterfactualResult:
5058
"""Request for a counterfactual explanation given a prediction and a model"""
5159
return self._explainer.explainAsync(prediction, model).get()
@@ -56,23 +64,23 @@ class LimeExplainer:
5664
"""Wrapper for TrustyAI's LIME explainer"""
5765

5866
def __init__(
59-
self,
60-
perturbations=1,
61-
seed=0,
62-
samples=10,
63-
penalise_sparse_balance=True,
64-
normalise_weights=True,
67+
self,
68+
perturbations=1,
69+
seed=0,
70+
samples=10,
71+
penalise_sparse_balance=True,
72+
normalise_weights=True,
6573
):
6674
# build LIME configuration
6775
self._jrandom = Random()
6876
self._jrandom.setSeed(seed)
6977

7078
self._lime_config = (
7179
LimeConfig()
72-
.withNormalizeWeights(normalise_weights)
73-
.withPerturbationContext(PerturbationContext(self._jrandom, perturbations))
74-
.withSamples(samples)
75-
.withPenalizeBalanceSparse(penalise_sparse_balance)
80+
.withNormalizeWeights(normalise_weights)
81+
.withPerturbationContext(PerturbationContext(self._jrandom, perturbations))
82+
.withSamples(samples)
83+
.withPenalizeBalanceSparse(penalise_sparse_balance)
7684
)
7785

7886
self._explainer = _LimeExplainer(self._lime_config)
@@ -85,5 +93,29 @@ def explain(self, prediction, model: PredictionProvider) -> Dict[str, Saliency]:
8593
class SHAPExplainer:
8694
"""Wrapper for TrustyAI's SHAP explainer"""
8795

88-
def __init__(self):
89-
pass
96+
def __init__(
97+
self,
98+
background: List[_PredictionInput],
99+
samples=100,
100+
seed=0,
101+
perturbations=0,
102+
link_type: Optional[_ShapConfig.LinkType] = None,
103+
):
104+
if not link_type:
105+
link_type = _ShapConfig.LinkType.IDENTITY
106+
self._jrandom = Random()
107+
self._jrandom.setSeed(seed)
108+
perturbation_context = PerturbationContext(self._jrandom, perturbations)
109+
self._config = (
110+
_ShapConfig.builder()
111+
.withLink(link_type)
112+
.withPC(perturbation_context)
113+
.withBackground(background)
114+
.withNSamples(JInt(samples))
115+
.build()
116+
)
117+
self._explainer = _ShapKernelExplainer(self._config)
118+
119+
def explain(self, prediction, model: PredictionProvider) -> List[ShapResults]:
120+
"""Request for a SHAP explanation given a prediction and a model"""
121+
return self._explainer.explainAsync(prediction, model).get()

0 commit comments

Comments
 (0)