@@ -46,15 +46,15 @@ def __init__(self, steps=10_000):
46
46
)
47
47
self ._solver_config = (
48
48
SolverConfigBuilder .builder ()
49
- .withTerminationConfig (self ._termination_config )
50
- .build ()
49
+ .withTerminationConfig (self ._termination_config )
50
+ .build ()
51
51
)
52
52
self ._cf_config = CounterfactualConfig ().withSolverConfig (self ._solver_config )
53
53
54
54
self ._explainer = _CounterfactualExplainer (self ._cf_config )
55
55
56
56
def explain (
57
- self , prediction : CounterfactualPrediction , model : PredictionProvider
57
+ self , prediction : CounterfactualPrediction , model : PredictionProvider
58
58
) -> CounterfactualResult :
59
59
"""Request for a counterfactual explanation given a prediction and a model"""
60
60
return self ._explainer .explainAsync (prediction , model ).get ()
@@ -69,22 +69,32 @@ def __init__(self, saliencies: Dict[str, Saliency]):
69
69
def show (self , decision : str ) -> str :
70
70
"""Return saliencies for a decision"""
71
71
result = f"Saliencies for '{ decision } ':\n "
72
- for f in self ._saliencies .get (decision ).getPerFeatureImportance ():
73
- result += f'\t { f .getFeature ().name } : { f .getScore ()} \n '
72
+ for feature_importance in self ._saliencies .get (
73
+ decision
74
+ ).getPerFeatureImportance ():
75
+ result += f"\t { feature_importance .getFeature ().name } : { feature_importance .getScore ()} \n "
74
76
return result
75
77
76
78
def map (self ):
79
+ """Return saliencies map"""
77
80
return self ._saliencies
78
81
79
82
def plot (self , decision : str ):
80
- d = {}
81
- for f in self ._saliencies .get (decision ).getPerFeatureImportance ():
82
- d [f .getFeature ().name ] = f .getScore ()
83
-
84
- colours = ['r' if i < 0 else 'g' for i in d .values ()]
83
+ """Plot saliencies"""
84
+ dictionary = {}
85
+ for feature_importance in self ._saliencies .get (
86
+ decision
87
+ ).getPerFeatureImportance ():
88
+ dictionary [
89
+ feature_importance .getFeature ().name
90
+ ] = feature_importance .getScore ()
91
+
92
+ colours = ["r" if i < 0 else "g" for i in dictionary .values ()]
85
93
plt .title (f"LIME explanation for '{ decision } '" )
86
- plt .barh (range (len (d )), d .values (), align = 'center' , color = colours )
87
- plt .yticks (range (len (d )), list (d .keys ()))
94
+ plt .barh (
95
+ range (len (dictionary )), dictionary .values (), align = "center" , color = colours
96
+ )
97
+ plt .yticks (range (len (dictionary )), list (dictionary .keys ()))
88
98
plt .tight_layout ()
89
99
90
100
@@ -93,23 +103,23 @@ class LimeExplainer:
93
103
"""Wrapper for TrustyAI's LIME explainer"""
94
104
95
105
def __init__ (
96
- self ,
97
- perturbations = 1 ,
98
- seed = 0 ,
99
- samples = 10 ,
100
- penalise_sparse_balance = True ,
101
- normalise_weights = True ,
106
+ self ,
107
+ perturbations = 1 ,
108
+ seed = 0 ,
109
+ samples = 10 ,
110
+ penalise_sparse_balance = True ,
111
+ normalise_weights = True ,
102
112
):
103
113
# build LIME configuration
104
114
self ._jrandom = Random ()
105
115
self ._jrandom .setSeed (seed )
106
116
107
117
self ._lime_config = (
108
118
LimeConfig ()
109
- .withNormalizeWeights (normalise_weights )
110
- .withPerturbationContext (PerturbationContext (self ._jrandom , perturbations ))
111
- .withSamples (samples )
112
- .withPenalizeBalanceSparse (penalise_sparse_balance )
119
+ .withNormalizeWeights (normalise_weights )
120
+ .withPerturbationContext (PerturbationContext (self ._jrandom , perturbations ))
121
+ .withSamples (samples )
122
+ .withPenalizeBalanceSparse (penalise_sparse_balance )
113
123
)
114
124
115
125
self ._explainer = _LimeExplainer (self ._lime_config )
@@ -123,12 +133,12 @@ class SHAPExplainer:
123
133
"""Wrapper for TrustyAI's SHAP explainer"""
124
134
125
135
def __init__ (
126
- self ,
127
- background : List [_PredictionInput ],
128
- samples = 100 ,
129
- seed = 0 ,
130
- perturbations = 0 ,
131
- link_type : Optional [_ShapConfig .LinkType ] = None ,
136
+ self ,
137
+ background : List [_PredictionInput ],
138
+ samples = 100 ,
139
+ seed = 0 ,
140
+ perturbations = 0 ,
141
+ link_type : Optional [_ShapConfig .LinkType ] = None ,
132
142
):
133
143
if not link_type :
134
144
link_type = _ShapConfig .LinkType .IDENTITY
@@ -137,11 +147,11 @@ def __init__(
137
147
perturbation_context = PerturbationContext (self ._jrandom , perturbations )
138
148
self ._config = (
139
149
_ShapConfig .builder ()
140
- .withLink (link_type )
141
- .withPC (perturbation_context )
142
- .withBackground (background )
143
- .withNSamples (JInt (samples ))
144
- .build ()
150
+ .withLink (link_type )
151
+ .withPC (perturbation_context )
152
+ .withBackground (background )
153
+ .withNSamples (JInt (samples ))
154
+ .build ()
145
155
)
146
156
self ._explainer = _ShapKernelExplainer (self ._config )
147
157
0 commit comments