4040 PredictionProvider ,
4141 Saliency ,
4242 PerturbationContext ,
43+ PredictionInputsDataDistribution ,
4344)
4445
4546from java .util import Random
@@ -233,14 +234,11 @@ class LimeExplainer:
233234 feature that describe how strongly said feature contributed to the model's output.
234235 """
235236
236- def __init__ (self , samples = 10 , ** kwargs ):
237+ def __init__ (self , ** kwargs ):
237238 r"""Initialize the :class:`LimeExplainer`.
238239
239240 Parameters
240241 ----------
241- samples: int
242- Number of samples to be generated for the local linear model training.
243-
244242 Keyword Arguments:
245243 * penalise_sparse_balance : bool
246244 (default= ``True``) Whether to penalise features that are likely to produce linearly
@@ -260,21 +258,58 @@ def __init__(self, samples=10, **kwargs):
260258 process.
261259 * trackCounterfactuals : bool
262260 (default= ``False``) Keep track of produced byproduct counterfactuals during LIME run.
261+ * samples: int
262+ (default= ``300``) Number of samples to be generated for the local linear model training.
263+ * encoding_params: Union[list, tuple]
264+ (default= ``(0.07, 0.3)``) Lime encoding parameters, as a tuple/list of two float numbers:
265+ - encoding_params[0] is the width of the Gaussian filter for clustering number features.
266+ - encoding_params[1] is the threshold for clustering number features.
267+ * data_distribution: PredictionInputsDataDistribution
268+ (default= ``PredictionInputsDataDistribution([])``) Data distribution used to find better feature perturbations
269+ * features: int
270+ (default= ``6``) Number of feature to select from the original set of input features
271+ * retries: int
272+ (default= ``3``) Number of retries performed by LIME to find a separable dataset
273+ * dataset_minimum: int
274+ (default= ``10``) Minimum number of samples retained by the proximity filter to be acceptable
275+ * separable_dataset_ratio: float
276+ (default= ``0.1``) Minimum portion of the encoded dataset that needs to have a different label
277+ * kernel_width: float
278+ (default= ``0.5``) Width of the proximity kernel
279+ * proximity_threshold: float
280+ (default= ``0.83``) Proximity threshold used to retain close samples
281+ * adapt_dataset_variance: bool
282+ (default= ``True``) Whether LIME should try to increase the perturbation variance in subsequent retries
283+ * feature_selection: bool
284+ (default= ``True``) Whether LIME should generate saliency for to the most important features only
285+ * filter_interpretable: bool
286+ (default= ``False``) Whether the proximity filter should happen in the interpretable space
263287
264288 """
265289 self ._jrandom = Random ()
266290 self ._jrandom .setSeed (kwargs .get ("seed" , 0 ))
267-
291+ ep = kwargs . get ( "encoding_params" , ( 0.07 , 0.3 ))
268292 self ._lime_config = (
269293 LimeConfig ()
270294 .withNormalizeWeights (kwargs .get ("normalise_weights" , False ))
271295 .withPerturbationContext (
272296 PerturbationContext (self ._jrandom , kwargs .get ("perturbations" , 1 ))
273297 )
274- .withSamples (samples )
275- .withEncodingParams (EncodingParams (0.07 , 0.3 ))
276- .withAdaptiveVariance (True )
298+ .withSamples (kwargs .get ("samples" , 300 ))
299+ .withDataDistribution (
300+ kwargs .get ("data_distribution" , PredictionInputsDataDistribution ([]))
301+ )
302+ .withNoOfFeatures (kwargs .get ("features" , 6 ))
303+ .withRetries (kwargs .get ("retries" , 3 ))
304+ .withProximityFilteredDatasetMinimum (kwargs .get ("dataset_minimum" , 10 ))
305+ .withSeparableDatasetRatio (kwargs .get ("separable_dataset_ratio" , 0.1 ))
306+ .withProximityKernelWidth (kwargs .get ("kernel_width" , 0.5 ))
307+ .withProximityThreshold (kwargs .get ("proximity_threshold" , 0.83 ))
308+ .withEncodingParams (EncodingParams (ep [0 ], ep [1 ]))
309+ .withAdaptiveVariance (kwargs .get ("adapt_dataset_variance" , True ))
310+ .withFeatureSelection (kwargs .get ("feature_selection" , True ))
277311 .withPenalizeBalanceSparse (kwargs .get ("penalise_sparse_balance" , True ))
312+ .withFilterInterpretable (kwargs .get ("filter_interpretable" , False ))
278313 .withUseWLRLinearModel (kwargs .get ("use_wlr_model" , True ))
279314 .withTrackCounterfactuals (kwargs .get ("track_counterfactuals" , False ))
280315 )
0 commit comments