Skip to content

Commit 7d67487

Browse files
committed
Add LIME explanation plot
1 parent 78b177c commit 7d67487

File tree

2 files changed

+57
-27
lines changed

2 files changed

+57
-27
lines changed

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
JPype1==1.3.0
1+
JPype1==1.3.0
2+
matplotlib==3.5.1

trustyai/explainers.py

Lines changed: 55 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Explainers module"""
22
# pylint: disable = import-error, too-few-public-methods
33
from typing import Dict, Optional, List
4+
import matplotlib.pyplot as plt
45

56
from jpype import JInt
67
from org.kie.kogito.explainability.local.counterfactual import (
@@ -45,61 +46,89 @@ def __init__(self, steps=10_000):
4546
)
4647
self._solver_config = (
4748
SolverConfigBuilder.builder()
48-
.withTerminationConfig(self._termination_config)
49-
.build()
49+
.withTerminationConfig(self._termination_config)
50+
.build()
5051
)
5152
self._cf_config = CounterfactualConfig().withSolverConfig(self._solver_config)
5253

5354
self._explainer = _CounterfactualExplainer(self._cf_config)
5455

5556
def explain(
56-
self, prediction: CounterfactualPrediction, model: PredictionProvider
57+
self, prediction: CounterfactualPrediction, model: PredictionProvider
5758
) -> CounterfactualResult:
5859
"""Request for a counterfactual explanation given a prediction and a model"""
5960
return self._explainer.explainAsync(prediction, model).get()
6061

6162

63+
class LimeExplanation:
64+
"""Encapsulate LIME results"""
65+
66+
def __init__(self, saliencies: Dict[str, Saliency]):
67+
self._saliencies = saliencies
68+
69+
def show(self, decision: str) -> str:
70+
"""Return saliencies for a decision"""
71+
result = f"Saliencies for '{decision}':\n"
72+
for f in self._saliencies.get(decision).getPerFeatureImportance():
73+
result += f'\t{f.getFeature().name}: {f.getScore()}\n'
74+
return result
75+
76+
def map(self):
77+
return self._saliencies
78+
79+
def plot(self, decision: str):
80+
d = {}
81+
for f in self._saliencies.get(decision).getPerFeatureImportance():
82+
d[f.getFeature().name] = f.getScore()
83+
84+
colours = ['r' if i < 0 else 'g' for i in d.values()]
85+
plt.title(f"LIME explanation for '{decision}'")
86+
plt.barh(range(len(d)), d.values(), align='center', color=colours)
87+
plt.yticks(range(len(d)), list(d.keys()))
88+
plt.tight_layout()
89+
90+
6291
# pylint: disable=too-many-arguments
6392
class LimeExplainer:
6493
"""Wrapper for TrustyAI's LIME explainer"""
6594

6695
def __init__(
67-
self,
68-
perturbations=1,
69-
seed=0,
70-
samples=10,
71-
penalise_sparse_balance=True,
72-
normalise_weights=True,
96+
self,
97+
perturbations=1,
98+
seed=0,
99+
samples=10,
100+
penalise_sparse_balance=True,
101+
normalise_weights=True,
73102
):
74103
# build LIME configuration
75104
self._jrandom = Random()
76105
self._jrandom.setSeed(seed)
77106

78107
self._lime_config = (
79108
LimeConfig()
80-
.withNormalizeWeights(normalise_weights)
81-
.withPerturbationContext(PerturbationContext(self._jrandom, perturbations))
82-
.withSamples(samples)
83-
.withPenalizeBalanceSparse(penalise_sparse_balance)
109+
.withNormalizeWeights(normalise_weights)
110+
.withPerturbationContext(PerturbationContext(self._jrandom, perturbations))
111+
.withSamples(samples)
112+
.withPenalizeBalanceSparse(penalise_sparse_balance)
84113
)
85114

86115
self._explainer = _LimeExplainer(self._lime_config)
87116

88-
def explain(self, prediction, model: PredictionProvider) -> Dict[str, Saliency]:
117+
def explain(self, prediction, model: PredictionProvider) -> LimeExplanation:
89118
"""Request for a LIME explanation given a prediction and a model"""
90-
return self._explainer.explainAsync(prediction, model).get()
119+
return LimeExplanation(self._explainer.explainAsync(prediction, model).get())
91120

92121

93122
class SHAPExplainer:
94123
"""Wrapper for TrustyAI's SHAP explainer"""
95124

96125
def __init__(
97-
self,
98-
background: List[_PredictionInput],
99-
samples=100,
100-
seed=0,
101-
perturbations=0,
102-
link_type: Optional[_ShapConfig.LinkType] = None,
126+
self,
127+
background: List[_PredictionInput],
128+
samples=100,
129+
seed=0,
130+
perturbations=0,
131+
link_type: Optional[_ShapConfig.LinkType] = None,
103132
):
104133
if not link_type:
105134
link_type = _ShapConfig.LinkType.IDENTITY
@@ -108,11 +137,11 @@ def __init__(
108137
perturbation_context = PerturbationContext(self._jrandom, perturbations)
109138
self._config = (
110139
_ShapConfig.builder()
111-
.withLink(link_type)
112-
.withPC(perturbation_context)
113-
.withBackground(background)
114-
.withNSamples(JInt(samples))
115-
.build()
140+
.withLink(link_type)
141+
.withPC(perturbation_context)
142+
.withBackground(background)
143+
.withNSamples(JInt(samples))
144+
.build()
116145
)
117146
self._explainer = _ShapKernelExplainer(self._config)
118147

0 commit comments

Comments
 (0)