11"""Explainers module"""
22# pylint: disable = import-error, too-few-public-methods
3- from typing import Dict
3+ from typing import Dict , Optional , List
44
5+ from jpype import JInt
56from org .kie .kogito .explainability .local .counterfactual import (
67 CounterfactualExplainer as _CounterfactualExplainer ,
78 CounterfactualResult ,
1314 LimeExplainer as _LimeExplainer ,
1415)
1516
17+ from org .kie .kogito .explainability .local .shap import (
18+ ShapConfig as _ShapConfig ,
19+ ShapResults ,
20+ ShapKernelExplainer as _ShapKernelExplainer ,
21+ )
22+
1623from org .kie .kogito .explainability .model import (
1724 CounterfactualPrediction ,
1825 PredictionProvider ,
1926 Saliency ,
2027 PerturbationContext ,
28+ PredictionInput as _PredictionInput ,
2129)
2230from org .optaplanner .core .config .solver .termination import TerminationConfig
2331from java .lang import Long
@@ -37,15 +45,15 @@ def __init__(self, steps=10_000):
3745 )
3846 self ._solver_config = (
3947 SolverConfigBuilder .builder ()
40- .withTerminationConfig (self ._termination_config )
41- .build ()
48+ .withTerminationConfig (self ._termination_config )
49+ .build ()
4250 )
4351 self ._cf_config = CounterfactualConfig ().withSolverConfig (self ._solver_config )
4452
4553 self ._explainer = _CounterfactualExplainer (self ._cf_config )
4654
4755 def explain (
48- self , prediction : CounterfactualPrediction , model : PredictionProvider
56+ self , prediction : CounterfactualPrediction , model : PredictionProvider
4957 ) -> CounterfactualResult :
5058 """Request for a counterfactual explanation given a prediction and a model"""
5159 return self ._explainer .explainAsync (prediction , model ).get ()
@@ -56,23 +64,23 @@ class LimeExplainer:
5664 """Wrapper for TrustyAI's LIME explainer"""
5765
5866 def __init__ (
59- self ,
60- perturbations = 1 ,
61- seed = 0 ,
62- samples = 10 ,
63- penalise_sparse_balance = True ,
64- normalise_weights = True ,
67+ self ,
68+ perturbations = 1 ,
69+ seed = 0 ,
70+ samples = 10 ,
71+ penalise_sparse_balance = True ,
72+ normalise_weights = True ,
6573 ):
6674 # build LIME configuration
6775 self ._jrandom = Random ()
6876 self ._jrandom .setSeed (seed )
6977
7078 self ._lime_config = (
7179 LimeConfig ()
72- .withNormalizeWeights (normalise_weights )
73- .withPerturbationContext (PerturbationContext (self ._jrandom , perturbations ))
74- .withSamples (samples )
75- .withPenalizeBalanceSparse (penalise_sparse_balance )
80+ .withNormalizeWeights (normalise_weights )
81+ .withPerturbationContext (PerturbationContext (self ._jrandom , perturbations ))
82+ .withSamples (samples )
83+ .withPenalizeBalanceSparse (penalise_sparse_balance )
7684 )
7785
7886 self ._explainer = _LimeExplainer (self ._lime_config )
@@ -85,5 +93,29 @@ def explain(self, prediction, model: PredictionProvider) -> Dict[str, Saliency]:
8593class SHAPExplainer :
8694 """Wrapper for TrustyAI's SHAP explainer"""
8795
88- def __init__ (self ):
89- pass
96+ def __init__ (
97+ self ,
98+ background : List [_PredictionInput ],
99+ samples = 100 ,
100+ seed = 0 ,
101+ perturbations = 0 ,
102+ link_type : Optional [_ShapConfig .LinkType ] = None ,
103+ ):
104+ if not link_type :
105+ link_type = _ShapConfig .LinkType .IDENTITY
106+ self ._jrandom = Random ()
107+ self ._jrandom .setSeed (seed )
108+ perturbation_context = PerturbationContext (self ._jrandom , perturbations )
109+ self ._config = (
110+ _ShapConfig .builder ()
111+ .withLink (link_type )
112+ .withPC (perturbation_context )
113+ .withBackground (background )
114+ .withNSamples (JInt (samples ))
115+ .build ()
116+ )
117+ self ._explainer = _ShapKernelExplainer (self ._config )
118+
119+ def explain (self , prediction , model : PredictionProvider ) -> List [ShapResults ]:
120+ """Request for a SHAP explanation given a prediction and a model"""
121+ return self ._explainer .explainAsync (prediction , model ).get ()
0 commit comments