Skip to content

Commit 28e1897

Browse files
committed
Merge branch 'notebooks'
2 parents 9d34f02 + 09e684d commit 28e1897

File tree

5 files changed

+319
-1
lines changed

5 files changed

+319
-1
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,7 @@ You can see the `sumSkipModel` in the [LIME tests](./tests/test_limeexplainer.py
6262
## Examples
6363

6464
You can look at the [tests](./tests) for working examples.
65+
66+
There are also Jupyter notebooks available:
67+
68+
- [Counterfactual explanations](notebooks/Counterfactuals.ipynb)

notebooks/Counterfactuals.ipynb

Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "8f0f2186",
6+
"metadata": {},
7+
"source": [
8+
"# Counterfactual explanations"
9+
]
10+
},
11+
{
12+
"cell_type": "code",
13+
"execution_count": 1,
14+
"id": "569777b3",
15+
"metadata": {},
16+
"outputs": [],
17+
"source": [
18+
"import trustyai\n",
19+
"\n",
20+
"trustyai.init(\n",
21+
" path=[\n",
22+
" \"../dep/org/kie/kogito/explainability-core/1.8.0.Final/*\",\n",
23+
" \"../dep/org/slf4j/slf4j-api/1.7.30/slf4j-api-1.7.30.jar\",\n",
24+
" \"../dep/org/apache/commons/commons-lang3/3.12.0/commons-lang3-3.12.0.jar\",\n",
25+
" \"../dep/org/optaplanner/optaplanner-core/8.8.0.Final/optaplanner-core-8.8.0.Final.jar\",\n",
26+
" \"../dep/org/apache/commons/commons-math3/3.6.1/commons-math3-3.6.1.jar\",\n",
27+
" \"../dep/org/kie/kie-api/7.55.0.Final/kie-api-7.55.0.Final.jar\",\n",
28+
" \"../dep/io/micrometer/micrometer-core/1.6.6/micrometer-core-1.6.6.jar\",\n",
29+
" ]\n",
30+
")"
31+
]
32+
},
33+
{
34+
"cell_type": "markdown",
35+
"id": "512462ee",
36+
"metadata": {},
37+
"source": [
38+
"## Simple example\n",
39+
"\n",
40+
"We start by defining our black-box model, typically represented by\n",
41+
"\n",
42+
"$$\n",
43+
"f(\\mathbf{x}) = \\mathbf{y}\n",
44+
"$$\n",
45+
"\n",
46+
"Where $\\mathbf{x}=\\{x_1, x_2, \\dots,x_m\\}$ and $\\mathbf{y}=\\{y_1, y_2, \\dots,y_n\\}$.\n",
47+
"\n",
48+
"Our example toy model, in this case, takes an all-numerical input $\\mathbf{x}$ and return a $\\mathbf{y}$ of either `true` or `false` if the sum of the $\\mathbf{x}$ components is within a threshold $\\epsilon$ of a point $\\mathbf{C}$, that is:\n",
49+
"\n",
50+
"$$\n",
51+
"f(\\mathbf{x}, \\epsilon, \\mathbf{C})=\\begin{cases}\n",
52+
"\\text{true},\\qquad \\text{if}\\ \\mathbf{C}-\\epsilon<\\sum_{i=1}^m x_i <\\mathbf{C}+\\epsilon \\\\\n",
53+
"\\text{false},\\qquad \\text{otherwise}\n",
54+
"\\end{cases}\n",
55+
"$$\n",
56+
"\n",
57+
"This model is provided in the `TestUtils` module. We instantiate with a $\\mathbf{C}=500$ and $\\epsilon=1.0$."
58+
]
59+
},
60+
{
61+
"cell_type": "code",
62+
"execution_count": 4,
63+
"id": "e4f89877",
64+
"metadata": {},
65+
"outputs": [],
66+
"source": [
67+
"from trustyai.utils import TestUtils\n",
68+
"\n",
69+
"center = 500.0\n",
70+
"epsilon = 10.0\n",
71+
"\n",
72+
"model = TestUtils.getSumThresholdModel(center, epsilon)"
73+
]
74+
},
75+
{
76+
"cell_type": "markdown",
77+
"id": "f0bb1cc2",
78+
"metadata": {},
79+
"source": [
80+
"Next we need to define a **goal**.\n",
81+
"If our model is $f(\\mathbf{x'})=\\mathbf{y'}$ we are then defining our $\\mathbf{y'}$ and the counterfactual result will be the $\\mathbf{x'}$ which satisfies $f(\\mathbf{x'})=\\mathbf{y'}$.\n",
82+
"\n",
83+
"We will define our goal as `true`, that is, the sum is withing the vicinity of a (to be defined) point $\\mathbf{C}$. The goal is a list of `Output` which take the following parameters\n",
84+
"\n",
85+
"- The feature name\n",
86+
"- The feature type\n",
87+
"- The feature value (wrapped in `Value`)\n",
88+
"- A confidence threshold, which we will leave at zero (no threshold)"
89+
]
90+
},
91+
{
92+
"cell_type": "code",
93+
"execution_count": 5,
94+
"id": "5bcb0105",
95+
"metadata": {},
96+
"outputs": [],
97+
"source": [
98+
"from trustyai.model import Output, Type, Value\n",
99+
"\n",
100+
"goal = [Output(\"inside\", Type.BOOLEAN, Value(True), 0.0)]"
101+
]
102+
},
103+
{
104+
"cell_type": "code",
105+
"execution_count": null,
106+
"id": "6aa524ae",
107+
"metadata": {},
108+
"outputs": [],
109+
"source": [
110+
"import random\n",
111+
"from trustyai.model import FeatureFactory\n",
112+
"\n",
113+
"features = [FeatureFactory.newNumericalFeature(f\"f-num{i+1}\", random.random()*10.0) for i in range(4)]\n",
114+
"\n",
115+
"for f in features:\n",
116+
" print(f\"Feature {f.getName()} has value {f.getValue()}\")"
117+
]
118+
},
119+
{
120+
"cell_type": "code",
121+
"execution_count": null,
122+
"id": "513d2e5a",
123+
"metadata": {},
124+
"outputs": [],
125+
"source": [
126+
"constraints = [False] * 4"
127+
]
128+
},
129+
{
130+
"cell_type": "code",
131+
"execution_count": null,
132+
"id": "30dcc15b",
133+
"metadata": {},
134+
"outputs": [],
135+
"source": [
136+
"from trustyai.model.domain import NumericalFeatureDomain\n",
137+
"\n",
138+
"feature_boundaries = [NumericalFeatureDomain.create(0.0, 1000.0)] * 4"
139+
]
140+
},
141+
{
142+
"cell_type": "code",
143+
"execution_count": null,
144+
"id": "5047e075",
145+
"metadata": {},
146+
"outputs": [],
147+
"source": [
148+
"from trustyai.model import DataDomain\n",
149+
"\n",
150+
"data_domain = DataDomain(feature_boundaries)"
151+
]
152+
},
153+
{
154+
"cell_type": "code",
155+
"execution_count": null,
156+
"id": "e1b0da83",
157+
"metadata": {},
158+
"outputs": [],
159+
"source": [
160+
"center = 500.0\n",
161+
"epsilon = 10.0"
162+
]
163+
},
164+
{
165+
"cell_type": "code",
166+
"execution_count": null,
167+
"id": "510b3b16",
168+
"metadata": {},
169+
"outputs": [],
170+
"source": [
171+
"from trustyai.utils import TestUtils\n",
172+
"\n",
173+
"model = TestUtils.getSumThresholdModel(center, epsilon)"
174+
]
175+
},
176+
{
177+
"cell_type": "code",
178+
"execution_count": null,
179+
"id": "bcd25df0",
180+
"metadata": {},
181+
"outputs": [],
182+
"source": [
183+
"from org.optaplanner.core.config.solver.termination import TerminationConfig\n",
184+
"from org.kie.kogito.explainability.local.counterfactual import CounterfactualConfigurationFactory\n",
185+
"from java.lang import Long\n",
186+
"\n",
187+
"termination_config = TerminationConfig().withScoreCalculationCountLimit(Long.valueOf(10_000))\n",
188+
"\n",
189+
"solver_config = (\n",
190+
" CounterfactualConfigurationFactory.builder()\n",
191+
" .withTerminationConfig(termination_config)\n",
192+
" .build()\n",
193+
" )"
194+
]
195+
},
196+
{
197+
"cell_type": "code",
198+
"execution_count": null,
199+
"id": "c2b76274",
200+
"metadata": {},
201+
"outputs": [],
202+
"source": [
203+
"from org.kie.kogito.explainability.local.counterfactual import CounterfactualExplainer\n",
204+
"\n",
205+
"explainer = CounterfactualExplainer.builder().withSolverConfig(solver_config).build()"
206+
]
207+
},
208+
{
209+
"cell_type": "code",
210+
"execution_count": null,
211+
"id": "4cff79cd",
212+
"metadata": {},
213+
"outputs": [],
214+
"source": [
215+
"from trustyai.model import PredictionFeatureDomain, PredictionInput, PredictionOutput\n",
216+
"\n",
217+
"inputs = PredictionInput(features)\n",
218+
"outputs = PredictionOutput(goal)\n",
219+
"domain = PredictionFeatureDomain(data_domain.getFeatureDomains())"
220+
]
221+
},
222+
{
223+
"cell_type": "code",
224+
"execution_count": null,
225+
"id": "98057ebd",
226+
"metadata": {},
227+
"outputs": [],
228+
"source": [
229+
"import uuid\n",
230+
"from trustyai.model import CounterfactualPrediction\n",
231+
"\n",
232+
"prediction = CounterfactualPrediction(inputs, outputs, domain, constraints, None, uuid.uuid4())"
233+
]
234+
},
235+
{
236+
"cell_type": "code",
237+
"execution_count": null,
238+
"id": "910a250f",
239+
"metadata": {},
240+
"outputs": [],
241+
"source": [
242+
"explanation_async = explainer.explainAsync(prediction, model)"
243+
]
244+
},
245+
{
246+
"cell_type": "code",
247+
"execution_count": null,
248+
"id": "38774822",
249+
"metadata": {},
250+
"outputs": [],
251+
"source": [
252+
"explanation = explanation_async.get()"
253+
]
254+
},
255+
{
256+
"cell_type": "code",
257+
"execution_count": null,
258+
"id": "7cb95b8c",
259+
"metadata": {},
260+
"outputs": [],
261+
"source": [
262+
"for entity in explanation.getEntities():\n",
263+
" print(entity)"
264+
]
265+
},
266+
{
267+
"cell_type": "code",
268+
"execution_count": null,
269+
"id": "7a8587d1",
270+
"metadata": {},
271+
"outputs": [],
272+
"source": []
273+
}
274+
],
275+
"metadata": {
276+
"kernelspec": {
277+
"display_name": "python-trustyai",
278+
"language": "python",
279+
"name": "python-trustyai"
280+
},
281+
"language_info": {
282+
"codemirror_mode": {
283+
"name": "ipython",
284+
"version": 3
285+
},
286+
"file_extension": ".py",
287+
"mimetype": "text/x-python",
288+
"name": "python",
289+
"nbconvert_exporter": "python",
290+
"pygments_lexer": "ipython3",
291+
"version": "3.9.5"
292+
}
293+
},
294+
"nbformat": 4,
295+
"nbformat_minor": 5
296+
}

requirements-dev.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,6 @@ black
22
pandas
33
scikit-learn
44
pylint
5-
pytest
5+
pytest
6+
setuptools
7+
wheel

setup.cfg

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
[build-system]
2+
requires = ["setuptools", "wheel"]
3+
build-backend = "setuptools.build_meta"
4+
5+
[metadata]
6+
name = trustyai
7+
version = 0.0.1
8+
9+
[options]
10+
packages = find:
11+
include_package_data = False
12+
requires = ['Jpype1']
13+
python_requires = >=3.8

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import setuptools
2+
3+
setuptools.setup(packages=['trustyai', 'trustyai.local.counterfactual', 'trustyai.local.lime', 'trustyai.model', 'trustyai.utils'])

0 commit comments

Comments
 (0)