|
1 | 1 | """Explainers module""" |
2 | 2 | # pylint: disable = import-error, too-few-public-methods |
3 | | -from typing import Dict |
| 3 | +from typing import Dict, Optional, List |
4 | 4 |
|
| 5 | +from jpype import JInt |
5 | 6 | from org.kie.kogito.explainability.local.counterfactual import ( |
6 | 7 | CounterfactualExplainer as _CounterfactualExplainer, |
7 | 8 | CounterfactualResult, |
|
13 | 14 | LimeExplainer as _LimeExplainer, |
14 | 15 | ) |
15 | 16 |
|
| 17 | +from org.kie.kogito.explainability.local.shap import ( |
| 18 | + ShapConfig as _ShapConfig, |
| 19 | + ShapResults, |
| 20 | + ShapKernelExplainer as _ShapKernelExplainer, |
| 21 | +) |
| 22 | + |
16 | 23 | from org.kie.kogito.explainability.model import ( |
17 | 24 | CounterfactualPrediction, |
18 | 25 | PredictionProvider, |
19 | 26 | Saliency, |
20 | 27 | PerturbationContext, |
| 28 | + PredictionInput as _PredictionInput, |
21 | 29 | ) |
22 | 30 | from org.optaplanner.core.config.solver.termination import TerminationConfig |
23 | 31 | from java.lang import Long |
@@ -80,3 +88,34 @@ def __init__( |
80 | 88 | def explain(self, prediction, model: PredictionProvider) -> Dict[str, Saliency]: |
81 | 89 | """Request for a LIME explanation given a prediction and a model""" |
82 | 90 | return self._explainer.explainAsync(prediction, model).get() |
| 91 | + |
| 92 | + |
| 93 | +class SHAPExplainer: |
| 94 | + """Wrapper for TrustyAI's SHAP explainer""" |
| 95 | + |
| 96 | + def __init__( |
| 97 | + self, |
| 98 | + background: List[_PredictionInput], |
| 99 | + samples=100, |
| 100 | + seed=0, |
| 101 | + perturbations=0, |
| 102 | + link_type: Optional[_ShapConfig.LinkType] = None, |
| 103 | + ): |
| 104 | + if not link_type: |
| 105 | + link_type = _ShapConfig.LinkType.IDENTITY |
| 106 | + self._jrandom = Random() |
| 107 | + self._jrandom.setSeed(seed) |
| 108 | + perturbation_context = PerturbationContext(self._jrandom, perturbations) |
| 109 | + self._config = ( |
| 110 | + _ShapConfig.builder() |
| 111 | + .withLink(link_type) |
| 112 | + .withPC(perturbation_context) |
| 113 | + .withBackground(background) |
| 114 | + .withNSamples(JInt(samples)) |
| 115 | + .build() |
| 116 | + ) |
| 117 | + self._explainer = _ShapKernelExplainer(self._config) |
| 118 | + |
| 119 | + def explain(self, prediction, model: PredictionProvider) -> List[ShapResults]: |
| 120 | + """Request for a SHAP explanation given a prediction and a model""" |
| 121 | + return self._explainer.explainAsync(prediction, model).get() |
0 commit comments