4040    PredictionProvider ,
4141    Saliency ,
4242    PerturbationContext ,
43+     PredictionInputsDataDistribution ,
4344)
4445
4546from  java .util  import  Random 
@@ -233,14 +234,11 @@ class LimeExplainer:
233234    feature that describe how strongly said feature contributed to the model's output. 
234235    """ 
235236
236-     def  __init__ (self , samples = 10 ,  ** kwargs ):
237+     def  __init__ (self , ** kwargs ):
237238        r"""Initialize the :class:`LimeExplainer`. 
238239
239240        Parameters 
240241        ---------- 
241-         samples: int 
242-             Number of samples to be generated for the local linear model training. 
243- 
244242        Keyword Arguments: 
245243            * penalise_sparse_balance : bool 
246244                (default= ``True``) Whether to penalise features that are likely to produce linearly 
@@ -260,21 +258,58 @@ def __init__(self, samples=10, **kwargs):
260258                process. 
261259            * trackCounterfactuals : bool 
262260                (default= ``False``) Keep track of produced byproduct counterfactuals during LIME run. 
261+             * samples: int 
262+                 (default= ``300``) Number of samples to be generated for the local linear model training. 
263+             * encoding_params: Union[list, tuple] 
264+                 (default= ``(0.07, 0.3)``) Lime encoding parameters, as a tuple/list of two float numbers: 
265+                 - encoding_params[0] is the width of the Gaussian filter for clustering number features. 
266+                 - encoding_params[1] is the threshold for clustering number features. 
267+             * data_distribution: PredictionInputsDataDistribution 
268+                 (default= ``PredictionInputsDataDistribution([])``) Data distribution used to find better feature perturbations 
269+             * features: int 
270+                 (default= ``6``) Number of feature to select from the original set of input features 
271+             * retries: int 
272+                 (default= ``3``) Number of retries performed by LIME to find a separable dataset 
273+             * dataset_minimum: int 
274+                 (default= ``10``) Minimum number of samples retained by the proximity filter to be acceptable 
275+             * separable_dataset_ratio: float 
276+                 (default= ``0.1``) Minimum portion of the encoded dataset that needs to have a different label 
277+             *  kernel_width: float 
278+                 (default= ``0.5``) Width of the proximity kernel 
279+             * proximity_threshold: float 
280+                 (default= ``0.83``) Proximity threshold used to retain close samples 
281+             * adapt_dataset_variance: bool 
282+                 (default= ``True``) Whether LIME should try to increase the perturbation variance in subsequent retries 
283+             * feature_selection: bool 
284+                 (default= ``True``) Whether LIME should generate saliency for to the most important features only 
285+             * filter_interpretable: bool 
286+                 (default= ``False``) Whether the proximity filter should happen in the interpretable space 
263287
264288        """ 
265289        self ._jrandom  =  Random ()
266290        self ._jrandom .setSeed (kwargs .get ("seed" , 0 ))
267- 
291+          ep   =   kwargs . get ( "encoding_params" , ( 0.07 ,  0.3 )) 
268292        self ._lime_config  =  (
269293            LimeConfig ()
270294            .withNormalizeWeights (kwargs .get ("normalise_weights" , False ))
271295            .withPerturbationContext (
272296                PerturbationContext (self ._jrandom , kwargs .get ("perturbations" , 1 ))
273297            )
274-             .withSamples (samples )
275-             .withEncodingParams (EncodingParams (0.07 , 0.3 ))
276-             .withAdaptiveVariance (True )
298+             .withSamples (kwargs .get ("samples" , 300 ))
299+             .withDataDistribution (
300+                 kwargs .get ("data_distribution" , PredictionInputsDataDistribution ([]))
301+             )
302+             .withNoOfFeatures (kwargs .get ("features" , 6 ))
303+             .withRetries (kwargs .get ("retries" , 3 ))
304+             .withProximityFilteredDatasetMinimum (kwargs .get ("dataset_minimum" , 10 ))
305+             .withSeparableDatasetRatio (kwargs .get ("separable_dataset_ratio" , 0.1 ))
306+             .withProximityKernelWidth (kwargs .get ("kernel_width" , 0.5 ))
307+             .withProximityThreshold (kwargs .get ("proximity_threshold" , 0.83 ))
308+             .withEncodingParams (EncodingParams (ep [0 ], ep [1 ]))
309+             .withAdaptiveVariance (kwargs .get ("adapt_dataset_variance" , True ))
310+             .withFeatureSelection (kwargs .get ("feature_selection" , True ))
277311            .withPenalizeBalanceSparse (kwargs .get ("penalise_sparse_balance" , True ))
312+             .withFilterInterpretable (kwargs .get ("filter_interpretable" , False ))
278313            .withUseWLRLinearModel (kwargs .get ("use_wlr_model" , True ))
279314            .withTrackCounterfactuals (kwargs .get ("track_counterfactuals" , False ))
280315        )
0 commit comments