Skip to content

Commit 68e403a

Browse files
authored
Merge pull request #8 from ruivieira/shap-example
Add SHAP notebook
2 parents dfc60d7 + a7bd710 commit 68e403a

File tree

2 files changed

+332
-1
lines changed

2 files changed

+332
-1
lines changed

examples/SHAP.ipynb

Lines changed: 331 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "d2a08ef6",
6+
"metadata": {},
7+
"source": [
8+
"# SHAP explanations"
9+
]
10+
},
11+
{
12+
"cell_type": "code",
13+
"execution_count": null,
14+
"id": "767b003e",
15+
"metadata": {},
16+
"outputs": [],
17+
"source": [
18+
"import trustyai\n",
19+
"\n",
20+
"trustyai.init()"
21+
]
22+
},
23+
{
24+
"cell_type": "markdown",
25+
"id": "e194eb56",
26+
"metadata": {},
27+
"source": [
28+
"## Simple example\n",
29+
"\n",
30+
"We start by defining our black-box model, typically represented by\n",
31+
"\n",
32+
"$$\n",
33+
"f(\\mathbf{x}) = \\mathbf{y}\n",
34+
"$$\n",
35+
"\n",
36+
"Where $\\mathbf{x}=\\{x_1, x_2, \\dots,x_m\\}$ and $\\mathbf{y}=\\{y_1, y_2, \\dots,y_n\\}$.\n",
37+
"\n",
38+
"Our example toy model, in this case, takes an all-numerical input $\\mathbf{x}$ and return a $\\mathbf{y}$ of either `true` or `false` if the sum of the $\\mathbf{x}$ components is within a threshold $\\epsilon$ of a point $\\mathbf{C}$, that is:\n",
39+
"\n",
40+
"$$\n",
41+
"f(\\mathbf{x}, \\epsilon, \\mathbf{C})=\\begin{cases}\n",
42+
"\\text{true},\\qquad \\text{if}\\ \\mathbf{C}-\\epsilon<\\sum_{i=1}^m x_i <\\mathbf{C}+\\epsilon \\\\\n",
43+
"\\text{false},\\qquad \\text{otherwise}\n",
44+
"\\end{cases}\n",
45+
"$$\n",
46+
"\n",
47+
"This model is provided in the `TestUtils` module. We instantiate with a $\\mathbf{C}=500$ and $\\epsilon=1.0$."
48+
]
49+
},
50+
{
51+
"cell_type": "code",
52+
"execution_count": 2,
53+
"id": "fd02e320",
54+
"metadata": {},
55+
"outputs": [],
56+
"source": [
57+
"from trustyai.utils import TestUtils\n",
58+
"\n",
59+
"center = 10.0\n",
60+
"epsilon = 2.0\n",
61+
"\n",
62+
"model = TestUtils.getSumThresholdModel(center, epsilon)"
63+
]
64+
},
65+
{
66+
"cell_type": "markdown",
67+
"id": "e4a15f8b",
68+
"metadata": {},
69+
"source": [
70+
"Next we need to define a **goal**.\n",
71+
"If our model is $f(\\mathbf{x'})=\\mathbf{y'}$ we are then defining our $\\mathbf{y'}$ and the counterfactual result will be the $\\mathbf{x'}$ which satisfies $f(\\mathbf{x'})=\\mathbf{y'}$.\n",
72+
"\n",
73+
"We will define our goal as `true`, that is, the sum is withing the vicinity of a (to be defined) point $\\mathbf{C}$. The goal is a list of `Output` which take the following parameters\n",
74+
"\n",
75+
"- The feature name\n",
76+
"- The feature type\n",
77+
"- The feature value (wrapped in `Value`)\n",
78+
"- A confidence threshold, which we will leave at zero (no threshold)"
79+
]
80+
},
81+
{
82+
"cell_type": "code",
83+
"execution_count": 3,
84+
"id": "bf3f4232",
85+
"metadata": {},
86+
"outputs": [],
87+
"source": [
88+
"from trustyai.model import output\n",
89+
"\n",
90+
"decision = \"inside\"\n",
91+
"goal = [output(name=decision, dtype=\"bool\", value=True, score=0.0)]"
92+
]
93+
},
94+
{
95+
"cell_type": "markdown",
96+
"id": "64349c3e",
97+
"metadata": {},
98+
"source": [
99+
"We will now define our initial features, $\\mathbf{x}$. Each feature can be instantiated by using `FeatureFactory` and in this case we want to use numerical features, so we'll use `FeatureFactory.newNumericalFeature`."
100+
]
101+
},
102+
{
103+
"cell_type": "code",
104+
"execution_count": 4,
105+
"id": "d688a7c8",
106+
"metadata": {},
107+
"outputs": [],
108+
"source": [
109+
"import random\n",
110+
"from trustyai.model import feature\n",
111+
"\n",
112+
"features = [feature(name=f\"x{i+1}\", dtype=\"number\", value=random.random()*10.0) for i in range(3)]"
113+
]
114+
},
115+
{
116+
"cell_type": "markdown",
117+
"id": "a562ef68",
118+
"metadata": {},
119+
"source": [
120+
"As we can see, the sum of of the features will not be within $\\epsilon$ (1.0) of $\\mathbf{C}$ (500.0). As such the model prediction will be `false`:"
121+
]
122+
},
123+
{
124+
"cell_type": "code",
125+
"execution_count": 5,
126+
"id": "48212d3f",
127+
"metadata": {},
128+
"outputs": [
129+
{
130+
"name": "stdout",
131+
"output_type": "stream",
132+
"text": [
133+
"Feature x1 has value 2.1516473114599046\n",
134+
"Feature x2 has value 0.8137674993709809\n",
135+
"Feature x3 has value 5.637541112355343\n",
136+
"\n",
137+
"Features sum is 8.60295592318623\n"
138+
]
139+
}
140+
],
141+
"source": [
142+
"feature_sum = 0.0\n",
143+
"for f in features:\n",
144+
" value = f.value.as_number()\n",
145+
" print(f\"Feature {f.name} has value {value}\")\n",
146+
" feature_sum += value\n",
147+
"print(f\"\\nFeatures sum is {feature_sum}\")"
148+
]
149+
},
150+
{
151+
"cell_type": "markdown",
152+
"id": "13001554",
153+
"metadata": {},
154+
"source": [
155+
"We execute the model on the generated input and collect the output"
156+
]
157+
},
158+
{
159+
"cell_type": "code",
160+
"execution_count": 6,
161+
"id": "0a45c0e0",
162+
"metadata": {
163+
"pycharm": {
164+
"name": "#%%\n"
165+
}
166+
},
167+
"outputs": [],
168+
"source": [
169+
"from org.kie.kogito.explainability.model import PredictionInput, PredictionOutput\n",
170+
"\n",
171+
"goals = model.predictAsync([PredictionInput(features)]).get()"
172+
]
173+
},
174+
{
175+
"cell_type": "code",
176+
"execution_count": 7,
177+
"id": "4483bf24",
178+
"metadata": {},
179+
"outputs": [],
180+
"source": [
181+
"background = []\n",
182+
"for i in range(10):\n",
183+
" _features = [feature(name=f\"x{i+1}\", dtype=\"number\", value=random.random()*10.0) for i in range(3)]\n",
184+
" background.append(PredictionInput(_features))"
185+
]
186+
},
187+
{
188+
"cell_type": "markdown",
189+
"id": "324cefdf",
190+
"metadata": {
191+
"pycharm": {
192+
"name": "#%% md\n"
193+
}
194+
},
195+
"source": [
196+
"We wrap these quantities in a `SimplePrediction`:"
197+
]
198+
},
199+
{
200+
"cell_type": "code",
201+
"execution_count": 8,
202+
"id": "8bb2aac1",
203+
"metadata": {
204+
"pycharm": {
205+
"name": "#%%\n"
206+
}
207+
},
208+
"outputs": [],
209+
"source": [
210+
"from trustyai.model import simple_prediction\n",
211+
"\n",
212+
"prediction = simple_prediction(input_features=features, outputs=goals[0].outputs)"
213+
]
214+
},
215+
{
216+
"cell_type": "markdown",
217+
"id": "9bb631f9",
218+
"metadata": {
219+
"pycharm": {
220+
"name": "#%% md\n"
221+
}
222+
},
223+
"source": [
224+
"We can now instantiate the **explainer** itself.\n"
225+
]
226+
},
227+
{
228+
"cell_type": "code",
229+
"execution_count": 9,
230+
"id": "115fa89c",
231+
"metadata": {},
232+
"outputs": [
233+
{
234+
"name": "stderr",
235+
"output_type": "stream",
236+
"text": [
237+
"SLF4J: Failed to load class \"org.slf4j.impl.StaticLoggerBinder\".\n",
238+
"SLF4J: Defaulting to no-operation (NOP) logger implementation\n",
239+
"SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.\n"
240+
]
241+
}
242+
],
243+
"source": [
244+
"from trustyai.explainers import SHAPExplainer\n",
245+
"\n",
246+
"explainer = SHAPExplainer(background=background)"
247+
]
248+
},
249+
{
250+
"cell_type": "markdown",
251+
"id": "7cd8b2b4",
252+
"metadata": {
253+
"pycharm": {
254+
"name": "#%% md\n"
255+
}
256+
},
257+
"source": [
258+
"We generate the **explanation** as a _dict : decision --> saliency_.\n"
259+
]
260+
},
261+
{
262+
"cell_type": "code",
263+
"execution_count": 10,
264+
"id": "b34e26d7",
265+
"metadata": {
266+
"pycharm": {
267+
"name": "#%%\n"
268+
}
269+
},
270+
"outputs": [],
271+
"source": [
272+
"explanation = explainer.explain(prediction, model)"
273+
]
274+
},
275+
{
276+
"cell_type": "markdown",
277+
"id": "d32e4272",
278+
"metadata": {
279+
"pycharm": {
280+
"name": "#%% md\n"
281+
}
282+
},
283+
"source": [
284+
"We inspect the saliency scores assigned by LIME to each feature"
285+
]
286+
},
287+
{
288+
"cell_type": "code",
289+
"execution_count": 16,
290+
"id": "2f0721fe",
291+
"metadata": {},
292+
"outputs": [
293+
{
294+
"name": "stdout",
295+
"output_type": "stream",
296+
"text": [
297+
"Saliency{output=Output{value=true, type=boolean, score=-0.39704407681377063, name='inside'}, perFeatureImportance=[FeatureImportance{feature=Feature{name='x1', type=number, value=2.1516473114599046}, score=0.4, confidence= +/-0.39264863227014996}, FeatureImportance{feature=Feature{name='x2', type=number, value=0.8137674993709809}, score=0.35, confidence= +/-0.39264863227014996}, FeatureImportance{feature=Feature{name='x3', type=number, value=5.637541112355343}, score=0.15000000000000002, confidence= +/-0.5552890210036922}]}\n"
298+
]
299+
}
300+
],
301+
"source": [
302+
"for saliency in explanation.getSaliencies():\n",
303+
" print(saliency)"
304+
]
305+
}
306+
],
307+
"metadata": {
308+
"interpreter": {
309+
"hash": "a0b19a0e0769482a3dd54d9b1f74632fb70b79784820162adf8976b9cad4acbb"
310+
},
311+
"kernelspec": {
312+
"display_name": "trustyai-python",
313+
"language": "python",
314+
"name": "python3"
315+
},
316+
"language_info": {
317+
"codemirror_mode": {
318+
"name": "ipython",
319+
"version": 3
320+
},
321+
"file_extension": ".py",
322+
"mimetype": "text/x-python",
323+
"name": "python",
324+
"nbconvert_exporter": "python",
325+
"pygments_lexer": "ipython3",
326+
"version": "3.9.10"
327+
}
328+
},
329+
"nbformat": 4,
330+
"nbformat_minor": 5
331+
}

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
trustyai==0.1.0
1+
trustyai==0.1.1

0 commit comments

Comments
 (0)