Skip to content

Commit 3137919

Browse files
committed
Bump release
Fix formatting and tests
1 parent ecfefc2 commit 3137919

File tree

4 files changed

+48
-38
lines changed

4 files changed

+48
-38
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
![version](https://img.shields.io/badge/version-0.0.8-green) ![TrustyAI](https://img.shields.io/badge/TrustyAI-1.17-green) [![Tests](https://github.com/trustyai-python/module/actions/workflows/workflow.yml/badge.svg)](https://github.com/trustyai-python/examples/actions/workflows/workflow.yml)
1+
![version](https://img.shields.io/badge/version-0.0.9-green) ![TrustyAI](https://img.shields.io/badge/TrustyAI-1.17-green) [![Tests](https://github.com/trustyai-python/module/actions/workflows/workflow.yml/badge.svg)](https://github.com/trustyai-python/examples/actions/workflows/workflow.yml)
22

33
# python-trustyai
44

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def run(self):
2727

2828
setup(
2929
name="trustyai",
30-
version="0.0.8",
30+
version="0.0.9",
3131
description="Python bindings to the TrustyAI explainability library",
3232
long_description=long_description,
3333
long_description_content_type="text/markdown",

tests/test_limeexplainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_sparse_balance(): # pylint: disable=too-many-locals
5454

5555
saliency_map_no_penalty = lime_explainer_no_penalty.explain(
5656
prediction, model
57-
)
57+
).map()
5858

5959
assert saliency_map_no_penalty is not None
6060

@@ -63,7 +63,7 @@ def test_sparse_balance(): # pylint: disable=too-many-locals
6363

6464
lime_explainer = LimeExplainer(samples=100, penalise_sparse_balance=True)
6565

66-
saliency_map = lime_explainer.explain(prediction, model)
66+
saliency_map = lime_explainer.explain(prediction, model).map()
6767
assert saliency_map is not None
6868

6969
saliency = saliency_map.get(decision_name)
@@ -85,7 +85,7 @@ def test_normalized_weights():
8585
outputs = model.predict([features])[0].outputs
8686
prediction = simple_prediction(input_features=features, outputs=outputs)
8787

88-
saliency_map = lime_explainer.explain(prediction, model)
88+
saliency_map = lime_explainer.explain(prediction, model).map()
8989
assert saliency_map is not None
9090

9191
decision_name = "sum-but0"

trustyai/explainers.py

Lines changed: 43 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,15 @@ def __init__(self, steps=10_000):
4646
)
4747
self._solver_config = (
4848
SolverConfigBuilder.builder()
49-
.withTerminationConfig(self._termination_config)
50-
.build()
49+
.withTerminationConfig(self._termination_config)
50+
.build()
5151
)
5252
self._cf_config = CounterfactualConfig().withSolverConfig(self._solver_config)
5353

5454
self._explainer = _CounterfactualExplainer(self._cf_config)
5555

5656
def explain(
57-
self, prediction: CounterfactualPrediction, model: PredictionProvider
57+
self, prediction: CounterfactualPrediction, model: PredictionProvider
5858
) -> CounterfactualResult:
5959
"""Request for a counterfactual explanation given a prediction and a model"""
6060
return self._explainer.explainAsync(prediction, model).get()
@@ -69,22 +69,32 @@ def __init__(self, saliencies: Dict[str, Saliency]):
6969
def show(self, decision: str) -> str:
7070
"""Return saliencies for a decision"""
7171
result = f"Saliencies for '{decision}':\n"
72-
for f in self._saliencies.get(decision).getPerFeatureImportance():
73-
result += f'\t{f.getFeature().name}: {f.getScore()}\n'
72+
for feature_importance in self._saliencies.get(
73+
decision
74+
).getPerFeatureImportance():
75+
result += f"\t{feature_importance.getFeature().name}: {feature_importance.getScore()}\n"
7476
return result
7577

7678
def map(self):
79+
"""Return saliencies map"""
7780
return self._saliencies
7881

7982
def plot(self, decision: str):
80-
d = {}
81-
for f in self._saliencies.get(decision).getPerFeatureImportance():
82-
d[f.getFeature().name] = f.getScore()
83-
84-
colours = ['r' if i < 0 else 'g' for i in d.values()]
83+
"""Plot saliencies"""
84+
dictionary = {}
85+
for feature_importance in self._saliencies.get(
86+
decision
87+
).getPerFeatureImportance():
88+
dictionary[
89+
feature_importance.getFeature().name
90+
] = feature_importance.getScore()
91+
92+
colours = ["r" if i < 0 else "g" for i in dictionary.values()]
8593
plt.title(f"LIME explanation for '{decision}'")
86-
plt.barh(range(len(d)), d.values(), align='center', color=colours)
87-
plt.yticks(range(len(d)), list(d.keys()))
94+
plt.barh(
95+
range(len(dictionary)), dictionary.values(), align="center", color=colours
96+
)
97+
plt.yticks(range(len(dictionary)), list(dictionary.keys()))
8898
plt.tight_layout()
8999

90100

@@ -93,23 +103,23 @@ class LimeExplainer:
93103
"""Wrapper for TrustyAI's LIME explainer"""
94104

95105
def __init__(
96-
self,
97-
perturbations=1,
98-
seed=0,
99-
samples=10,
100-
penalise_sparse_balance=True,
101-
normalise_weights=True,
106+
self,
107+
perturbations=1,
108+
seed=0,
109+
samples=10,
110+
penalise_sparse_balance=True,
111+
normalise_weights=True,
102112
):
103113
# build LIME configuration
104114
self._jrandom = Random()
105115
self._jrandom.setSeed(seed)
106116

107117
self._lime_config = (
108118
LimeConfig()
109-
.withNormalizeWeights(normalise_weights)
110-
.withPerturbationContext(PerturbationContext(self._jrandom, perturbations))
111-
.withSamples(samples)
112-
.withPenalizeBalanceSparse(penalise_sparse_balance)
119+
.withNormalizeWeights(normalise_weights)
120+
.withPerturbationContext(PerturbationContext(self._jrandom, perturbations))
121+
.withSamples(samples)
122+
.withPenalizeBalanceSparse(penalise_sparse_balance)
113123
)
114124

115125
self._explainer = _LimeExplainer(self._lime_config)
@@ -123,12 +133,12 @@ class SHAPExplainer:
123133
"""Wrapper for TrustyAI's SHAP explainer"""
124134

125135
def __init__(
126-
self,
127-
background: List[_PredictionInput],
128-
samples=100,
129-
seed=0,
130-
perturbations=0,
131-
link_type: Optional[_ShapConfig.LinkType] = None,
136+
self,
137+
background: List[_PredictionInput],
138+
samples=100,
139+
seed=0,
140+
perturbations=0,
141+
link_type: Optional[_ShapConfig.LinkType] = None,
132142
):
133143
if not link_type:
134144
link_type = _ShapConfig.LinkType.IDENTITY
@@ -137,11 +147,11 @@ def __init__(
137147
perturbation_context = PerturbationContext(self._jrandom, perturbations)
138148
self._config = (
139149
_ShapConfig.builder()
140-
.withLink(link_type)
141-
.withPC(perturbation_context)
142-
.withBackground(background)
143-
.withNSamples(JInt(samples))
144-
.build()
150+
.withLink(link_type)
151+
.withPC(perturbation_context)
152+
.withBackground(background)
153+
.withNSamples(JInt(samples))
154+
.build()
145155
)
146156
self._explainer = _ShapKernelExplainer(self._config)
147157

0 commit comments

Comments
 (0)