Skip to content

Commit 87e2e3a

Browse files
authored
Merge pull request #40 from ruivieira/SHAP
Add SHAP support
2 parents 2bafa5c + f1a3a90 commit 87e2e3a

File tree

3 files changed

+75
-2
lines changed

3 files changed

+75
-2
lines changed

tests/test_limeexplainer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import pytest
77

8-
from common import mock_feature
98
from trustyai.explainers import LimeExplainer
109
from trustyai.local.counterfactual import simple_prediction
1110
from trustyai.utils import TestUtils

tests/test_shap.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# pylint: disable=import-error, wrong-import-position, wrong-import-order, duplicate-code
2+
"""SHAP explainer test suite"""
3+
4+
from common import *
5+
6+
import pytest
7+
8+
from trustyai.explainers import SHAPExplainer
9+
from trustyai.local.counterfactual import simple_prediction
10+
from trustyai.model import feature, PredictionInput
11+
from trustyai.utils import TestUtils
12+
13+
14+
def test_no_variance_one_output():
15+
"""Check if the explanation returned is not null"""
16+
model = TestUtils.getSumSkipModel(0)
17+
18+
background = [PredictionInput([feature(name="f", value=value, dtype="number") for value in [1.0, 2.0, 3.0]]) for _
19+
in
20+
range(2)]
21+
22+
prediction_outputs = model.predictAsync(background).get()
23+
24+
predictions = [simple_prediction(input_features=background[i].features, outputs=prediction_outputs[i].outputs) for i
25+
in
26+
range(2)]
27+
28+
shap_explainer = SHAPExplainer(background=background)
29+
30+
explanations = [shap_explainer.explain(prediction, model) for prediction in predictions]
31+
32+
for explanation in explanations:
33+
for saliency in explanation.getSaliencies():
34+
for feature_importance in saliency.getPerFeatureImportance():
35+
assert feature_importance.getScore() == 0.0

trustyai/explainers.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""Explainers module"""
22
# pylint: disable = import-error, too-few-public-methods
3-
from typing import Dict
3+
from typing import Dict, Optional, List
44

5+
from jpype import JInt
56
from org.kie.kogito.explainability.local.counterfactual import (
67
CounterfactualExplainer as _CounterfactualExplainer,
78
CounterfactualResult,
@@ -13,11 +14,18 @@
1314
LimeExplainer as _LimeExplainer,
1415
)
1516

17+
from org.kie.kogito.explainability.local.shap import (
18+
ShapConfig as _ShapConfig,
19+
ShapResults,
20+
ShapKernelExplainer as _ShapKernelExplainer,
21+
)
22+
1623
from org.kie.kogito.explainability.model import (
1724
CounterfactualPrediction,
1825
PredictionProvider,
1926
Saliency,
2027
PerturbationContext,
28+
PredictionInput as _PredictionInput,
2129
)
2230
from org.optaplanner.core.config.solver.termination import TerminationConfig
2331
from java.lang import Long
@@ -80,3 +88,34 @@ def __init__(
8088
def explain(self, prediction, model: PredictionProvider) -> Dict[str, Saliency]:
8189
"""Request for a LIME explanation given a prediction and a model"""
8290
return self._explainer.explainAsync(prediction, model).get()
91+
92+
93+
class SHAPExplainer:
94+
"""Wrapper for TrustyAI's SHAP explainer"""
95+
96+
def __init__(
97+
self,
98+
background: List[_PredictionInput],
99+
samples=100,
100+
seed=0,
101+
perturbations=0,
102+
link_type: Optional[_ShapConfig.LinkType] = None,
103+
):
104+
if not link_type:
105+
link_type = _ShapConfig.LinkType.IDENTITY
106+
self._jrandom = Random()
107+
self._jrandom.setSeed(seed)
108+
perturbation_context = PerturbationContext(self._jrandom, perturbations)
109+
self._config = (
110+
_ShapConfig.builder()
111+
.withLink(link_type)
112+
.withPC(perturbation_context)
113+
.withBackground(background)
114+
.withNSamples(JInt(samples))
115+
.build()
116+
)
117+
self._explainer = _ShapKernelExplainer(self._config)
118+
119+
def explain(self, prediction, model: PredictionProvider) -> List[ShapResults]:
120+
"""Request for a SHAP explanation given a prediction and a model"""
121+
return self._explainer.explainAsync(prediction, model).get()

0 commit comments

Comments
 (0)