11"""Explainers module"""
22# pylint: disable = import-error, too-few-public-methods
33from typing import Dict , Optional , List
4+ import matplotlib .pyplot as plt
45
56from jpype import JInt
67from org .kie .kogito .explainability .local .counterfactual import (
@@ -45,61 +46,89 @@ def __init__(self, steps=10_000):
4546 )
4647 self ._solver_config = (
4748 SolverConfigBuilder .builder ()
48- .withTerminationConfig (self ._termination_config )
49- .build ()
49+ .withTerminationConfig (self ._termination_config )
50+ .build ()
5051 )
5152 self ._cf_config = CounterfactualConfig ().withSolverConfig (self ._solver_config )
5253
5354 self ._explainer = _CounterfactualExplainer (self ._cf_config )
5455
5556 def explain (
56- self , prediction : CounterfactualPrediction , model : PredictionProvider
57+ self , prediction : CounterfactualPrediction , model : PredictionProvider
5758 ) -> CounterfactualResult :
5859 """Request for a counterfactual explanation given a prediction and a model"""
5960 return self ._explainer .explainAsync (prediction , model ).get ()
6061
6162
63+ class LimeExplanation :
64+ """Encapsulate LIME results"""
65+
66+ def __init__ (self , saliencies : Dict [str , Saliency ]):
67+ self ._saliencies = saliencies
68+
69+ def show (self , decision : str ) -> str :
70+ """Return saliencies for a decision"""
71+ result = f"Saliencies for '{ decision } ':\n "
72+ for f in self ._saliencies .get (decision ).getPerFeatureImportance ():
73+ result += f'\t { f .getFeature ().name } : { f .getScore ()} \n '
74+ return result
75+
76+ def map (self ):
77+ return self ._saliencies
78+
79+ def plot (self , decision : str ):
80+ d = {}
81+ for f in self ._saliencies .get (decision ).getPerFeatureImportance ():
82+ d [f .getFeature ().name ] = f .getScore ()
83+
84+ colours = ['r' if i < 0 else 'g' for i in d .values ()]
85+ plt .title (f"LIME explanation for '{ decision } '" )
86+ plt .barh (range (len (d )), d .values (), align = 'center' , color = colours )
87+ plt .yticks (range (len (d )), list (d .keys ()))
88+ plt .tight_layout ()
89+
90+
6291# pylint: disable=too-many-arguments
6392class LimeExplainer :
6493 """Wrapper for TrustyAI's LIME explainer"""
6594
6695 def __init__ (
67- self ,
68- perturbations = 1 ,
69- seed = 0 ,
70- samples = 10 ,
71- penalise_sparse_balance = True ,
72- normalise_weights = True ,
96+ self ,
97+ perturbations = 1 ,
98+ seed = 0 ,
99+ samples = 10 ,
100+ penalise_sparse_balance = True ,
101+ normalise_weights = True ,
73102 ):
74103 # build LIME configuration
75104 self ._jrandom = Random ()
76105 self ._jrandom .setSeed (seed )
77106
78107 self ._lime_config = (
79108 LimeConfig ()
80- .withNormalizeWeights (normalise_weights )
81- .withPerturbationContext (PerturbationContext (self ._jrandom , perturbations ))
82- .withSamples (samples )
83- .withPenalizeBalanceSparse (penalise_sparse_balance )
109+ .withNormalizeWeights (normalise_weights )
110+ .withPerturbationContext (PerturbationContext (self ._jrandom , perturbations ))
111+ .withSamples (samples )
112+ .withPenalizeBalanceSparse (penalise_sparse_balance )
84113 )
85114
86115 self ._explainer = _LimeExplainer (self ._lime_config )
87116
88- def explain (self , prediction , model : PredictionProvider ) -> Dict [ str , Saliency ] :
117+ def explain (self , prediction , model : PredictionProvider ) -> LimeExplanation :
89118 """Request for a LIME explanation given a prediction and a model"""
90- return self ._explainer .explainAsync (prediction , model ).get ()
119+ return LimeExplanation ( self ._explainer .explainAsync (prediction , model ).get () )
91120
92121
93122class SHAPExplainer :
94123 """Wrapper for TrustyAI's SHAP explainer"""
95124
96125 def __init__ (
97- self ,
98- background : List [_PredictionInput ],
99- samples = 100 ,
100- seed = 0 ,
101- perturbations = 0 ,
102- link_type : Optional [_ShapConfig .LinkType ] = None ,
126+ self ,
127+ background : List [_PredictionInput ],
128+ samples = 100 ,
129+ seed = 0 ,
130+ perturbations = 0 ,
131+ link_type : Optional [_ShapConfig .LinkType ] = None ,
103132 ):
104133 if not link_type :
105134 link_type = _ShapConfig .LinkType .IDENTITY
@@ -108,11 +137,11 @@ def __init__(
108137 perturbation_context = PerturbationContext (self ._jrandom , perturbations )
109138 self ._config = (
110139 _ShapConfig .builder ()
111- .withLink (link_type )
112- .withPC (perturbation_context )
113- .withBackground (background )
114- .withNSamples (JInt (samples ))
115- .build ()
140+ .withLink (link_type )
141+ .withPC (perturbation_context )
142+ .withBackground (background )
143+ .withNSamples (JInt (samples ))
144+ .build ()
116145 )
117146 self ._explainer = _ShapKernelExplainer (self ._config )
118147
0 commit comments