|
16 | 16 | "outputs": [],
|
17 | 17 | "source": [
|
18 | 18 | "import trustyai\n",
|
19 |
| - "import os\n", |
20 | 19 | "\n",
|
21 |
| - "pwd = os.path.abspath('')\n", |
22 |
| - "\n", |
23 |
| - "trustyai.init(\n", |
24 |
| - " path = [\n", |
25 |
| - " pwd + \"/../dep/org/kie/kogito/explainability-core/1.16.0.Final/*\",\n", |
26 |
| - " pwd + \"/../dep/org/slf4j/slf4j-api/1.7.30/slf4j-api-1.7.30.jar\",\n", |
27 |
| - " pwd + \"/../dep/org/apache/commons/commons-lang3/3.12.0/commons-lang3-3.12.0.jar\",\n", |
28 |
| - " pwd + \"/../dep/org/optaplanner/optaplanner-core/8.16.0.Final/*\",\n", |
29 |
| - " pwd + \"/../dep/org/apache/commons/commons-math3/3.6.1/commons-math3-3.6.1.jar\",\n", |
30 |
| - " pwd + \"/../dep/org/kie/kie-api/8.16.0.Beta/*\",\n", |
31 |
| - " pwd + \"/../dep/io/micrometer/micrometer-core/1.8.1/*\",\n", |
32 |
| - " ]\n", |
33 |
| - ")" |
| 20 | + "trustyai.init()" |
34 | 21 | ]
|
35 | 22 | },
|
36 | 23 | {
|
|
145 | 132 | "name": "stdout",
|
146 | 133 | "output_type": "stream",
|
147 | 134 | "text": [
|
148 |
| - "Feature x1 has value 1.0380081180194922\n", |
149 |
| - "Feature x2 has value 9.63798897368338\n", |
150 |
| - "Feature x3 has value 3.5581213895741595\n", |
151 |
| - "Feature x4 has value 5.670566321767616\n", |
| 135 | + "Feature x1 has value 8.383112183880774\n", |
| 136 | + "Feature x2 has value 1.5296907450266672\n", |
| 137 | + "Feature x3 has value 6.324409811461678\n", |
| 138 | + "Feature x4 has value 3.0125007391588077\n", |
152 | 139 | "\n",
|
153 |
| - "Features sum is 19.904684803044645\n" |
| 140 | + "Features sum is 19.249713479527923\n" |
154 | 141 | ]
|
155 | 142 | }
|
156 | 143 | ],
|
|
266 | 253 | "metadata": {},
|
267 | 254 | "outputs": [],
|
268 | 255 | "source": [
|
269 |
| - "import uuid\n", |
270 |
| - "from trustyai.local.counterfactual import counterfactual_prediction\n", |
| 256 | + "from trustyai.model import counterfactual_prediction\n", |
271 | 257 | "\n",
|
272 | 258 | "prediction = counterfactual_prediction(\n",
|
273 | 259 | " input_features=features,\n",
|
|
311 | 297 | "name": "stdout",
|
312 | 298 | "output_type": "stream",
|
313 | 299 | "text": [
|
314 |
| - "java.lang.DoubleEntity{value=1.0380081180194922, rangeMinimum=0.0, rangeMaximum=1000.0, id='x1'}\n", |
315 |
| - "java.lang.DoubleEntity{value=484.86792394543664, rangeMinimum=0.0, rangeMaximum=1000.0, id='x2'}\n", |
316 |
| - "java.lang.DoubleEntity{value=3.5581213895741595, rangeMinimum=0.0, rangeMaximum=1000.0, id='x3'}\n", |
317 |
| - "java.lang.DoubleEntity{value=9.63798897368338, rangeMinimum=0.0, rangeMaximum=1000.0, id='x4'}\n", |
| 300 | + "java.lang.DoubleEntity{value=8.383112183880774, rangeMinimum=0.0, rangeMaximum=1000.0, id='x1'}\n", |
| 301 | + "java.lang.DoubleEntity{value=481.32000081209515, rangeMinimum=0.0, rangeMaximum=1000.0, id='x2'}\n", |
| 302 | + "java.lang.DoubleEntity{value=6.495107736558037, rangeMinimum=0.0, rangeMaximum=1000.0, id='x3'}\n", |
| 303 | + "java.lang.DoubleEntity{value=3.0125007391588077, rangeMinimum=0.0, rangeMaximum=1000.0, id='x4'}\n", |
318 | 304 | "\n",
|
319 |
| - "Feature sum is 499.1020424267137\n" |
| 305 | + "Feature sum is 499.21072147169275\n" |
320 | 306 | ]
|
321 | 307 | }
|
322 | 308 | ],
|
|
409 | 395 | "name": "stdout",
|
410 | 396 | "output_type": "stream",
|
411 | 397 | "text": [
|
412 |
| - "Original x1: 1.0380081180194922\n", |
413 |
| - "Original x4: 5.670566321767616\n", |
| 398 | + "Original x1: 8.383112183880774\n", |
| 399 | + "Original x4: 3.0125007391588077\n", |
414 | 400 | "\n",
|
415 |
| - "java.lang.DoubleEntity{value=1.0380081180194922, rangeMinimum=1.0380081180194922, rangeMaximum=1.0380081180194922, id='x1'}\n", |
416 |
| - "java.lang.DoubleEntity{value=488.87828084645156, rangeMinimum=0.0, rangeMaximum=1000.0, id='x2'}\n", |
417 |
| - "java.lang.DoubleEntity{value=4.2932049776029935, rangeMinimum=0.0, rangeMaximum=1000.0, id='x3'}\n", |
418 |
| - "java.lang.DoubleEntity{value=5.670566321767616, rangeMinimum=5.670566321767616, rangeMaximum=5.670566321767616, id='x4'}\n" |
| 401 | + "java.lang.DoubleEntity{value=8.383112183880774, rangeMinimum=8.383112183880774, rangeMaximum=8.383112183880774, id='x1'}\n", |
| 402 | + "java.lang.DoubleEntity{value=481.32000081209515, rangeMinimum=0.0, rangeMaximum=1000.0, id='x2'}\n", |
| 403 | + "java.lang.DoubleEntity{value=6.495107736558037, rangeMinimum=0.0, rangeMaximum=1000.0, id='x3'}\n", |
| 404 | + "java.lang.DoubleEntity{value=3.0125007391588077, rangeMinimum=3.0125007391588077, rangeMaximum=3.0125007391588077, id='x4'}\n" |
419 | 405 | ]
|
420 | 406 | }
|
421 | 407 | ],
|
|
1283 | 1269 | },
|
1284 | 1270 | {
|
1285 | 1271 | "cell_type": "code",
|
1286 |
| - "execution_count": 52, |
| 1272 | + "execution_count": 40, |
1287 | 1273 | "outputs": [],
|
1288 | 1274 | "source": [
|
1289 | 1275 | "from sklearn.datasets import make_blobs\n",
|
|
1299 | 1285 | },
|
1300 | 1286 | {
|
1301 | 1287 | "cell_type": "code",
|
1302 |
| - "execution_count": 53, |
| 1288 | + "execution_count": 41, |
1303 | 1289 | "outputs": [
|
1304 | 1290 | {
|
1305 | 1291 | "data": {
|
|
1340 | 1326 | },
|
1341 | 1327 | {
|
1342 | 1328 | "cell_type": "code",
|
1343 |
| - "execution_count": 54, |
| 1329 | + "execution_count": 42, |
1344 | 1330 | "outputs": [
|
1345 | 1331 | {
|
1346 | 1332 | "data": {
|
1347 | 1333 | "text/plain": "KNeighborsClassifier(n_neighbors=3)"
|
1348 | 1334 | },
|
1349 |
| - "execution_count": 54, |
| 1335 | + "execution_count": 42, |
1350 | 1336 | "metadata": {},
|
1351 | 1337 | "output_type": "execute_result"
|
1352 | 1338 | }
|
|
1378 | 1364 | },
|
1379 | 1365 | {
|
1380 | 1366 | "cell_type": "code",
|
1381 |
| - "execution_count": 56, |
| 1367 | + "execution_count": 43, |
1382 | 1368 | "outputs": [
|
1383 | 1369 | {
|
1384 | 1370 | "data": {
|
|
1409 | 1395 | },
|
1410 | 1396 | {
|
1411 | 1397 | "cell_type": "code",
|
1412 |
| - "execution_count": 57, |
| 1398 | + "execution_count": 44, |
1413 | 1399 | "outputs": [
|
1414 | 1400 | {
|
1415 | 1401 | "data": {
|
1416 | 1402 | "text/plain": "array([[0., 0., 1.]])"
|
1417 | 1403 | },
|
1418 |
| - "execution_count": 57, |
| 1404 | + "execution_count": 44, |
1419 | 1405 | "metadata": {},
|
1420 | 1406 | "output_type": "execute_result"
|
1421 | 1407 | }
|
|
1447 | 1433 | },
|
1448 | 1434 | {
|
1449 | 1435 | "cell_type": "code",
|
1450 |
| - "execution_count": 58, |
| 1436 | + "execution_count": 45, |
1451 | 1437 | "outputs": [],
|
1452 | 1438 | "source": [
|
1453 | 1439 | "def knn_classify(inputs):\n",
|
|
1478 | 1464 | },
|
1479 | 1465 | {
|
1480 | 1466 | "cell_type": "code",
|
1481 |
| - "execution_count": 59, |
| 1467 | + "execution_count": 46, |
1482 | 1468 | "outputs": [],
|
1483 | 1469 | "source": [
|
1484 | 1470 | "from trustyai.model import Model\n",
|
|
1506 | 1492 | },
|
1507 | 1493 | {
|
1508 | 1494 | "cell_type": "code",
|
1509 |
| - "execution_count": 60, |
| 1495 | + "execution_count": 47, |
1510 | 1496 | "outputs": [
|
1511 | 1497 | {
|
1512 | 1498 | "data": {
|
1513 | 1499 | "text/plain": "'Output{value=2, type=number, score=1.0, name='cluster'}'"
|
1514 | 1500 | },
|
1515 |
| - "execution_count": 60, |
| 1501 | + "execution_count": 47, |
1516 | 1502 | "metadata": {},
|
1517 | 1503 | "output_type": "execute_result"
|
1518 | 1504 | }
|
|
1542 | 1528 | },
|
1543 | 1529 | {
|
1544 | 1530 | "cell_type": "code",
|
1545 |
| - "execution_count": 61, |
| 1531 | + "execution_count": 48, |
1546 | 1532 | "outputs": [],
|
1547 | 1533 | "source": [
|
1548 | 1534 | "goal = [output(name=\"cluster\", dtype=\"number\", value=1)]"
|
|
1568 | 1554 | },
|
1569 | 1555 | {
|
1570 | 1556 | "cell_type": "code",
|
1571 |
| - "execution_count": 62, |
| 1557 | + "execution_count": 49, |
1572 | 1558 | "outputs": [],
|
1573 | 1559 | "source": [
|
1574 | 1560 | "prediction = counterfactual_prediction(\n",
|
|
1599 | 1585 | },
|
1600 | 1586 | {
|
1601 | 1587 | "cell_type": "code",
|
1602 |
| - "execution_count": 63, |
| 1588 | + "execution_count": 50, |
1603 | 1589 | "outputs": [
|
1604 | 1590 | {
|
1605 | 1591 | "data": {
|
1606 | 1592 | "text/plain": "[2.501921601686025, 2.6401536249215436]"
|
1607 | 1593 | },
|
1608 |
| - "execution_count": 63, |
| 1594 | + "execution_count": 50, |
1609 | 1595 | "metadata": {},
|
1610 | 1596 | "output_type": "execute_result"
|
1611 | 1597 | }
|
|
1635 | 1621 | },
|
1636 | 1622 | {
|
1637 | 1623 | "cell_type": "code",
|
1638 |
| - "execution_count": 64, |
| 1624 | + "execution_count": 51, |
1639 | 1625 | "outputs": [
|
1640 | 1626 | {
|
1641 | 1627 | "data": {
|
1642 | 1628 | "text/plain": "'Output{value=1, type=number, score=1.0, name='cluster'}'"
|
1643 | 1629 | },
|
1644 |
| - "execution_count": 64, |
| 1630 | + "execution_count": 51, |
1645 | 1631 | "metadata": {},
|
1646 | 1632 | "output_type": "execute_result"
|
1647 | 1633 | }
|
|
1671 | 1657 | },
|
1672 | 1658 | {
|
1673 | 1659 | "cell_type": "code",
|
1674 |
| - "execution_count": 65, |
| 1660 | + "execution_count": 52, |
1675 | 1661 | "outputs": [
|
1676 | 1662 | {
|
1677 | 1663 | "data": {
|
|
1700 | 1686 | "name": "#%%\n"
|
1701 | 1687 | }
|
1702 | 1688 | }
|
1703 |
| - }, |
1704 |
| - { |
1705 |
| - "cell_type": "code", |
1706 |
| - "execution_count": null, |
1707 |
| - "outputs": [], |
1708 |
| - "source": [], |
1709 |
| - "metadata": { |
1710 |
| - "collapsed": false, |
1711 |
| - "pycharm": { |
1712 |
| - "name": "#%%\n" |
1713 |
| - } |
1714 |
| - } |
1715 | 1689 | }
|
1716 | 1690 | ],
|
1717 | 1691 | "metadata": {
|
|
0 commit comments