11"""Explainers.shap module"""
22# pylint: disable = import-error, too-few-public-methods, wrong-import-order, line-too-long,
33# pylint: disable = unused-argument, consider-using-f-string, invalid-name
4- from typing import Dict , Optional , List , Union
4+ from typing import Dict , Optional
55import matplotlib .pyplot as plt
66import matplotlib as mpl
77from bokeh .models import ColumnDataSource , HoverTool
2121 output_html ,
2222 feature_html ,
2323)
24-
2524from trustyai .model import (
26- feature ,
27- Dataset ,
28- PredictionInput ,
2925 simple_prediction ,
30- PredUnionType ,
26+ )
27+ from trustyai .utils .data_conversions import (
28+ OneInputUnionType ,
29+ OneOutputUnionType ,
30+ ManyInputsUnionType ,
31+ many_inputs_convert ,
32+ data_conversion_docstring ,
3133)
3234
3335from org .kie .trustyai .explainability .local .shap import (
@@ -434,20 +436,19 @@ class SHAPExplainer:
434436 the outputs, as compared to the background inputs?*
435437 """
436438
439+ @data_conversion_docstring ("many_inputs" )
437440 def __init__ (
438441 self ,
439- background : Union [ np . ndarray , pd . DataFrame , List [ PredictionInput ]] ,
442+ background : ManyInputsUnionType ,
440443 link_type : Optional [_ShapConfig .LinkType ] = None ,
441444 ** kwargs ,
442445 ):
443446 r"""Initialize the :class:`SHAPxplainer`.
444447
445448 Parameters
446449 ----------
447- background : :class:`numpy.array`, :class:`Pandas.DataFrame`
448- or List[:class:`PredictionInput]
449- The set of background datapoints as an array, dataframe of shape
450- ``[n_datapoints, n_features]``, or list of TrustyAI PredictionInputs.
450+ background : {}
451+ The set of background datapoints as a: {}
451452 link_type : :obj:`~_ShapConfig.LinkType`
452453 A choice of either ``trustyai.explainers._ShapConfig.LinkType.IDENTITY``
453454 or ``trustyai.explainers._ShapConfig.LinkType.LOGIT``. If the model output is a
@@ -464,10 +465,11 @@ def __init__(
464465 (default=20) The number of batches passed to the PredictionProvider at once.
465466 When uusing :class:`~Model` with `arrow=False` this parameter has no effect.
466467 If `arrow=True`, `batch_sizes` of around
467- :math:`\frac{2000}{ \mathtt{len(background)}}` can produce significant
468+ :math:`\frac{{ 2000}}{{ \mathtt{{ len(background)}} }}` can produce significant
468469 performance gains.
469470 * trackCounterfactuals : bool
470471 (default=False) Keep track of produced byproduct counterfactuals during SHAP run.
472+
471473 Returns
472474 -------
473475 :class:`~SHAPResults`
@@ -477,19 +479,9 @@ def __init__(
477479 link_type = _ShapConfig .LinkType .IDENTITY
478480 self ._jrandom = Random ()
479481 self ._jrandom .setSeed (kwargs .get ("seed" , 0 ))
482+ self .background = many_inputs_convert (background )
480483 perturbation_context = PerturbationContext (self ._jrandom , 0 )
481484
482- if isinstance (background , np .ndarray ):
483- self .background = Dataset .numpy_to_prediction_object (background , feature )
484- elif isinstance (background , pd .DataFrame ):
485- self .background = Dataset .df_to_prediction_object (background , feature )
486- elif isinstance (background [0 ], PredictionInput ):
487- self .background = background
488- else :
489- raise AttributeError (
490- "Unsupported background type: {}" .format (type (background ))
491- )
492-
493485 self ._configbuilder = (
494486 _ShapConfig .builder ()
495487 .withLink (link_type )
@@ -503,32 +495,22 @@ def __init__(
503495 self ._config = self ._configbuilder .build ()
504496 self ._explainer = _ShapKernelExplainer (self ._config )
505497
498+ @data_conversion_docstring ("one_input" , "one_output" )
506499 def explain (
507- self , inputs : PredUnionType , outputs : PredUnionType , model : PredictionProvider
500+ self ,
501+ inputs : OneInputUnionType ,
502+ outputs : OneOutputUnionType ,
503+ model : PredictionProvider ,
508504 ) -> SHAPResults :
509505 """Produce a SHAP explanation.
510506
511507 Parameters
512508 ----------
513- inputs : :class:`numpy.ndarray`, :class:`pandas.DataFrame`, List[:class:`Feature`], or :class:`PredictionInput`
514- The input features to the model, as a:
515-
516- * Numpy array of shape ``[1, n_features]``
517- * Pandas DataFrame with 1 row and ``n_features`` columns
518- * A List of TrustyAI :class:`Feature`, as created by the :func:`~feature` function
519- * A TrustyAI :class:`PredictionInput`
520-
521- outputs : :class:`numpy.ndarray`, :class:`pandas.DataFrame`, List[:class:`Output`], or :class:`PredictionOutput`
509+ inputs : {}
510+ The input features to the model, as a: {}
511+ outputs : {}
522512 The corresponding model outputs for the provided features, that is,
523- ``outputs = model(input_features)``. These can take the form of a:
524-
525- * Numpy array of shape ``[1, n_outputs]``
526- * Pandas DataFrame with 1 row and ``n_outputs`` columns
527- * A List of TrustyAI :class:`Output`, as created by the :func:`~output` function
528- * A TrustyAI :class:`PredictionOutput`
529- model : :obj:`~trustyai.model.PredictionProvider`
530- The TrustyAI PredictionProvider, as generated by :class:`~trustyai.model.Model` or
531- :class:`~trustyai.model.ArrowModel`.
513+ ``outputs = model(input_features)``. These can take the form of a: {}
532514
533515 Returns
534516 -------
0 commit comments