11"""Explainers.shap module"""
22# pylint: disable = import-error, too-few-public-methods, wrong-import-order, line-too-long,
3- # pylint: disable = unused-argument, consider-using-f-string, invalid-name, too-many-arguments
3+ # pylint: disable = unused-argument, consider-using-f-string, invalid-name
44from typing import Dict , Optional , List , Union
55import matplotlib .pyplot as plt
66import matplotlib as mpl
@@ -434,11 +434,7 @@ class SHAPExplainer:
434434 def __init__ (
435435 self ,
436436 background : Union [np .ndarray , pd .DataFrame , List [PredictionInput ]],
437- samples = None ,
438- batch_size = 20 ,
439- seed = 0 ,
440437 link_type : Optional [_ShapConfig .LinkType ] = None ,
441- track_counterfactuals = False ,
442438 ** kwargs ,
443439 ):
444440 r"""Initialize the :class:`SHAPxplainer`.
@@ -449,23 +445,26 @@ def __init__(
449445 or List[:class:`PredictionInput]
450446 The set of background datapoints as an array, dataframe of shape
451447 ``[n_datapoints, n_features]``, or list of TrustyAI PredictionInputs.
452- samples: int
453- The number of samples to use when computing SHAP values. Higher values will increase
454- explanation accuracy, at the cost of runtime.
455- batch_size: int
456- The number of batches passed to the PredictionProvider at once. When using a
457- :class:`~Model` in the :func:`explain` function, this parameter has no effect. With an
458- :class:`~ArrowModel`, `batch_sizes` of around
459- :math:`\frac{2000}{\mathtt{len(background)}}` can produce significant
460- performance gains.
461- seed: int
462- The random seed to be used when generating explanations.
463448 link_type : :obj:`~_ShapConfig.LinkType`
464449 A choice of either ``trustyai.explainers._ShapConfig.LinkType.IDENTITY``
465450 or ``trustyai.explainers._ShapConfig.LinkType.LOGIT``. If the model output is a
466451 probability, choosing the ``LOGIT`` link will rescale explanations into log-odds units.
467452 Otherwise, choose ``IDENTITY``.
468-
453+ Keyword Arguments:
454+ * samples: int
455+ (default=None) The number of samples to use when computing SHAP values. Higher
456+ values will increase explanation accuracy, at the cost of runtime. If none,
457+ samples will equal 2048 + 2*n_features
458+ * seed: int
459+ (default=0) The random seed to be used when generating explanations.
460+ * batchSize: int
461+ (default=20) The number of batches passed to the PredictionProvider at once.
462+ When uusing :class:`~Model` with `arrow=False` this parameter has no effect.
463+ If `arrow=True`, `batch_sizes` of around
464+ :math:`\frac{2000}{\mathtt{len(background)}}` can produce significant
465+ performance gains.
466+ * trackCounterfactuals : bool
467+ (default=False) Keep track of produced byproduct counterfactuals during SHAP run.
469468 Returns
470469 -------
471470 :class:`~SHAPResults`
@@ -474,7 +473,7 @@ def __init__(
474473 if not link_type :
475474 link_type = _ShapConfig .LinkType .IDENTITY
476475 self ._jrandom = Random ()
477- self ._jrandom .setSeed (seed )
476+ self ._jrandom .setSeed (kwargs . get ( " seed" , 0 ) )
478477 perturbation_context = PerturbationContext (self ._jrandom , 0 )
479478
480479 if isinstance (background , np .ndarray ):
@@ -491,13 +490,13 @@ def __init__(
491490 self ._configbuilder = (
492491 _ShapConfig .builder ()
493492 .withLink (link_type )
494- .withBatchSize (batch_size )
493+ .withBatchSize (kwargs . get ( " batch_size" , 20 ) )
495494 .withPC (perturbation_context )
496495 .withBackground (self .background )
497- .withTrackCounterfactuals (track_counterfactuals )
496+ .withTrackCounterfactuals (kwargs . get ( " track_counterfactuals" , False ) )
498497 )
499- if samples is not None :
500- self ._configbuilder .withNSamples (JInt (samples ))
498+ if kwargs . get ( " samples" ) is not None :
499+ self ._configbuilder .withNSamples (JInt (kwargs [ " samples" ] ))
501500 self ._config = self ._configbuilder .build ()
502501 self ._explainer = _ShapKernelExplainer (self ._config )
503502
0 commit comments