|
10 | 10 | }, |
11 | 11 | { |
12 | 12 | "cell_type": "code", |
13 | | - "execution_count": 2, |
| 13 | + "execution_count": 1, |
14 | 14 | "id": "569777b3", |
15 | 15 | "metadata": {}, |
16 | 16 | "outputs": [], |
|
59 | 59 | }, |
60 | 60 | { |
61 | 61 | "cell_type": "code", |
62 | | - "execution_count": 3, |
| 62 | + "execution_count": 2, |
63 | 63 | "id": "22ba9951", |
64 | 64 | "metadata": {}, |
65 | 65 | "outputs": [], |
66 | 66 | "source": [ |
67 | 67 | "from trustyai.utils import TestUtils\n", |
68 | 68 | "\n", |
69 | 69 | "center = 500.0\n", |
70 | | - "epsilon = 10.0\n", |
| 70 | + "epsilon = 1.0\n", |
71 | 71 | "\n", |
72 | 72 | "model = TestUtils.getSumThresholdModel(center, epsilon)" |
73 | 73 | ] |
|
90 | 90 | }, |
91 | 91 | { |
92 | 92 | "cell_type": "code", |
93 | | - "execution_count": 4, |
| 93 | + "execution_count": 3, |
94 | 94 | "id": "5bcb0105", |
95 | 95 | "metadata": {}, |
96 | 96 | "outputs": [], |
|
110 | 110 | }, |
111 | 111 | { |
112 | 112 | "cell_type": "code", |
113 | | - "execution_count": 5, |
| 113 | + "execution_count": 4, |
114 | 114 | "id": "6aa524ae", |
115 | 115 | "metadata": {}, |
116 | 116 | "outputs": [], |
117 | 117 | "source": [ |
118 | 118 | "import random\n", |
119 | 119 | "from trustyai.model import FeatureFactory\n", |
120 | 120 | "\n", |
121 | | - "features = [FeatureFactory.newNumericalFeature(f\"f-num{i+1}\", random.random()*10.0) for i in range(4)]" |
| 121 | + "features = [FeatureFactory.newNumericalFeature(f\"x{i+1}\", random.random()*10.0) for i in range(4)]" |
122 | 122 | ] |
123 | 123 | }, |
124 | 124 | { |
|
131 | 131 | }, |
132 | 132 | { |
133 | 133 | "cell_type": "code", |
134 | | - "execution_count": 6, |
| 134 | + "execution_count": 5, |
135 | 135 | "id": "f0f07043", |
136 | 136 | "metadata": {}, |
137 | 137 | "outputs": [ |
138 | 138 | { |
139 | 139 | "name": "stdout", |
140 | 140 | "output_type": "stream", |
141 | 141 | "text": [ |
142 | | - "Feature f-num1 has value 9.35846230523286\n", |
143 | | - "Feature f-num2 has value 7.791718241742139\n", |
144 | | - "Feature f-num3 has value 0.30365991828529393\n", |
145 | | - "Feature f-num4 has value 2.8165353533668114\n", |
| 142 | + "Feature x1 has value 6.953686434260184\n", |
| 143 | + "Feature x2 has value 0.6895992287088226\n", |
| 144 | + "Feature x3 has value 9.429677348990124\n", |
| 145 | + "Feature x4 has value 3.6853630123991774\n", |
146 | 146 | "\n", |
147 | | - "Features sum is 20.27037581862711\n" |
| 147 | + "Features sum is 20.75832602435831\n" |
148 | 148 | ] |
149 | 149 | } |
150 | 150 | ], |
|
167 | 167 | }, |
168 | 168 | { |
169 | 169 | "cell_type": "code", |
170 | | - "execution_count": 7, |
| 170 | + "execution_count": 6, |
171 | 171 | "id": "513d2e5a", |
172 | 172 | "metadata": {}, |
173 | 173 | "outputs": [], |
|
185 | 185 | }, |
186 | 186 | { |
187 | 187 | "cell_type": "code", |
188 | | - "execution_count": 8, |
| 188 | + "execution_count": 7, |
189 | 189 | "id": "30dcc15b", |
190 | 190 | "metadata": {}, |
191 | 191 | "outputs": [], |
|
205 | 205 | }, |
206 | 206 | { |
207 | 207 | "cell_type": "code", |
208 | | - "execution_count": 9, |
| 208 | + "execution_count": 8, |
209 | 209 | "id": "9cfe2a9d", |
210 | 210 | "metadata": {}, |
211 | 211 | "outputs": [], |
|
227 | 227 | }, |
228 | 228 | { |
229 | 229 | "cell_type": "code", |
230 | | - "execution_count": 10, |
| 230 | + "execution_count": 9, |
231 | 231 | "id": "bcd25df0", |
232 | 232 | "metadata": {}, |
233 | 233 | "outputs": [], |
|
255 | 255 | }, |
256 | 256 | { |
257 | 257 | "cell_type": "code", |
258 | | - "execution_count": 11, |
| 258 | + "execution_count": 10, |
259 | 259 | "id": "c2b76274", |
260 | 260 | "metadata": {}, |
261 | 261 | "outputs": [ |
|
289 | 289 | }, |
290 | 290 | { |
291 | 291 | "cell_type": "code", |
292 | | - "execution_count": 12, |
| 292 | + "execution_count": 11, |
293 | 293 | "id": "92356f76", |
294 | 294 | "metadata": {}, |
295 | 295 | "outputs": [], |
|
311 | 311 | }, |
312 | 312 | { |
313 | 313 | "cell_type": "code", |
314 | | - "execution_count": 13, |
| 314 | + "execution_count": 12, |
315 | 315 | "id": "19a001ac", |
316 | 316 | "metadata": {}, |
317 | 317 | "outputs": [], |
|
332 | 332 | }, |
333 | 333 | { |
334 | 334 | "cell_type": "code", |
335 | | - "execution_count": 14, |
| 335 | + "execution_count": 13, |
336 | 336 | "id": "e5783b3d", |
337 | 337 | "metadata": {}, |
338 | 338 | "outputs": [], |
|
350 | 350 | }, |
351 | 351 | { |
352 | 352 | "cell_type": "code", |
353 | | - "execution_count": 15, |
| 353 | + "execution_count": 14, |
354 | 354 | "id": "cc2ad21e", |
355 | 355 | "metadata": {}, |
356 | 356 | "outputs": [], |
|
368 | 368 | }, |
369 | 369 | { |
370 | 370 | "cell_type": "code", |
371 | | - "execution_count": 19, |
| 371 | + "execution_count": 15, |
372 | 372 | "id": "6f1e04c1", |
373 | 373 | "metadata": {}, |
374 | 374 | "outputs": [ |
375 | 375 | { |
376 | 376 | "name": "stdout", |
377 | 377 | "output_type": "stream", |
378 | 378 | "text": [ |
379 | | - "java.lang.DoubleFeature{value=9.35846230523286, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='f-num1'}\n", |
380 | | - "java.lang.DoubleFeature{value=7.868907841140915, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='f-num2'}\n", |
381 | | - "java.lang.DoubleFeature{value=0.30365991828529393, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='f-num3'}\n", |
382 | | - "java.lang.DoubleFeature{value=481.78810412867466, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='f-num4'}\n", |
| 379 | + "java.lang.DoubleFeature{value=485.4101987057185, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x1'}\n", |
| 380 | + "java.lang.DoubleFeature{value=0.6895992287088226, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x2'}\n", |
| 381 | + "java.lang.DoubleFeature{value=9.291426877845232, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x3'}\n", |
| 382 | + "java.lang.DoubleFeature{value=3.6853630123991774, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x4'}\n", |
383 | 383 | "\n", |
384 | | - "Feature sum is 499.31913419333375\n" |
| 384 | + "Feature sum is 499.0765878246718\n" |
385 | 385 | ] |
386 | 386 | } |
387 | 387 | ], |
|
394 | 394 | "print(f\"\\nFeature sum is {feature_sum}\")" |
395 | 395 | ] |
396 | 396 | }, |
| 397 | + { |
| 398 | + "cell_type": "markdown", |
| 399 | + "id": "17c7d05f", |
| 400 | + "metadata": {}, |
| 401 | + "source": [ |
| 402 | + "## Constrained features\n", |
| 403 | + "\n", |
| 404 | + "As we've seen, it is possible to constraint a specific feature $x_i$ by setting the _constraints_ list corresponding element to `True`.\n", |
| 405 | + "\n", |
| 406 | + "In this example, we know want to fix $x_1$ and $x_4$. That is, these features should have the same value in the counterfactual $\\mathbf{x'}$ as in the original $\\mathbf{x}$." |
| 407 | + ] |
| 408 | + }, |
| 409 | + { |
| 410 | + "cell_type": "code", |
| 411 | + "execution_count": 16, |
| 412 | + "id": "919c1b51", |
| 413 | + "metadata": {}, |
| 414 | + "outputs": [], |
| 415 | + "source": [ |
| 416 | + "constraints = [True, False, False, True] # x1, x2, x3 and x4" |
| 417 | + ] |
| 418 | + }, |
| 419 | + { |
| 420 | + "cell_type": "markdown", |
| 421 | + "id": "1e7071d7", |
| 422 | + "metadata": {}, |
| 423 | + "source": [ |
| 424 | + "We simply need to wrap the previous quantities with the new constraints:" |
| 425 | + ] |
| 426 | + }, |
| 427 | + { |
| 428 | + "cell_type": "code", |
| 429 | + "execution_count": 17, |
| 430 | + "id": "42e88d35", |
| 431 | + "metadata": {}, |
| 432 | + "outputs": [], |
| 433 | + "source": [ |
| 434 | + "prediction = CounterfactualPrediction(original, goals, domain, constraints, None, uuid.uuid4())" |
| 435 | + ] |
| 436 | + }, |
| 437 | + { |
| 438 | + "cell_type": "markdown", |
| 439 | + "id": "6321b0d9", |
| 440 | + "metadata": {}, |
| 441 | + "source": [ |
| 442 | + "And request a new counterfactual explanation" |
| 443 | + ] |
| 444 | + }, |
| 445 | + { |
| 446 | + "cell_type": "code", |
| 447 | + "execution_count": 18, |
| 448 | + "id": "197fc2ea", |
| 449 | + "metadata": {}, |
| 450 | + "outputs": [], |
| 451 | + "source": [ |
| 452 | + "explanation = explainer.explainAsync(prediction, model).get()" |
| 453 | + ] |
| 454 | + }, |
| 455 | + { |
| 456 | + "cell_type": "markdown", |
| 457 | + "id": "c5d5c0d1", |
| 458 | + "metadata": {}, |
| 459 | + "source": [ |
| 460 | + "We can see that $x_1$ and $x_4$ has the same value as the original and the model satisfies the conditions." |
| 461 | + ] |
| 462 | + }, |
| 463 | + { |
| 464 | + "cell_type": "code", |
| 465 | + "execution_count": 19, |
| 466 | + "id": "7e373cf6", |
| 467 | + "metadata": {}, |
| 468 | + "outputs": [ |
| 469 | + { |
| 470 | + "name": "stdout", |
| 471 | + "output_type": "stream", |
| 472 | + "text": [ |
| 473 | + "Original x1: 6.953686434260184\n", |
| 474 | + "Original x4: 3.6853630123991774\n", |
| 475 | + "\n", |
| 476 | + "java.lang.DoubleFeature{value=6.953686434260184, intRangeMinimum=6.953686434260184, intRangeMaximum=6.953686434260184, id='x1'}\n", |
| 477 | + "java.lang.DoubleFeature{value=0.7810382333337529, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x2'}\n", |
| 478 | + "java.lang.DoubleFeature{value=488.03303690921916, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x3'}\n", |
| 479 | + "java.lang.DoubleFeature{value=3.6853630123991774, intRangeMinimum=3.6853630123991774, intRangeMaximum=3.6853630123991774, id='x4'}\n" |
| 480 | + ] |
| 481 | + } |
| 482 | + ], |
| 483 | + "source": [ |
| 484 | + "print(f\"Original x1: {features[0].getValue()}\")\n", |
| 485 | + "print(f\"Original x4: {features[3].getValue()}\\n\")\n", |
| 486 | + "\n", |
| 487 | + "for entity in explanation.getEntities():\n", |
| 488 | + " print(entity)" |
| 489 | + ] |
| 490 | + }, |
397 | 491 | { |
398 | 492 | "cell_type": "code", |
399 | 493 | "execution_count": null, |
400 | | - "id": "b49d9c1c", |
| 494 | + "id": "ad71d609", |
401 | 495 | "metadata": {}, |
402 | 496 | "outputs": [], |
403 | 497 | "source": [] |
|
0 commit comments