Skip to content

Commit d8f3c46

Browse files
committed
Update counterfactual notebook with constrained features
1 parent bc99230 commit d8f3c46

File tree

1 file changed

+122
-28
lines changed

1 file changed

+122
-28
lines changed

notebooks/Counterfactuals.ipynb

Lines changed: 122 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
},
1111
{
1212
"cell_type": "code",
13-
"execution_count": 2,
13+
"execution_count": 1,
1414
"id": "569777b3",
1515
"metadata": {},
1616
"outputs": [],
@@ -59,15 +59,15 @@
5959
},
6060
{
6161
"cell_type": "code",
62-
"execution_count": 3,
62+
"execution_count": 2,
6363
"id": "22ba9951",
6464
"metadata": {},
6565
"outputs": [],
6666
"source": [
6767
"from trustyai.utils import TestUtils\n",
6868
"\n",
6969
"center = 500.0\n",
70-
"epsilon = 10.0\n",
70+
"epsilon = 1.0\n",
7171
"\n",
7272
"model = TestUtils.getSumThresholdModel(center, epsilon)"
7373
]
@@ -90,7 +90,7 @@
9090
},
9191
{
9292
"cell_type": "code",
93-
"execution_count": 4,
93+
"execution_count": 3,
9494
"id": "5bcb0105",
9595
"metadata": {},
9696
"outputs": [],
@@ -110,15 +110,15 @@
110110
},
111111
{
112112
"cell_type": "code",
113-
"execution_count": 5,
113+
"execution_count": 4,
114114
"id": "6aa524ae",
115115
"metadata": {},
116116
"outputs": [],
117117
"source": [
118118
"import random\n",
119119
"from trustyai.model import FeatureFactory\n",
120120
"\n",
121-
"features = [FeatureFactory.newNumericalFeature(f\"f-num{i+1}\", random.random()*10.0) for i in range(4)]"
121+
"features = [FeatureFactory.newNumericalFeature(f\"x{i+1}\", random.random()*10.0) for i in range(4)]"
122122
]
123123
},
124124
{
@@ -131,20 +131,20 @@
131131
},
132132
{
133133
"cell_type": "code",
134-
"execution_count": 6,
134+
"execution_count": 5,
135135
"id": "f0f07043",
136136
"metadata": {},
137137
"outputs": [
138138
{
139139
"name": "stdout",
140140
"output_type": "stream",
141141
"text": [
142-
"Feature f-num1 has value 9.35846230523286\n",
143-
"Feature f-num2 has value 7.791718241742139\n",
144-
"Feature f-num3 has value 0.30365991828529393\n",
145-
"Feature f-num4 has value 2.8165353533668114\n",
142+
"Feature x1 has value 6.953686434260184\n",
143+
"Feature x2 has value 0.6895992287088226\n",
144+
"Feature x3 has value 9.429677348990124\n",
145+
"Feature x4 has value 3.6853630123991774\n",
146146
"\n",
147-
"Features sum is 20.27037581862711\n"
147+
"Features sum is 20.75832602435831\n"
148148
]
149149
}
150150
],
@@ -167,7 +167,7 @@
167167
},
168168
{
169169
"cell_type": "code",
170-
"execution_count": 7,
170+
"execution_count": 6,
171171
"id": "513d2e5a",
172172
"metadata": {},
173173
"outputs": [],
@@ -185,7 +185,7 @@
185185
},
186186
{
187187
"cell_type": "code",
188-
"execution_count": 8,
188+
"execution_count": 7,
189189
"id": "30dcc15b",
190190
"metadata": {},
191191
"outputs": [],
@@ -205,7 +205,7 @@
205205
},
206206
{
207207
"cell_type": "code",
208-
"execution_count": 9,
208+
"execution_count": 8,
209209
"id": "9cfe2a9d",
210210
"metadata": {},
211211
"outputs": [],
@@ -227,7 +227,7 @@
227227
},
228228
{
229229
"cell_type": "code",
230-
"execution_count": 10,
230+
"execution_count": 9,
231231
"id": "bcd25df0",
232232
"metadata": {},
233233
"outputs": [],
@@ -255,7 +255,7 @@
255255
},
256256
{
257257
"cell_type": "code",
258-
"execution_count": 11,
258+
"execution_count": 10,
259259
"id": "c2b76274",
260260
"metadata": {},
261261
"outputs": [
@@ -289,7 +289,7 @@
289289
},
290290
{
291291
"cell_type": "code",
292-
"execution_count": 12,
292+
"execution_count": 11,
293293
"id": "92356f76",
294294
"metadata": {},
295295
"outputs": [],
@@ -311,7 +311,7 @@
311311
},
312312
{
313313
"cell_type": "code",
314-
"execution_count": 13,
314+
"execution_count": 12,
315315
"id": "19a001ac",
316316
"metadata": {},
317317
"outputs": [],
@@ -332,7 +332,7 @@
332332
},
333333
{
334334
"cell_type": "code",
335-
"execution_count": 14,
335+
"execution_count": 13,
336336
"id": "e5783b3d",
337337
"metadata": {},
338338
"outputs": [],
@@ -350,7 +350,7 @@
350350
},
351351
{
352352
"cell_type": "code",
353-
"execution_count": 15,
353+
"execution_count": 14,
354354
"id": "cc2ad21e",
355355
"metadata": {},
356356
"outputs": [],
@@ -368,20 +368,20 @@
368368
},
369369
{
370370
"cell_type": "code",
371-
"execution_count": 19,
371+
"execution_count": 15,
372372
"id": "6f1e04c1",
373373
"metadata": {},
374374
"outputs": [
375375
{
376376
"name": "stdout",
377377
"output_type": "stream",
378378
"text": [
379-
"java.lang.DoubleFeature{value=9.35846230523286, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='f-num1'}\n",
380-
"java.lang.DoubleFeature{value=7.868907841140915, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='f-num2'}\n",
381-
"java.lang.DoubleFeature{value=0.30365991828529393, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='f-num3'}\n",
382-
"java.lang.DoubleFeature{value=481.78810412867466, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='f-num4'}\n",
379+
"java.lang.DoubleFeature{value=485.4101987057185, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x1'}\n",
380+
"java.lang.DoubleFeature{value=0.6895992287088226, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x2'}\n",
381+
"java.lang.DoubleFeature{value=9.291426877845232, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x3'}\n",
382+
"java.lang.DoubleFeature{value=3.6853630123991774, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x4'}\n",
383383
"\n",
384-
"Feature sum is 499.31913419333375\n"
384+
"Feature sum is 499.0765878246718\n"
385385
]
386386
}
387387
],
@@ -394,10 +394,104 @@
394394
"print(f\"\\nFeature sum is {feature_sum}\")"
395395
]
396396
},
397+
{
398+
"cell_type": "markdown",
399+
"id": "17c7d05f",
400+
"metadata": {},
401+
"source": [
402+
"## Constrained features\n",
403+
"\n",
404+
"As we've seen, it is possible to constraint a specific feature $x_i$ by setting the _constraints_ list corresponding element to `True`.\n",
405+
"\n",
406+
"In this example, we know want to fix $x_1$ and $x_4$. That is, these features should have the same value in the counterfactual $\\mathbf{x'}$ as in the original $\\mathbf{x}$."
407+
]
408+
},
409+
{
410+
"cell_type": "code",
411+
"execution_count": 16,
412+
"id": "919c1b51",
413+
"metadata": {},
414+
"outputs": [],
415+
"source": [
416+
"constraints = [True, False, False, True] # x1, x2, x3 and x4"
417+
]
418+
},
419+
{
420+
"cell_type": "markdown",
421+
"id": "1e7071d7",
422+
"metadata": {},
423+
"source": [
424+
"We simply need to wrap the previous quantities with the new constraints:"
425+
]
426+
},
427+
{
428+
"cell_type": "code",
429+
"execution_count": 17,
430+
"id": "42e88d35",
431+
"metadata": {},
432+
"outputs": [],
433+
"source": [
434+
"prediction = CounterfactualPrediction(original, goals, domain, constraints, None, uuid.uuid4())"
435+
]
436+
},
437+
{
438+
"cell_type": "markdown",
439+
"id": "6321b0d9",
440+
"metadata": {},
441+
"source": [
442+
"And request a new counterfactual explanation"
443+
]
444+
},
445+
{
446+
"cell_type": "code",
447+
"execution_count": 18,
448+
"id": "197fc2ea",
449+
"metadata": {},
450+
"outputs": [],
451+
"source": [
452+
"explanation = explainer.explainAsync(prediction, model).get()"
453+
]
454+
},
455+
{
456+
"cell_type": "markdown",
457+
"id": "c5d5c0d1",
458+
"metadata": {},
459+
"source": [
460+
"We can see that $x_1$ and $x_4$ has the same value as the original and the model satisfies the conditions."
461+
]
462+
},
463+
{
464+
"cell_type": "code",
465+
"execution_count": 19,
466+
"id": "7e373cf6",
467+
"metadata": {},
468+
"outputs": [
469+
{
470+
"name": "stdout",
471+
"output_type": "stream",
472+
"text": [
473+
"Original x1: 6.953686434260184\n",
474+
"Original x4: 3.6853630123991774\n",
475+
"\n",
476+
"java.lang.DoubleFeature{value=6.953686434260184, intRangeMinimum=6.953686434260184, intRangeMaximum=6.953686434260184, id='x1'}\n",
477+
"java.lang.DoubleFeature{value=0.7810382333337529, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x2'}\n",
478+
"java.lang.DoubleFeature{value=488.03303690921916, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x3'}\n",
479+
"java.lang.DoubleFeature{value=3.6853630123991774, intRangeMinimum=3.6853630123991774, intRangeMaximum=3.6853630123991774, id='x4'}\n"
480+
]
481+
}
482+
],
483+
"source": [
484+
"print(f\"Original x1: {features[0].getValue()}\")\n",
485+
"print(f\"Original x4: {features[3].getValue()}\\n\")\n",
486+
"\n",
487+
"for entity in explanation.getEntities():\n",
488+
" print(entity)"
489+
]
490+
},
397491
{
398492
"cell_type": "code",
399493
"execution_count": null,
400-
"id": "b49d9c1c",
494+
"id": "ad71d609",
401495
"metadata": {},
402496
"outputs": [],
403497
"source": []

0 commit comments

Comments
 (0)