@@ -37,15 +37,15 @@ def __init__(self, steps=10_000):
3737 )
3838 self ._solver_config = (
3939 SolverConfigBuilder .builder ()
40- .withTerminationConfig (self ._termination_config )
41- .build ()
40+ .withTerminationConfig (self ._termination_config )
41+ .build ()
4242 )
4343 self ._cf_config = CounterfactualConfig ().withSolverConfig (self ._solver_config )
4444
4545 self ._explainer = _CounterfactualExplainer (self ._cf_config )
4646
4747 def explain (
48- self , prediction : CounterfactualPrediction , model : PredictionProvider
48+ self , prediction : CounterfactualPrediction , model : PredictionProvider
4949 ) -> CounterfactualResult :
5050 """Request for a counterfactual explanation given a prediction and a model"""
5151 return self ._explainer .explainAsync (prediction , model ).get ()
@@ -56,27 +56,34 @@ class LimeExplainer:
5656 """Wrapper for TrustyAI's LIME explainer"""
5757
5858 def __init__ (
59- self ,
60- perturbations = 1 ,
61- seed = 0 ,
62- samples = 10 ,
63- penalise_sparse_balance = True ,
64- normalise_weights = True ,
59+ self ,
60+ perturbations = 1 ,
61+ seed = 0 ,
62+ samples = 10 ,
63+ penalise_sparse_balance = True ,
64+ normalise_weights = True ,
6565 ):
6666 # build LIME configuration
6767 self ._jrandom = Random ()
6868 self ._jrandom .setSeed (seed )
6969
7070 self ._lime_config = (
7171 LimeConfig ()
72- .withNormalizeWeights (normalise_weights )
73- .withPerturbationContext (PerturbationContext (self ._jrandom , perturbations ))
74- .withSamples (samples )
75- .withPenalizeBalanceSparse (penalise_sparse_balance )
72+ .withNormalizeWeights (normalise_weights )
73+ .withPerturbationContext (PerturbationContext (self ._jrandom , perturbations ))
74+ .withSamples (samples )
75+ .withPenalizeBalanceSparse (penalise_sparse_balance )
7676 )
7777
7878 self ._explainer = _LimeExplainer (self ._lime_config )
7979
8080 def explain (self , prediction , model : PredictionProvider ) -> Dict [str , Saliency ]:
8181 """Request for a LIME explanation given a prediction and a model"""
8282 return self ._explainer .explainAsync (prediction , model ).get ()
83+
84+
85+ class SHAPExplainer :
86+ """Wrapper for TrustyAI's SHAP explainer"""
87+
88+ def __init__ (self ):
89+ pass
0 commit comments