@@ -46,15 +46,15 @@ def __init__(self, steps=10_000):
4646 )
4747 self ._solver_config = (
4848 SolverConfigBuilder .builder ()
49- .withTerminationConfig (self ._termination_config )
50- .build ()
49+ .withTerminationConfig (self ._termination_config )
50+ .build ()
5151 )
5252 self ._cf_config = CounterfactualConfig ().withSolverConfig (self ._solver_config )
5353
5454 self ._explainer = _CounterfactualExplainer (self ._cf_config )
5555
5656 def explain (
57- self , prediction : CounterfactualPrediction , model : PredictionProvider
57+ self , prediction : CounterfactualPrediction , model : PredictionProvider
5858 ) -> CounterfactualResult :
5959 """Request for a counterfactual explanation given a prediction and a model"""
6060 return self ._explainer .explainAsync (prediction , model ).get ()
@@ -69,22 +69,32 @@ def __init__(self, saliencies: Dict[str, Saliency]):
6969 def show (self , decision : str ) -> str :
7070 """Return saliencies for a decision"""
7171 result = f"Saliencies for '{ decision } ':\n "
72- for f in self ._saliencies .get (decision ).getPerFeatureImportance ():
73- result += f'\t { f .getFeature ().name } : { f .getScore ()} \n '
72+ for feature_importance in self ._saliencies .get (
73+ decision
74+ ).getPerFeatureImportance ():
75+ result += f"\t { feature_importance .getFeature ().name } : { feature_importance .getScore ()} \n "
7476 return result
7577
7678 def map (self ):
79+ """Return saliencies map"""
7780 return self ._saliencies
7881
7982 def plot (self , decision : str ):
80- d = {}
81- for f in self ._saliencies .get (decision ).getPerFeatureImportance ():
82- d [f .getFeature ().name ] = f .getScore ()
83-
84- colours = ['r' if i < 0 else 'g' for i in d .values ()]
83+ """Plot saliencies"""
84+ dictionary = {}
85+ for feature_importance in self ._saliencies .get (
86+ decision
87+ ).getPerFeatureImportance ():
88+ dictionary [
89+ feature_importance .getFeature ().name
90+ ] = feature_importance .getScore ()
91+
92+ colours = ["r" if i < 0 else "g" for i in dictionary .values ()]
8593 plt .title (f"LIME explanation for '{ decision } '" )
86- plt .barh (range (len (d )), d .values (), align = 'center' , color = colours )
87- plt .yticks (range (len (d )), list (d .keys ()))
94+ plt .barh (
95+ range (len (dictionary )), dictionary .values (), align = "center" , color = colours
96+ )
97+ plt .yticks (range (len (dictionary )), list (dictionary .keys ()))
8898 plt .tight_layout ()
8999
90100
@@ -93,23 +103,23 @@ class LimeExplainer:
93103 """Wrapper for TrustyAI's LIME explainer"""
94104
95105 def __init__ (
96- self ,
97- perturbations = 1 ,
98- seed = 0 ,
99- samples = 10 ,
100- penalise_sparse_balance = True ,
101- normalise_weights = True ,
106+ self ,
107+ perturbations = 1 ,
108+ seed = 0 ,
109+ samples = 10 ,
110+ penalise_sparse_balance = True ,
111+ normalise_weights = True ,
102112 ):
103113 # build LIME configuration
104114 self ._jrandom = Random ()
105115 self ._jrandom .setSeed (seed )
106116
107117 self ._lime_config = (
108118 LimeConfig ()
109- .withNormalizeWeights (normalise_weights )
110- .withPerturbationContext (PerturbationContext (self ._jrandom , perturbations ))
111- .withSamples (samples )
112- .withPenalizeBalanceSparse (penalise_sparse_balance )
119+ .withNormalizeWeights (normalise_weights )
120+ .withPerturbationContext (PerturbationContext (self ._jrandom , perturbations ))
121+ .withSamples (samples )
122+ .withPenalizeBalanceSparse (penalise_sparse_balance )
113123 )
114124
115125 self ._explainer = _LimeExplainer (self ._lime_config )
@@ -123,12 +133,12 @@ class SHAPExplainer:
123133 """Wrapper for TrustyAI's SHAP explainer"""
124134
125135 def __init__ (
126- self ,
127- background : List [_PredictionInput ],
128- samples = 100 ,
129- seed = 0 ,
130- perturbations = 0 ,
131- link_type : Optional [_ShapConfig .LinkType ] = None ,
136+ self ,
137+ background : List [_PredictionInput ],
138+ samples = 100 ,
139+ seed = 0 ,
140+ perturbations = 0 ,
141+ link_type : Optional [_ShapConfig .LinkType ] = None ,
132142 ):
133143 if not link_type :
134144 link_type = _ShapConfig .LinkType .IDENTITY
@@ -137,11 +147,11 @@ def __init__(
137147 perturbation_context = PerturbationContext (self ._jrandom , perturbations )
138148 self ._config = (
139149 _ShapConfig .builder ()
140- .withLink (link_type )
141- .withPC (perturbation_context )
142- .withBackground (background )
143- .withNSamples (JInt (samples ))
144- .build ()
150+ .withLink (link_type )
151+ .withPC (perturbation_context )
152+ .withBackground (background )
153+ .withNSamples (JInt (samples ))
154+ .build ()
145155 )
146156 self ._explainer = _ShapKernelExplainer (self ._config )
147157
0 commit comments