Skip to content

Commit b0ddc5d

Browse files
committed
Update notebook
1 parent 28e1897 commit b0ddc5d

File tree

1 file changed

+174
-45
lines changed

1 file changed

+174
-45
lines changed

notebooks/Counterfactuals.ipynb

Lines changed: 174 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
},
3333
{
3434
"cell_type": "markdown",
35-
"id": "512462ee",
35+
"id": "12645d02",
3636
"metadata": {},
3737
"source": [
3838
"## Simple example\n",
@@ -60,7 +60,7 @@
6060
{
6161
"cell_type": "code",
6262
"execution_count": 4,
63-
"id": "e4f89877",
63+
"id": "22ba9951",
6464
"metadata": {},
6565
"outputs": [],
6666
"source": [
@@ -74,7 +74,7 @@
7474
},
7575
{
7676
"cell_type": "markdown",
77-
"id": "f0bb1cc2",
77+
"id": "b80d0d68",
7878
"metadata": {},
7979
"source": [
8080
"Next we need to define a **goal**.\n",
@@ -100,35 +100,92 @@
100100
"goal = [Output(\"inside\", Type.BOOLEAN, Value(True), 0.0)]"
101101
]
102102
},
103+
{
104+
"cell_type": "markdown",
105+
"id": "4e7fb934",
106+
"metadata": {},
107+
"source": [
108+
"We will now define our initial features, $\\mathbf{x}$. Each feature can be instantiated by using `FeatureFactory` and in this case we want to use numerical features, so we'll use `FeatureFactory.newNumericalFeature`."
109+
]
110+
},
103111
{
104112
"cell_type": "code",
105-
"execution_count": null,
113+
"execution_count": 11,
106114
"id": "6aa524ae",
107115
"metadata": {},
108116
"outputs": [],
109117
"source": [
110118
"import random\n",
111119
"from trustyai.model import FeatureFactory\n",
112120
"\n",
113-
"features = [FeatureFactory.newNumericalFeature(f\"f-num{i+1}\", random.random()*10.0) for i in range(4)]\n",
114-
"\n",
121+
"features = [FeatureFactory.newNumericalFeature(f\"f-num{i+1}\", random.random()*10.0) for i in range(4)]"
122+
]
123+
},
124+
{
125+
"cell_type": "markdown",
126+
"id": "db9c90ff",
127+
"metadata": {},
128+
"source": [
129+
"As we can see, the sum of of the features will not be within $\\epsilon$ (1.0) of $\\mathbf{C}$ (500.0). As such the model prediction will be `false`:"
130+
]
131+
},
132+
{
133+
"cell_type": "code",
134+
"execution_count": 12,
135+
"id": "f0f07043",
136+
"metadata": {},
137+
"outputs": [
138+
{
139+
"name": "stdout",
140+
"output_type": "stream",
141+
"text": [
142+
"Feature f-num1 has value 9.344140417436046\n",
143+
"Feature f-num2 has value 2.101222990524685\n",
144+
"Feature f-num3 has value 5.759573701749472\n",
145+
"Feature f-num4 has value 0.8173260627331469\n",
146+
"\n",
147+
"Features sum is 18.02226317244335\n"
148+
]
149+
}
150+
],
151+
"source": [
152+
"feature_sum = 0.0\n",
115153
"for f in features:\n",
116-
" print(f\"Feature {f.getName()} has value {f.getValue()}\")"
154+
" value = f.getValue().asNumber()\n",
155+
" print(f\"Feature {f.getName()} has value {value}\")\n",
156+
" feature_sum += value\n",
157+
"print(f\"\\nFeatures sum is {feature_sum}\")"
158+
]
159+
},
160+
{
161+
"cell_type": "markdown",
162+
"id": "4773e71a",
163+
"metadata": {},
164+
"source": [
165+
"The next step is to specify the **constraints** of the features, i.e. which features can be changed and which should be fixed. Since we want all features to be able to change, we specify `False` for all of them:"
117166
]
118167
},
119168
{
120169
"cell_type": "code",
121-
"execution_count": null,
170+
"execution_count": 20,
122171
"id": "513d2e5a",
123172
"metadata": {},
124173
"outputs": [],
125174
"source": [
126175
"constraints = [False] * 4"
127176
]
128177
},
178+
{
179+
"cell_type": "markdown",
180+
"id": "1894c1d7",
181+
"metadata": {},
182+
"source": [
183+
"Finally, we also specify which are the **bounds** for the counterfactual search. Typically this can be set either using domain-specific knowledge or taken from the data. In this case we simply specify an arbitrary (sensible) value, e.g. all the features can vary between `0` and `1000`."
184+
]
185+
},
129186
{
130187
"cell_type": "code",
131-
"execution_count": null,
188+
"execution_count": 13,
132189
"id": "30dcc15b",
133190
"metadata": {},
134191
"outputs": [],
@@ -139,43 +196,38 @@
139196
]
140197
},
141198
{
142-
"cell_type": "code",
143-
"execution_count": null,
144-
"id": "5047e075",
199+
"cell_type": "markdown",
200+
"id": "be0cdfe3",
145201
"metadata": {},
146-
"outputs": [],
147202
"source": [
148-
"from trustyai.model import DataDomain\n",
149-
"\n",
150-
"data_domain = DataDomain(feature_boundaries)"
203+
"In order to use the boundaries in the explainer we need to wrap all of them in a `DataDomain` class:"
151204
]
152205
},
153206
{
154207
"cell_type": "code",
155-
"execution_count": null,
156-
"id": "e1b0da83",
208+
"execution_count": 14,
209+
"id": "9cfe2a9d",
157210
"metadata": {},
158211
"outputs": [],
159212
"source": [
160-
"center = 500.0\n",
161-
"epsilon = 10.0"
213+
"from trustyai.model import DataDomain\n",
214+
"\n",
215+
"data_domain = DataDomain(feature_boundaries)"
162216
]
163217
},
164218
{
165-
"cell_type": "code",
166-
"execution_count": null,
167-
"id": "510b3b16",
219+
"cell_type": "markdown",
220+
"id": "e47d348e",
168221
"metadata": {},
169-
"outputs": [],
170222
"source": [
171-
"from trustyai.utils import TestUtils\n",
223+
"We can now instantiate the **explainer** itself.\n",
172224
"\n",
173-
"model = TestUtils.getSumThresholdModel(center, epsilon)"
225+
"To do so, we will to configure the termination criteria. For this example we will specify that the counterfactual search should only execute a maximum of 10,000 iterations before stopping and returning whatever the best result is so far."
174226
]
175227
},
176228
{
177229
"cell_type": "code",
178-
"execution_count": null,
230+
"execution_count": 15,
179231
"id": "bcd25df0",
180232
"metadata": {},
181233
"outputs": [],
@@ -193,80 +245,157 @@
193245
" )"
194246
]
195247
},
248+
{
249+
"cell_type": "markdown",
250+
"id": "790e868f",
251+
"metadata": {},
252+
"source": [
253+
"We can can now instantiate the explainer itself using `CounterfactualExplainer` and our `solver_config` configuration."
254+
]
255+
},
196256
{
197257
"cell_type": "code",
198-
"execution_count": null,
258+
"execution_count": 16,
199259
"id": "c2b76274",
200260
"metadata": {},
201-
"outputs": [],
261+
"outputs": [
262+
{
263+
"name": "stderr",
264+
"output_type": "stream",
265+
"text": [
266+
"SLF4J: Failed to load class \"org.slf4j.impl.StaticLoggerBinder\".\n",
267+
"SLF4J: Defaulting to no-operation (NOP) logger implementation\n",
268+
"SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.\n"
269+
]
270+
}
271+
],
202272
"source": [
203273
"from org.kie.kogito.explainability.local.counterfactual import CounterfactualExplainer\n",
204274
"\n",
205275
"explainer = CounterfactualExplainer.builder().withSolverConfig(solver_config).build()"
206276
]
207277
},
278+
{
279+
"cell_type": "markdown",
280+
"id": "292c136c",
281+
"metadata": {},
282+
"source": [
283+
"We will now express the counterfactual problem as defined above.\n",
284+
"\n",
285+
"- `original` represents our $\\mathbf{x}$ which know gives a prediction of `False`\n",
286+
"- `goals` represents our $\\mathbf{y'}$, that is our desired prediction (`True`)\n",
287+
"- `domain` repreents the boundaries for the counterfactual search"
288+
]
289+
},
208290
{
209291
"cell_type": "code",
210-
"execution_count": null,
211-
"id": "4cff79cd",
292+
"execution_count": 17,
293+
"id": "92356f76",
212294
"metadata": {},
213295
"outputs": [],
214296
"source": [
215297
"from trustyai.model import PredictionFeatureDomain, PredictionInput, PredictionOutput\n",
216298
"\n",
217-
"inputs = PredictionInput(features)\n",
218-
"outputs = PredictionOutput(goal)\n",
299+
"original = PredictionInput(features)\n",
300+
"goals = PredictionOutput(goal)\n",
219301
"domain = PredictionFeatureDomain(data_domain.getFeatureDomains())"
220302
]
221303
},
304+
{
305+
"cell_type": "markdown",
306+
"id": "00c09d95",
307+
"metadata": {},
308+
"source": [
309+
"We wrap these quantities in a `CounterfactualPrediction` (the UUID is simply to label the search instance):"
310+
]
311+
},
222312
{
223313
"cell_type": "code",
224-
"execution_count": null,
225-
"id": "98057ebd",
314+
"execution_count": 21,
315+
"id": "19a001ac",
226316
"metadata": {},
227317
"outputs": [],
228318
"source": [
229319
"import uuid\n",
230320
"from trustyai.model import CounterfactualPrediction\n",
231321
"\n",
232-
"prediction = CounterfactualPrediction(inputs, outputs, domain, constraints, None, uuid.uuid4())"
322+
"prediction = CounterfactualPrediction(original, goals, domain, constraints, None, uuid.uuid4())"
323+
]
324+
},
325+
{
326+
"cell_type": "markdown",
327+
"id": "6d593f4f",
328+
"metadata": {},
329+
"source": [
330+
"We now request the counterfactual $\\mathbf{x'}$ which is closest to $\\mathbf{x}$ and which satisfies $f(\\mathbf{x'}, \\epsilon, \\mathbf{C})=\\mathbf{y'}$:"
233331
]
234332
},
235333
{
236334
"cell_type": "code",
237-
"execution_count": null,
238-
"id": "910a250f",
335+
"execution_count": 22,
336+
"id": "e5783b3d",
239337
"metadata": {},
240338
"outputs": [],
241339
"source": [
242340
"explanation_async = explainer.explainAsync(prediction, model)"
243341
]
244342
},
343+
{
344+
"cell_type": "markdown",
345+
"id": "b2af6cb4",
346+
"metadata": {},
347+
"source": [
348+
"The counterfactual explainer API operates in a asynchronous way, so we need to `.get()` the result:"
349+
]
350+
},
245351
{
246352
"cell_type": "code",
247-
"execution_count": null,
248-
"id": "38774822",
353+
"execution_count": 23,
354+
"id": "cc2ad21e",
249355
"metadata": {},
250356
"outputs": [],
251357
"source": [
252358
"explanation = explanation_async.get()"
253359
]
254360
},
361+
{
362+
"cell_type": "markdown",
363+
"id": "7fcfb591",
364+
"metadata": {},
365+
"source": [
366+
"We can see that the counterfactual $\\mathbf{x'}$"
367+
]
368+
},
255369
{
256370
"cell_type": "code",
257-
"execution_count": null,
258-
"id": "7cb95b8c",
371+
"execution_count": 25,
372+
"id": "6f1e04c1",
259373
"metadata": {},
260-
"outputs": [],
374+
"outputs": [
375+
{
376+
"name": "stdout",
377+
"output_type": "stream",
378+
"text": [
379+
"java.lang.DoubleFeature{value=490.4373902874999, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='f-num1'}\n",
380+
"java.lang.DoubleFeature{value=2.420079314517709, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='f-num2'}\n",
381+
"java.lang.DoubleFeature{value=5.759573701749472, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='f-num3'}\n",
382+
"java.lang.DoubleFeature{value=0.8173260627331469, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='f-num4'}\n"
383+
]
384+
}
385+
],
261386
"source": [
387+
"feature_sum = 0.0\n",
262388
"for entity in explanation.getEntities():\n",
263-
" print(entity)"
389+
" print(entity)\n",
390+
" feature_sum += entity.getValue().asNumber()\n",
391+
" \n",
392+
"print(f\"\\nFeature sum is {fe}\")"
264393
]
265394
},
266395
{
267396
"cell_type": "code",
268397
"execution_count": null,
269-
"id": "7a8587d1",
398+
"id": "b49d9c1c",
270399
"metadata": {},
271400
"outputs": [],
272401
"source": []

0 commit comments

Comments
 (0)