Skip to content

Commit ff43bd5

Browse files
committed
Update examples
1 parent 648d9af commit ff43bd5

File tree

2 files changed

+201
-138
lines changed

2 files changed

+201
-138
lines changed

examples/Counterfactuals.ipynb

Lines changed: 36 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,8 @@
1616
"outputs": [],
1717
"source": [
1818
"import trustyai\n",
19-
"import os\n",
2019
"\n",
21-
"pwd = os.path.abspath('')\n",
22-
"\n",
23-
"trustyai.init(\n",
24-
" path = [\n",
25-
" pwd + \"/../dep/org/kie/kogito/explainability-core/1.16.0.Final/*\",\n",
26-
" pwd + \"/../dep/org/slf4j/slf4j-api/1.7.30/slf4j-api-1.7.30.jar\",\n",
27-
" pwd + \"/../dep/org/apache/commons/commons-lang3/3.12.0/commons-lang3-3.12.0.jar\",\n",
28-
" pwd + \"/../dep/org/optaplanner/optaplanner-core/8.16.0.Final/*\",\n",
29-
" pwd + \"/../dep/org/apache/commons/commons-math3/3.6.1/commons-math3-3.6.1.jar\",\n",
30-
" pwd + \"/../dep/org/kie/kie-api/8.16.0.Beta/*\",\n",
31-
" pwd + \"/../dep/io/micrometer/micrometer-core/1.8.1/*\",\n",
32-
" ]\n",
33-
")"
20+
"trustyai.init()"
3421
]
3522
},
3623
{
@@ -145,12 +132,12 @@
145132
"name": "stdout",
146133
"output_type": "stream",
147134
"text": [
148-
"Feature x1 has value 1.0380081180194922\n",
149-
"Feature x2 has value 9.63798897368338\n",
150-
"Feature x3 has value 3.5581213895741595\n",
151-
"Feature x4 has value 5.670566321767616\n",
135+
"Feature x1 has value 8.383112183880774\n",
136+
"Feature x2 has value 1.5296907450266672\n",
137+
"Feature x3 has value 6.324409811461678\n",
138+
"Feature x4 has value 3.0125007391588077\n",
152139
"\n",
153-
"Features sum is 19.904684803044645\n"
140+
"Features sum is 19.249713479527923\n"
154141
]
155142
}
156143
],
@@ -266,8 +253,7 @@
266253
"metadata": {},
267254
"outputs": [],
268255
"source": [
269-
"import uuid\n",
270-
"from trustyai.local.counterfactual import counterfactual_prediction\n",
256+
"from trustyai.model import counterfactual_prediction\n",
271257
"\n",
272258
"prediction = counterfactual_prediction(\n",
273259
" input_features=features,\n",
@@ -311,12 +297,12 @@
311297
"name": "stdout",
312298
"output_type": "stream",
313299
"text": [
314-
"java.lang.DoubleEntity{value=1.0380081180194922, rangeMinimum=0.0, rangeMaximum=1000.0, id='x1'}\n",
315-
"java.lang.DoubleEntity{value=484.86792394543664, rangeMinimum=0.0, rangeMaximum=1000.0, id='x2'}\n",
316-
"java.lang.DoubleEntity{value=3.5581213895741595, rangeMinimum=0.0, rangeMaximum=1000.0, id='x3'}\n",
317-
"java.lang.DoubleEntity{value=9.63798897368338, rangeMinimum=0.0, rangeMaximum=1000.0, id='x4'}\n",
300+
"java.lang.DoubleEntity{value=8.383112183880774, rangeMinimum=0.0, rangeMaximum=1000.0, id='x1'}\n",
301+
"java.lang.DoubleEntity{value=481.32000081209515, rangeMinimum=0.0, rangeMaximum=1000.0, id='x2'}\n",
302+
"java.lang.DoubleEntity{value=6.495107736558037, rangeMinimum=0.0, rangeMaximum=1000.0, id='x3'}\n",
303+
"java.lang.DoubleEntity{value=3.0125007391588077, rangeMinimum=0.0, rangeMaximum=1000.0, id='x4'}\n",
318304
"\n",
319-
"Feature sum is 499.1020424267137\n"
305+
"Feature sum is 499.21072147169275\n"
320306
]
321307
}
322308
],
@@ -409,13 +395,13 @@
409395
"name": "stdout",
410396
"output_type": "stream",
411397
"text": [
412-
"Original x1: 1.0380081180194922\n",
413-
"Original x4: 5.670566321767616\n",
398+
"Original x1: 8.383112183880774\n",
399+
"Original x4: 3.0125007391588077\n",
414400
"\n",
415-
"java.lang.DoubleEntity{value=1.0380081180194922, rangeMinimum=1.0380081180194922, rangeMaximum=1.0380081180194922, id='x1'}\n",
416-
"java.lang.DoubleEntity{value=488.87828084645156, rangeMinimum=0.0, rangeMaximum=1000.0, id='x2'}\n",
417-
"java.lang.DoubleEntity{value=4.2932049776029935, rangeMinimum=0.0, rangeMaximum=1000.0, id='x3'}\n",
418-
"java.lang.DoubleEntity{value=5.670566321767616, rangeMinimum=5.670566321767616, rangeMaximum=5.670566321767616, id='x4'}\n"
401+
"java.lang.DoubleEntity{value=8.383112183880774, rangeMinimum=8.383112183880774, rangeMaximum=8.383112183880774, id='x1'}\n",
402+
"java.lang.DoubleEntity{value=481.32000081209515, rangeMinimum=0.0, rangeMaximum=1000.0, id='x2'}\n",
403+
"java.lang.DoubleEntity{value=6.495107736558037, rangeMinimum=0.0, rangeMaximum=1000.0, id='x3'}\n",
404+
"java.lang.DoubleEntity{value=3.0125007391588077, rangeMinimum=3.0125007391588077, rangeMaximum=3.0125007391588077, id='x4'}\n"
419405
]
420406
}
421407
],
@@ -1283,7 +1269,7 @@
12831269
},
12841270
{
12851271
"cell_type": "code",
1286-
"execution_count": 52,
1272+
"execution_count": 40,
12871273
"outputs": [],
12881274
"source": [
12891275
"from sklearn.datasets import make_blobs\n",
@@ -1299,7 +1285,7 @@
12991285
},
13001286
{
13011287
"cell_type": "code",
1302-
"execution_count": 53,
1288+
"execution_count": 41,
13031289
"outputs": [
13041290
{
13051291
"data": {
@@ -1340,13 +1326,13 @@
13401326
},
13411327
{
13421328
"cell_type": "code",
1343-
"execution_count": 54,
1329+
"execution_count": 42,
13441330
"outputs": [
13451331
{
13461332
"data": {
13471333
"text/plain": "KNeighborsClassifier(n_neighbors=3)"
13481334
},
1349-
"execution_count": 54,
1335+
"execution_count": 42,
13501336
"metadata": {},
13511337
"output_type": "execute_result"
13521338
}
@@ -1378,7 +1364,7 @@
13781364
},
13791365
{
13801366
"cell_type": "code",
1381-
"execution_count": 56,
1367+
"execution_count": 43,
13821368
"outputs": [
13831369
{
13841370
"data": {
@@ -1409,13 +1395,13 @@
14091395
},
14101396
{
14111397
"cell_type": "code",
1412-
"execution_count": 57,
1398+
"execution_count": 44,
14131399
"outputs": [
14141400
{
14151401
"data": {
14161402
"text/plain": "array([[0., 0., 1.]])"
14171403
},
1418-
"execution_count": 57,
1404+
"execution_count": 44,
14191405
"metadata": {},
14201406
"output_type": "execute_result"
14211407
}
@@ -1447,7 +1433,7 @@
14471433
},
14481434
{
14491435
"cell_type": "code",
1450-
"execution_count": 58,
1436+
"execution_count": 45,
14511437
"outputs": [],
14521438
"source": [
14531439
"def knn_classify(inputs):\n",
@@ -1478,7 +1464,7 @@
14781464
},
14791465
{
14801466
"cell_type": "code",
1481-
"execution_count": 59,
1467+
"execution_count": 46,
14821468
"outputs": [],
14831469
"source": [
14841470
"from trustyai.model import Model\n",
@@ -1506,13 +1492,13 @@
15061492
},
15071493
{
15081494
"cell_type": "code",
1509-
"execution_count": 60,
1495+
"execution_count": 47,
15101496
"outputs": [
15111497
{
15121498
"data": {
15131499
"text/plain": "'Output{value=2, type=number, score=1.0, name='cluster'}'"
15141500
},
1515-
"execution_count": 60,
1501+
"execution_count": 47,
15161502
"metadata": {},
15171503
"output_type": "execute_result"
15181504
}
@@ -1542,7 +1528,7 @@
15421528
},
15431529
{
15441530
"cell_type": "code",
1545-
"execution_count": 61,
1531+
"execution_count": 48,
15461532
"outputs": [],
15471533
"source": [
15481534
"goal = [output(name=\"cluster\", dtype=\"number\", value=1)]"
@@ -1568,7 +1554,7 @@
15681554
},
15691555
{
15701556
"cell_type": "code",
1571-
"execution_count": 62,
1557+
"execution_count": 49,
15721558
"outputs": [],
15731559
"source": [
15741560
"prediction = counterfactual_prediction(\n",
@@ -1599,13 +1585,13 @@
15991585
},
16001586
{
16011587
"cell_type": "code",
1602-
"execution_count": 63,
1588+
"execution_count": 50,
16031589
"outputs": [
16041590
{
16051591
"data": {
16061592
"text/plain": "[2.501921601686025, 2.6401536249215436]"
16071593
},
1608-
"execution_count": 63,
1594+
"execution_count": 50,
16091595
"metadata": {},
16101596
"output_type": "execute_result"
16111597
}
@@ -1635,13 +1621,13 @@
16351621
},
16361622
{
16371623
"cell_type": "code",
1638-
"execution_count": 64,
1624+
"execution_count": 51,
16391625
"outputs": [
16401626
{
16411627
"data": {
16421628
"text/plain": "'Output{value=1, type=number, score=1.0, name='cluster'}'"
16431629
},
1644-
"execution_count": 64,
1630+
"execution_count": 51,
16451631
"metadata": {},
16461632
"output_type": "execute_result"
16471633
}
@@ -1671,7 +1657,7 @@
16711657
},
16721658
{
16731659
"cell_type": "code",
1674-
"execution_count": 65,
1660+
"execution_count": 52,
16751661
"outputs": [
16761662
{
16771663
"data": {
@@ -1700,18 +1686,6 @@
17001686
"name": "#%%\n"
17011687
}
17021688
}
1703-
},
1704-
{
1705-
"cell_type": "code",
1706-
"execution_count": null,
1707-
"outputs": [],
1708-
"source": [],
1709-
"metadata": {
1710-
"collapsed": false,
1711-
"pycharm": {
1712-
"name": "#%%\n"
1713-
}
1714-
}
17151689
}
17161690
],
17171691
"metadata": {

0 commit comments

Comments
 (0)