|
19 | 19 | "\n", |
20 | 20 | "trustyai.init(\n", |
21 | 21 | " path=[\n", |
22 | | - " \"../dep/org/kie/kogito/explainability-core/1.8.0.Final/*\",\n", |
| 22 | + " \"../dep/org/kie/kogito/explainability-core/1.8.0.Final/*\",\n", |
| 23 | + "# \"../dep/org/kie/kogito/explainability-core/1.8.0.Final/explainability-core-2.0.0-SNAPSHOT.jar\",\n", |
| 24 | + "# \"../dep/org/kie/kogito/explainability-core/1.8.0.Final/explainability-core-1.8.0.Final-tests.jar\",\n", |
23 | 25 | " \"../dep/org/slf4j/slf4j-api/1.7.30/slf4j-api-1.7.30.jar\",\n", |
24 | 26 | " \"../dep/org/apache/commons/commons-lang3/3.12.0/commons-lang3-3.12.0.jar\",\n", |
25 | 27 | " \"../dep/org/optaplanner/optaplanner-core/8.8.0.Final/optaplanner-core-8.8.0.Final.jar\",\n", |
|
139 | 141 | "name": "stdout", |
140 | 142 | "output_type": "stream", |
141 | 143 | "text": [ |
142 | | - "Feature x1 has value 5.154409874945087\n", |
143 | | - "Feature x2 has value 1.730733062926768\n", |
144 | | - "Feature x3 has value 9.000708957897421\n", |
145 | | - "Feature x4 has value 4.539576608157114\n", |
| 144 | + "Feature x1 has value 1.4465457485394606\n", |
| 145 | + "Feature x2 has value 9.904958794943276\n", |
| 146 | + "Feature x3 has value 8.632408661102822\n", |
| 147 | + "Feature x4 has value 2.666374576834393\n", |
146 | 148 | "\n", |
147 | | - "Features sum is 20.42542850392639\n" |
| 149 | + "Features sum is 22.65028778141995\n" |
148 | 150 | ] |
149 | 151 | } |
150 | 152 | ], |
|
376 | 378 | "name": "stdout", |
377 | 379 | "output_type": "stream", |
378 | 380 | "text": [ |
379 | | - "java.lang.DoubleFeature{value=5.154409874945087, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x1'}\n", |
380 | | - "java.lang.DoubleFeature{value=480.69885432449246, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x2'}\n", |
381 | | - "java.lang.DoubleFeature{value=9.000708957897421, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x3'}\n", |
382 | | - "java.lang.DoubleFeature{value=4.2932049776029935, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x4'}\n", |
| 381 | + "java.lang.DoubleFeature{value=1.547375925562755, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x1'}\n", |
| 382 | + "java.lang.DoubleFeature{value=486.3570827969562, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x2'}\n", |
| 383 | + "java.lang.DoubleFeature{value=8.632408661102822, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x3'}\n", |
| 384 | + "java.lang.DoubleFeature{value=2.666374576834393, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x4'}\n", |
383 | 385 | "\n", |
384 | | - "Feature sum is 499.147178134938\n" |
| 386 | + "Feature sum is 499.2032419604562\n" |
385 | 387 | ] |
386 | 388 | } |
387 | 389 | ], |
|
470 | 472 | "name": "stdout", |
471 | 473 | "output_type": "stream", |
472 | 474 | "text": [ |
473 | | - "Original x1: 5.154409874945087\n", |
474 | | - "Original x4: 4.539576608157114\n", |
| 475 | + "Original x1: 1.4465457485394606\n", |
| 476 | + "Original x4: 2.666374576834393\n", |
475 | 477 | "\n", |
476 | | - "java.lang.DoubleFeature{value=5.154409874945087, intRangeMinimum=5.154409874945087, intRangeMaximum=5.154409874945087, id='x1'}\n", |
477 | | - "java.lang.DoubleFeature{value=480.69885432449246, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x2'}\n", |
478 | | - "java.lang.DoubleFeature{value=9.291426877845232, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x3'}\n", |
479 | | - "java.lang.DoubleFeature{value=4.539576608157114, intRangeMinimum=4.539576608157114, intRangeMaximum=4.539576608157114, id='x4'}\n" |
| 478 | + "java.lang.DoubleFeature{value=1.4465457485394606, intRangeMinimum=1.4465457485394606, intRangeMaximum=1.4465457485394606, id='x1'}\n", |
| 479 | + "java.lang.DoubleFeature{value=486.3570827969562, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x2'}\n", |
| 480 | + "java.lang.DoubleFeature{value=8.538028490590378, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x3'}\n", |
| 481 | + "java.lang.DoubleFeature{value=2.666374576834393, intRangeMinimum=2.666374576834393, intRangeMaximum=2.666374576834393, id='x4'}\n" |
480 | 482 | ] |
481 | 483 | } |
482 | 484 | ], |
|
511 | 513 | }, |
512 | 514 | { |
513 | 515 | "cell_type": "code", |
514 | | - "execution_count": 57, |
| 516 | + "execution_count": 20, |
515 | 517 | "id": "3f64510a", |
516 | 518 | "metadata": {}, |
517 | 519 | "outputs": [ |
|
588 | 590 | }, |
589 | 591 | { |
590 | 592 | "cell_type": "code", |
591 | | - "execution_count": 58, |
| 593 | + "execution_count": 21, |
592 | 594 | "id": "603d909e", |
593 | 595 | "metadata": {}, |
594 | 596 | "outputs": [], |
|
598 | 600 | }, |
599 | 601 | { |
600 | 602 | "cell_type": "code", |
601 | | - "execution_count": 59, |
| 603 | + "execution_count": 22, |
602 | 604 | "id": "8baeb746", |
603 | 605 | "metadata": {}, |
604 | 606 | "outputs": [ |
|
620 | 622 | }, |
621 | 623 | { |
622 | 624 | "cell_type": "code", |
623 | | - "execution_count": 75, |
| 625 | + "execution_count": 23, |
624 | 626 | "id": "2080ef92", |
625 | 627 | "metadata": {}, |
626 | 628 | "outputs": [], |
627 | 629 | "source": [ |
628 | | - "def predict(inputs):\n", |
629 | | - " values = [feature.getValue().asNumber() for feature in inputs.getFeatures()]\n", |
| 630 | + "from typing import List\n", |
| 631 | + "from trustyai.utils import toJList\n", |
| 632 | + "\n", |
| 633 | + "def predict(inputs: List[PredictionInput]) -> List[PredictionOutput]:\n", |
| 634 | + " values = [feature.getValue().asNumber() for feature in inputs.get(0).getFeatures()]\n", |
630 | 635 | " result = xg_model.predict(np.array([values]))\n", |
631 | 636 | " value = False if result[0]==0 else True\n", |
632 | 637 | " output = Output(\"PaidLoan\", Type.BOOLEAN, Value(value), 0.0)\n", |
633 | | - " return PredictionOutput([output])" |
| 638 | + " return toJList([PredictionOutput([output])])" |
634 | 639 | ] |
635 | 640 | }, |
636 | 641 | { |
637 | 642 | "cell_type": "code", |
638 | | - "execution_count": 76, |
| 643 | + "execution_count": 24, |
639 | 644 | "id": "00347c63", |
640 | 645 | "metadata": {}, |
641 | 646 | "outputs": [], |
|
647 | 652 | }, |
648 | 653 | { |
649 | 654 | "cell_type": "code", |
650 | | - "execution_count": 77, |
| 655 | + "execution_count": 25, |
651 | 656 | "id": "06d52535", |
652 | 657 | "metadata": {}, |
653 | 658 | "outputs": [], |
|
697 | 702 | }, |
698 | 703 | { |
699 | 704 | "cell_type": "code", |
700 | | - "execution_count": 78, |
| 705 | + "execution_count": 26, |
701 | 706 | "id": "fa3e099b", |
702 | 707 | "metadata": {}, |
703 | 708 | "outputs": [ |
|
716 | 721 | }, |
717 | 722 | { |
718 | 723 | "cell_type": "code", |
719 | | - "execution_count": 79, |
| 724 | + "execution_count": 27, |
720 | 725 | "id": "2b279cae", |
721 | 726 | "metadata": {}, |
722 | 727 | "outputs": [ |
723 | 728 | { |
724 | | - "name": "stdout", |
725 | | - "output_type": "stream", |
726 | | - "text": [ |
727 | | - "Output{value=false, type=boolean, score=0.0, name='PaidLoan'}\n" |
728 | | - ] |
| 729 | + "data": { |
| 730 | + "text/plain": [ |
| 731 | + "'Output{value=false, type=boolean, score=0.0, name='PaidLoan'}'" |
| 732 | + ] |
| 733 | + }, |
| 734 | + "execution_count": 27, |
| 735 | + "metadata": {}, |
| 736 | + "output_type": "execute_result" |
729 | 737 | } |
730 | 738 | ], |
731 | 739 | "source": [ |
732 | | - "print(model.predictAsync(PredictionInput(features)).get().getOutputs()[0])" |
| 740 | + "from trustyai.utils import toJList\n", |
| 741 | + "\n", |
| 742 | + "model.predictAsync(toJList([PredictionInput(features)])).get()[0].getOutputs()[0].toString()" |
733 | 743 | ] |
734 | 744 | }, |
735 | 745 | { |
736 | 746 | "cell_type": "code", |
737 | | - "execution_count": 80, |
| 747 | + "execution_count": 28, |
738 | 748 | "id": "18fff350", |
739 | 749 | "metadata": {}, |
740 | 750 | "outputs": [], |
|
744 | 754 | }, |
745 | 755 | { |
746 | 756 | "cell_type": "code", |
747 | | - "execution_count": 81, |
| 757 | + "execution_count": 29, |
748 | 758 | "id": "3a4815d9", |
749 | 759 | "metadata": {}, |
750 | 760 | "outputs": [], |
|
754 | 764 | }, |
755 | 765 | { |
756 | 766 | "cell_type": "code", |
757 | | - "execution_count": 85, |
| 767 | + "execution_count": 30, |
758 | 768 | "id": "7277e246", |
759 | 769 | "metadata": {}, |
760 | 770 | "outputs": [], |
761 | 771 | "source": [ |
762 | | - "termination_config = TerminationConfig().withBestScoreFeasible(True).withSecondsSpentLimit(Long.valueOf(10))\n", |
| 772 | + "termination_config = TerminationConfig().withSecondsSpentLimit(Long.valueOf(20))\n", |
763 | 773 | "\n", |
764 | 774 | "solver_config = (\n", |
765 | 775 | " CounterfactualConfigurationFactory.builder()\n", |
|
770 | 780 | }, |
771 | 781 | { |
772 | 782 | "cell_type": "code", |
773 | | - "execution_count": 86, |
| 783 | + "execution_count": 31, |
774 | 784 | "id": "338e61e9", |
775 | 785 | "metadata": {}, |
776 | 786 | "outputs": [], |
|
780 | 790 | }, |
781 | 791 | { |
782 | 792 | "cell_type": "code", |
783 | | - "execution_count": 87, |
| 793 | + "execution_count": 32, |
784 | 794 | "id": "f9340354", |
785 | 795 | "metadata": {}, |
786 | 796 | "outputs": [], |
|
790 | 800 | }, |
791 | 801 | { |
792 | 802 | "cell_type": "code", |
793 | | - "execution_count": 88, |
| 803 | + "execution_count": 33, |
794 | 804 | "id": "ef45bde0", |
795 | 805 | "metadata": {}, |
796 | 806 | "outputs": [], |
|
802 | 812 | }, |
803 | 813 | { |
804 | 814 | "cell_type": "code", |
805 | | - "execution_count": 89, |
| 815 | + "execution_count": 34, |
806 | 816 | "id": "5884e49c", |
807 | 817 | "metadata": {}, |
808 | 818 | "outputs": [], |
|
812 | 822 | }, |
813 | 823 | { |
814 | 824 | "cell_type": "code", |
815 | | - "execution_count": 90, |
| 825 | + "execution_count": 37, |
816 | 826 | "id": "e9ea7928", |
817 | 827 | "metadata": {}, |
818 | 828 | "outputs": [], |
819 | 829 | "source": [ |
820 | | - "explanation = explainer.explainAsync(prediction, model)" |
| 830 | + "explanation = explainer.explainAsync(prediction, model).get()" |
| 831 | + ] |
| 832 | + }, |
| 833 | + { |
| 834 | + "cell_type": "code", |
| 835 | + "execution_count": 38, |
| 836 | + "id": "852c395c", |
| 837 | + "metadata": {}, |
| 838 | + "outputs": [ |
| 839 | + { |
| 840 | + "name": "stdout", |
| 841 | + "output_type": "stream", |
| 842 | + "text": [ |
| 843 | + "java.lang.BooleanFeature{value=false, id='NewCreditCustomer'}\n", |
| 844 | + "java.lang.DoubleFeature{value=2125.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='Amount'}\n", |
| 845 | + "java.lang.DoubleFeature{value=20.97, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='Interest'}\n", |
| 846 | + "java.lang.DoubleFeature{value=60.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='LoanDuration'}\n", |
| 847 | + "java.lang.DoubleFeature{value=4.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='Education'}\n", |
| 848 | + "java.lang.DoubleFeature{value=0.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='NrOfDependants'}\n", |
| 849 | + "java.lang.DoubleFeature{value=6.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='EmploymentDurationCurrentEmployer'}\n", |
| 850 | + "java.lang.DoubleFeature{value=0.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='IncomeFromPrincipalEmployer'}\n", |
| 851 | + "java.lang.DoubleFeature{value=301.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='IncomeFromPension'}\n", |
| 852 | + "java.lang.DoubleFeature{value=0.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='IncomeFromFamilyAllowance'}\n", |
| 853 | + "java.lang.DoubleFeature{value=53.31125429433703, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='IncomeFromSocialWelfare'}\n", |
| 854 | + "java.lang.DoubleFeature{value=0.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='IncomeFromLeavePay'}\n", |
| 855 | + "java.lang.DoubleFeature{value=0.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='IncomeFromChildSupport'}\n", |
| 856 | + "java.lang.DoubleFeature{value=0.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='IncomeOther'}\n", |
| 857 | + "java.lang.DoubleFeature{value=8.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='ExistingLiabilities'}\n", |
| 858 | + "java.lang.DoubleFeature{value=1.230474777192958, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='RefinanceLiabilities'}\n", |
| 859 | + "java.lang.DoubleFeature{value=26.29, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='DebtToIncome'}\n", |
| 860 | + "java.lang.DoubleFeature{value=10.92, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='FreeCash'}\n", |
| 861 | + "java.lang.DoubleFeature{value=1000.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='CreditScoreEeMini'}\n", |
| 862 | + "java.lang.DoubleFeature{value=1.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='NoOfPreviousLoansBeforeLoan'}\n", |
| 863 | + "java.lang.DoubleFeature{value=500.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='AmountOfPreviousLoansBeforeLoan'}\n", |
| 864 | + "java.lang.DoubleFeature{value=590.95, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='PreviousRepaymentsBeforeLoan'}\n", |
| 865 | + "java.lang.DoubleFeature{value=0.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='PreviousEarlyRepaymentsBefoleLoan'}\n", |
| 866 | + "java.lang.DoubleFeature{value=6.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='PreviousEarlyRepaymentsCountBeforeLoan'}\n", |
| 867 | + "java.lang.BooleanFeature{value=false, id='Council_house'}\n", |
| 868 | + "java.lang.BooleanFeature{value=false, id='Homeless'}\n", |
| 869 | + "java.lang.BooleanFeature{value=false, id='Joint_ownership'}\n", |
| 870 | + "java.lang.BooleanFeature{value=false, id='Joint_tenant'}\n", |
| 871 | + "java.lang.BooleanFeature{value=false, id='Living_with_parents'}\n", |
| 872 | + "java.lang.BooleanFeature{value=false, id='Mortgage'}\n", |
| 873 | + "java.lang.BooleanFeature{value=false, id='Other'}\n", |
| 874 | + "java.lang.BooleanFeature{value=true, id='Owner'}\n", |
| 875 | + "java.lang.BooleanFeature{value=false, id='Owner_with_encumbrance'}\n", |
| 876 | + "java.lang.BooleanFeature{value=true, id='Tenant'}\n", |
| 877 | + "java.lang.BooleanFeature{value=false, id='Entrepreneur'}\n", |
| 878 | + "java.lang.BooleanFeature{value=false, id='Fully'}\n", |
| 879 | + "java.lang.BooleanFeature{value=false, id='Partially'}\n", |
| 880 | + "java.lang.BooleanFeature{value=true, id='Retiree'}\n", |
| 881 | + "java.lang.BooleanFeature{value=false, id='Self_employed'}\n" |
| 882 | + ] |
| 883 | + } |
| 884 | + ], |
| 885 | + "source": [ |
| 886 | + "for entity in explanation.getEntities():\n", |
| 887 | + " print(entity)" |
821 | 888 | ] |
| 889 | + }, |
| 890 | + { |
| 891 | + "cell_type": "code", |
| 892 | + "execution_count": 42, |
| 893 | + "id": "cad93dea", |
| 894 | + "metadata": {}, |
| 895 | + "outputs": [ |
| 896 | + { |
| 897 | + "data": { |
| 898 | + "text/plain": [ |
| 899 | + "'Output{value=true, type=boolean, score=0.0, name='PaidLoan'}'" |
| 900 | + ] |
| 901 | + }, |
| 902 | + "execution_count": 42, |
| 903 | + "metadata": {}, |
| 904 | + "output_type": "execute_result" |
| 905 | + } |
| 906 | + ], |
| 907 | + "source": [ |
| 908 | + "testf = [f.asFeature() for f in explanation.getEntities()]\n", |
| 909 | + "model.predictAsync(toJList([PredictionInput(testf)])).get()[0].getOutputs()[0].toString()" |
| 910 | + ] |
| 911 | + }, |
| 912 | + { |
| 913 | + "cell_type": "code", |
| 914 | + "execution_count": null, |
| 915 | + "id": "927fe8f9", |
| 916 | + "metadata": {}, |
| 917 | + "outputs": [], |
| 918 | + "source": [] |
822 | 919 | } |
823 | 920 | ], |
824 | 921 | "metadata": { |
|
0 commit comments