|
32 | 32 | }, |
33 | 33 | { |
34 | 34 | "cell_type": "markdown", |
35 | | - "id": "512462ee", |
| 35 | + "id": "12645d02", |
36 | 36 | "metadata": {}, |
37 | 37 | "source": [ |
38 | 38 | "## Simple example\n", |
|
60 | 60 | { |
61 | 61 | "cell_type": "code", |
62 | 62 | "execution_count": 4, |
63 | | - "id": "e4f89877", |
| 63 | + "id": "22ba9951", |
64 | 64 | "metadata": {}, |
65 | 65 | "outputs": [], |
66 | 66 | "source": [ |
|
74 | 74 | }, |
75 | 75 | { |
76 | 76 | "cell_type": "markdown", |
77 | | - "id": "f0bb1cc2", |
| 77 | + "id": "b80d0d68", |
78 | 78 | "metadata": {}, |
79 | 79 | "source": [ |
80 | 80 | "Next we need to define a **goal**.\n", |
|
100 | 100 | "goal = [Output(\"inside\", Type.BOOLEAN, Value(True), 0.0)]" |
101 | 101 | ] |
102 | 102 | }, |
| 103 | + { |
| 104 | + "cell_type": "markdown", |
| 105 | + "id": "4e7fb934", |
| 106 | + "metadata": {}, |
| 107 | + "source": [ |
| 108 | + "We will now define our initial features, $\\mathbf{x}$. Each feature can be instantiated by using `FeatureFactory` and in this case we want to use numerical features, so we'll use `FeatureFactory.newNumericalFeature`." |
| 109 | + ] |
| 110 | + }, |
103 | 111 | { |
104 | 112 | "cell_type": "code", |
105 | | - "execution_count": null, |
| 113 | + "execution_count": 11, |
106 | 114 | "id": "6aa524ae", |
107 | 115 | "metadata": {}, |
108 | 116 | "outputs": [], |
109 | 117 | "source": [ |
110 | 118 | "import random\n", |
111 | 119 | "from trustyai.model import FeatureFactory\n", |
112 | 120 | "\n", |
113 | | - "features = [FeatureFactory.newNumericalFeature(f\"f-num{i+1}\", random.random()*10.0) for i in range(4)]\n", |
114 | | - "\n", |
| 121 | + "features = [FeatureFactory.newNumericalFeature(f\"f-num{i+1}\", random.random()*10.0) for i in range(4)]" |
| 122 | + ] |
| 123 | + }, |
| 124 | + { |
| 125 | + "cell_type": "markdown", |
| 126 | + "id": "db9c90ff", |
| 127 | + "metadata": {}, |
| 128 | + "source": [ |
| 129 | + "As we can see, the sum of of the features will not be within $\\epsilon$ (1.0) of $\\mathbf{C}$ (500.0). As such the model prediction will be `false`:" |
| 130 | + ] |
| 131 | + }, |
| 132 | + { |
| 133 | + "cell_type": "code", |
| 134 | + "execution_count": 12, |
| 135 | + "id": "f0f07043", |
| 136 | + "metadata": {}, |
| 137 | + "outputs": [ |
| 138 | + { |
| 139 | + "name": "stdout", |
| 140 | + "output_type": "stream", |
| 141 | + "text": [ |
| 142 | + "Feature f-num1 has value 9.344140417436046\n", |
| 143 | + "Feature f-num2 has value 2.101222990524685\n", |
| 144 | + "Feature f-num3 has value 5.759573701749472\n", |
| 145 | + "Feature f-num4 has value 0.8173260627331469\n", |
| 146 | + "\n", |
| 147 | + "Features sum is 18.02226317244335\n" |
| 148 | + ] |
| 149 | + } |
| 150 | + ], |
| 151 | + "source": [ |
| 152 | + "feature_sum = 0.0\n", |
115 | 153 | "for f in features:\n", |
116 | | - " print(f\"Feature {f.getName()} has value {f.getValue()}\")" |
| 154 | + " value = f.getValue().asNumber()\n", |
| 155 | + " print(f\"Feature {f.getName()} has value {value}\")\n", |
| 156 | + " feature_sum += value\n", |
| 157 | + "print(f\"\\nFeatures sum is {feature_sum}\")" |
| 158 | + ] |
| 159 | + }, |
| 160 | + { |
| 161 | + "cell_type": "markdown", |
| 162 | + "id": "4773e71a", |
| 163 | + "metadata": {}, |
| 164 | + "source": [ |
| 165 | + "The next step is to specify the **constraints** of the features, i.e. which features can be changed and which should be fixed. Since we want all features to be able to change, we specify `False` for all of them:" |
117 | 166 | ] |
118 | 167 | }, |
119 | 168 | { |
120 | 169 | "cell_type": "code", |
121 | | - "execution_count": null, |
| 170 | + "execution_count": 20, |
122 | 171 | "id": "513d2e5a", |
123 | 172 | "metadata": {}, |
124 | 173 | "outputs": [], |
125 | 174 | "source": [ |
126 | 175 | "constraints = [False] * 4" |
127 | 176 | ] |
128 | 177 | }, |
| 178 | + { |
| 179 | + "cell_type": "markdown", |
| 180 | + "id": "1894c1d7", |
| 181 | + "metadata": {}, |
| 182 | + "source": [ |
| 183 | + "Finally, we also specify which are the **bounds** for the counterfactual search. Typically this can be set either using domain-specific knowledge or taken from the data. In this case we simply specify an arbitrary (sensible) value, e.g. all the features can vary between `0` and `1000`." |
| 184 | + ] |
| 185 | + }, |
129 | 186 | { |
130 | 187 | "cell_type": "code", |
131 | | - "execution_count": null, |
| 188 | + "execution_count": 13, |
132 | 189 | "id": "30dcc15b", |
133 | 190 | "metadata": {}, |
134 | 191 | "outputs": [], |
|
139 | 196 | ] |
140 | 197 | }, |
141 | 198 | { |
142 | | - "cell_type": "code", |
143 | | - "execution_count": null, |
144 | | - "id": "5047e075", |
| 199 | + "cell_type": "markdown", |
| 200 | + "id": "be0cdfe3", |
145 | 201 | "metadata": {}, |
146 | | - "outputs": [], |
147 | 202 | "source": [ |
148 | | - "from trustyai.model import DataDomain\n", |
149 | | - "\n", |
150 | | - "data_domain = DataDomain(feature_boundaries)" |
| 203 | + "In order to use the boundaries in the explainer we need to wrap all of them in a `DataDomain` class:" |
151 | 204 | ] |
152 | 205 | }, |
153 | 206 | { |
154 | 207 | "cell_type": "code", |
155 | | - "execution_count": null, |
156 | | - "id": "e1b0da83", |
| 208 | + "execution_count": 14, |
| 209 | + "id": "9cfe2a9d", |
157 | 210 | "metadata": {}, |
158 | 211 | "outputs": [], |
159 | 212 | "source": [ |
160 | | - "center = 500.0\n", |
161 | | - "epsilon = 10.0" |
| 213 | + "from trustyai.model import DataDomain\n", |
| 214 | + "\n", |
| 215 | + "data_domain = DataDomain(feature_boundaries)" |
162 | 216 | ] |
163 | 217 | }, |
164 | 218 | { |
165 | | - "cell_type": "code", |
166 | | - "execution_count": null, |
167 | | - "id": "510b3b16", |
| 219 | + "cell_type": "markdown", |
| 220 | + "id": "e47d348e", |
168 | 221 | "metadata": {}, |
169 | | - "outputs": [], |
170 | 222 | "source": [ |
171 | | - "from trustyai.utils import TestUtils\n", |
| 223 | + "We can now instantiate the **explainer** itself.\n", |
172 | 224 | "\n", |
173 | | - "model = TestUtils.getSumThresholdModel(center, epsilon)" |
| 225 | + "To do so, we will to configure the termination criteria. For this example we will specify that the counterfactual search should only execute a maximum of 10,000 iterations before stopping and returning whatever the best result is so far." |
174 | 226 | ] |
175 | 227 | }, |
176 | 228 | { |
177 | 229 | "cell_type": "code", |
178 | | - "execution_count": null, |
| 230 | + "execution_count": 15, |
179 | 231 | "id": "bcd25df0", |
180 | 232 | "metadata": {}, |
181 | 233 | "outputs": [], |
|
193 | 245 | " )" |
194 | 246 | ] |
195 | 247 | }, |
| 248 | + { |
| 249 | + "cell_type": "markdown", |
| 250 | + "id": "790e868f", |
| 251 | + "metadata": {}, |
| 252 | + "source": [ |
| 253 | + "We can can now instantiate the explainer itself using `CounterfactualExplainer` and our `solver_config` configuration." |
| 254 | + ] |
| 255 | + }, |
196 | 256 | { |
197 | 257 | "cell_type": "code", |
198 | | - "execution_count": null, |
| 258 | + "execution_count": 16, |
199 | 259 | "id": "c2b76274", |
200 | 260 | "metadata": {}, |
201 | | - "outputs": [], |
| 261 | + "outputs": [ |
| 262 | + { |
| 263 | + "name": "stderr", |
| 264 | + "output_type": "stream", |
| 265 | + "text": [ |
| 266 | + "SLF4J: Failed to load class \"org.slf4j.impl.StaticLoggerBinder\".\n", |
| 267 | + "SLF4J: Defaulting to no-operation (NOP) logger implementation\n", |
| 268 | + "SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.\n" |
| 269 | + ] |
| 270 | + } |
| 271 | + ], |
202 | 272 | "source": [ |
203 | 273 | "from org.kie.kogito.explainability.local.counterfactual import CounterfactualExplainer\n", |
204 | 274 | "\n", |
205 | 275 | "explainer = CounterfactualExplainer.builder().withSolverConfig(solver_config).build()" |
206 | 276 | ] |
207 | 277 | }, |
| 278 | + { |
| 279 | + "cell_type": "markdown", |
| 280 | + "id": "292c136c", |
| 281 | + "metadata": {}, |
| 282 | + "source": [ |
| 283 | + "We will now express the counterfactual problem as defined above.\n", |
| 284 | + "\n", |
| 285 | + "- `original` represents our $\\mathbf{x}$ which know gives a prediction of `False`\n", |
| 286 | + "- `goals` represents our $\\mathbf{y'}$, that is our desired prediction (`True`)\n", |
| 287 | + "- `domain` repreents the boundaries for the counterfactual search" |
| 288 | + ] |
| 289 | + }, |
208 | 290 | { |
209 | 291 | "cell_type": "code", |
210 | | - "execution_count": null, |
211 | | - "id": "4cff79cd", |
| 292 | + "execution_count": 17, |
| 293 | + "id": "92356f76", |
212 | 294 | "metadata": {}, |
213 | 295 | "outputs": [], |
214 | 296 | "source": [ |
215 | 297 | "from trustyai.model import PredictionFeatureDomain, PredictionInput, PredictionOutput\n", |
216 | 298 | "\n", |
217 | | - "inputs = PredictionInput(features)\n", |
218 | | - "outputs = PredictionOutput(goal)\n", |
| 299 | + "original = PredictionInput(features)\n", |
| 300 | + "goals = PredictionOutput(goal)\n", |
219 | 301 | "domain = PredictionFeatureDomain(data_domain.getFeatureDomains())" |
220 | 302 | ] |
221 | 303 | }, |
| 304 | + { |
| 305 | + "cell_type": "markdown", |
| 306 | + "id": "00c09d95", |
| 307 | + "metadata": {}, |
| 308 | + "source": [ |
| 309 | + "We wrap these quantities in a `CounterfactualPrediction` (the UUID is simply to label the search instance):" |
| 310 | + ] |
| 311 | + }, |
222 | 312 | { |
223 | 313 | "cell_type": "code", |
224 | | - "execution_count": null, |
225 | | - "id": "98057ebd", |
| 314 | + "execution_count": 21, |
| 315 | + "id": "19a001ac", |
226 | 316 | "metadata": {}, |
227 | 317 | "outputs": [], |
228 | 318 | "source": [ |
229 | 319 | "import uuid\n", |
230 | 320 | "from trustyai.model import CounterfactualPrediction\n", |
231 | 321 | "\n", |
232 | | - "prediction = CounterfactualPrediction(inputs, outputs, domain, constraints, None, uuid.uuid4())" |
| 322 | + "prediction = CounterfactualPrediction(original, goals, domain, constraints, None, uuid.uuid4())" |
| 323 | + ] |
| 324 | + }, |
| 325 | + { |
| 326 | + "cell_type": "markdown", |
| 327 | + "id": "6d593f4f", |
| 328 | + "metadata": {}, |
| 329 | + "source": [ |
| 330 | + "We now request the counterfactual $\\mathbf{x'}$ which is closest to $\\mathbf{x}$ and which satisfies $f(\\mathbf{x'}, \\epsilon, \\mathbf{C})=\\mathbf{y'}$:" |
233 | 331 | ] |
234 | 332 | }, |
235 | 333 | { |
236 | 334 | "cell_type": "code", |
237 | | - "execution_count": null, |
238 | | - "id": "910a250f", |
| 335 | + "execution_count": 22, |
| 336 | + "id": "e5783b3d", |
239 | 337 | "metadata": {}, |
240 | 338 | "outputs": [], |
241 | 339 | "source": [ |
242 | 340 | "explanation_async = explainer.explainAsync(prediction, model)" |
243 | 341 | ] |
244 | 342 | }, |
| 343 | + { |
| 344 | + "cell_type": "markdown", |
| 345 | + "id": "b2af6cb4", |
| 346 | + "metadata": {}, |
| 347 | + "source": [ |
| 348 | + "The counterfactual explainer API operates in a asynchronous way, so we need to `.get()` the result:" |
| 349 | + ] |
| 350 | + }, |
245 | 351 | { |
246 | 352 | "cell_type": "code", |
247 | | - "execution_count": null, |
248 | | - "id": "38774822", |
| 353 | + "execution_count": 23, |
| 354 | + "id": "cc2ad21e", |
249 | 355 | "metadata": {}, |
250 | 356 | "outputs": [], |
251 | 357 | "source": [ |
252 | 358 | "explanation = explanation_async.get()" |
253 | 359 | ] |
254 | 360 | }, |
| 361 | + { |
| 362 | + "cell_type": "markdown", |
| 363 | + "id": "7fcfb591", |
| 364 | + "metadata": {}, |
| 365 | + "source": [ |
| 366 | + "We can see that the counterfactual $\\mathbf{x'}$" |
| 367 | + ] |
| 368 | + }, |
255 | 369 | { |
256 | 370 | "cell_type": "code", |
257 | | - "execution_count": null, |
258 | | - "id": "7cb95b8c", |
| 371 | + "execution_count": 25, |
| 372 | + "id": "6f1e04c1", |
259 | 373 | "metadata": {}, |
260 | | - "outputs": [], |
| 374 | + "outputs": [ |
| 375 | + { |
| 376 | + "name": "stdout", |
| 377 | + "output_type": "stream", |
| 378 | + "text": [ |
| 379 | + "java.lang.DoubleFeature{value=490.4373902874999, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='f-num1'}\n", |
| 380 | + "java.lang.DoubleFeature{value=2.420079314517709, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='f-num2'}\n", |
| 381 | + "java.lang.DoubleFeature{value=5.759573701749472, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='f-num3'}\n", |
| 382 | + "java.lang.DoubleFeature{value=0.8173260627331469, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='f-num4'}\n" |
| 383 | + ] |
| 384 | + } |
| 385 | + ], |
261 | 386 | "source": [ |
| 387 | + "feature_sum = 0.0\n", |
262 | 388 | "for entity in explanation.getEntities():\n", |
263 | | - " print(entity)" |
| 389 | + " print(entity)\n", |
| 390 | + " feature_sum += entity.getValue().asNumber()\n", |
| 391 | + " \n", |
| 392 | + "print(f\"\\nFeature sum is {fe}\")" |
264 | 393 | ] |
265 | 394 | }, |
266 | 395 | { |
267 | 396 | "cell_type": "code", |
268 | 397 | "execution_count": null, |
269 | | - "id": "7a8587d1", |
| 398 | + "id": "b49d9c1c", |
270 | 399 | "metadata": {}, |
271 | 400 | "outputs": [], |
272 | 401 | "source": [] |
|
0 commit comments