Skip to content

Commit 1c93db9

Browse files
committed
Add proper confidence to XGBoost model prediction (closes #8)
1 parent b1ce350 commit 1c93db9

File tree

1 file changed

+93
-88
lines changed

1 file changed

+93
-88
lines changed

notebooks/Counterfactuals.ipynb

Lines changed: 93 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
"trustyai.init(\n",
2121
" path=[\n",
2222
" \"../dep/org/kie/kogito/explainability-core/1.8.0.Final/*\",\n",
23-
"# \"../dep/org/kie/kogito/explainability-core/1.8.0.Final/explainability-core-2.0.0-SNAPSHOT.jar\",\n",
24-
"# \"../dep/org/kie/kogito/explainability-core/1.8.0.Final/explainability-core-1.8.0.Final-tests.jar\",\n",
2523
" \"../dep/org/slf4j/slf4j-api/1.7.30/slf4j-api-1.7.30.jar\",\n",
2624
" \"../dep/org/apache/commons/commons-lang3/3.12.0/commons-lang3-3.12.0.jar\",\n",
2725
" \"../dep/org/optaplanner/optaplanner-core/8.8.0.Final/optaplanner-core-8.8.0.Final.jar\",\n",
@@ -141,12 +139,12 @@
141139
"name": "stdout",
142140
"output_type": "stream",
143141
"text": [
144-
"Feature x1 has value 5.89410604979949\n",
145-
"Feature x2 has value 4.6702081567068525\n",
146-
"Feature x3 has value 7.484002107104038\n",
147-
"Feature x4 has value 8.39980161062865\n",
142+
"Feature x1 has value 9.135811031480424\n",
143+
"Feature x2 has value 5.0454318078238245\n",
144+
"Feature x3 has value 1.1416494951836265\n",
145+
"Feature x4 has value 0.362431844924509\n",
148146
"\n",
149-
"Features sum is 26.44811792423903\n"
147+
"Features sum is 15.685324179412383\n"
150148
]
151149
}
152150
],
@@ -378,12 +376,12 @@
378376
"name": "stdout",
379377
"output_type": "stream",
380378
"text": [
381-
"java.lang.DoubleFeature{value=6.143270159874259, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x1'}\n",
382-
"java.lang.DoubleFeature{value=4.6702081567068525, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x2'}\n",
383-
"java.lang.DoubleFeature{value=7.484002107104038, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x3'}\n",
384-
"java.lang.DoubleFeature{value=481.34447531284985, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x4'}\n",
379+
"java.lang.DoubleFeature{value=9.135811031480424, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x1'}\n",
380+
"java.lang.DoubleFeature{value=6.268698018465302, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x2'}\n",
381+
"java.lang.DoubleFeature{value=484.55755944660217, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x3'}\n",
382+
"java.lang.DoubleFeature{value=0.362431844924509, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x4'}\n",
385383
"\n",
386-
"Feature sum is 499.641955736535\n"
384+
"Feature sum is 500.3245003414724\n"
387385
]
388386
}
389387
],
@@ -472,13 +470,13 @@
472470
"name": "stdout",
473471
"output_type": "stream",
474472
"text": [
475-
"Original x1: 5.89410604979949\n",
476-
"Original x4: 8.39980161062865\n",
473+
"Original x1: 9.135811031480424\n",
474+
"Original x4: 0.362431844924509\n",
477475
"\n",
478-
"java.lang.DoubleFeature{value=5.89410604979949, intRangeMinimum=5.89410604979949, intRangeMaximum=5.89410604979949, id='x1'}\n",
479-
"java.lang.DoubleFeature{value=4.467150339062309, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x2'}\n",
480-
"java.lang.DoubleFeature{value=480.4495023914547, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x3'}\n",
481-
"java.lang.DoubleFeature{value=8.39980161062865, intRangeMinimum=8.39980161062865, intRangeMaximum=8.39980161062865, id='x4'}\n"
476+
"java.lang.DoubleFeature{value=9.135811031480424, intRangeMinimum=9.135811031480424, intRangeMaximum=9.135811031480424, id='x1'}\n",
477+
"java.lang.DoubleFeature{value=5.522381982853708, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x2'}\n",
478+
"java.lang.DoubleFeature{value=484.5070395184806, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x3'}\n",
479+
"java.lang.DoubleFeature{value=0.362431844924509, intRangeMinimum=0.362431844924509, intRangeMaximum=0.362431844924509, id='x4'}\n"
482480
]
483481
}
484482
],
@@ -659,9 +657,13 @@
659657
"\n",
660658
"def predict(inputs: List[PredictionInput]) -> List[PredictionOutput]:\n",
661659
" values = [feature.getValue().asNumber() for feature in inputs.get(0).getFeatures()]\n",
662-
" result = xg_model.predict(np.array([values]))\n",
663-
" value = False if result[0]==0 else True\n",
664-
" output = Output(\"PaidLoan\", Type.BOOLEAN, Value(value), 0.0)\n",
660+
" result = xg_model.predict_proba(np.array([values]))\n",
661+
" false_prob, true_prob = result[0]\n",
662+
" if false_prob > true_prob:\n",
663+
" prediction = (False, false_prob)\n",
664+
" else:\n",
665+
" prediction = (True, true_prob)\n",
666+
" output = Output(\"PaidLoan\", Type.BOOLEAN, Value(prediction[0]), prediction[1])\n",
665667
" return toJList([PredictionOutput([output])])"
666668
]
667669
},
@@ -697,52 +699,56 @@
697699
},
698700
{
699701
"cell_type": "code",
700-
"execution_count": 25,
702+
"execution_count": 27,
701703
"id": "06d52535",
702704
"metadata": {},
703705
"outputs": [],
704706
"source": [
705-
"features = [\n",
706-
" FeatureFactory.newBooleanFeature(\"NewCreditCustomer\", False),\n",
707-
" FeatureFactory.newNumericalFeature(\"Amount\", 2125.0),\n",
708-
" FeatureFactory.newNumericalFeature(\"Interest\", 20.97),\n",
709-
" FeatureFactory.newNumericalFeature(\"LoanDuration\", 60.0),\n",
710-
" FeatureFactory.newNumericalFeature(\"Education\", 4.0),\n",
711-
" FeatureFactory.newNumericalFeature(\"NrOfDependants\", 0.0),\n",
712-
" FeatureFactory.newNumericalFeature(\"EmploymentDurationCurrentEmployer\", 6.0),\n",
713-
" FeatureFactory.newNumericalFeature(\"IncomeFromPrincipalEmployer\", 0.0),\n",
714-
" FeatureFactory.newNumericalFeature(\"IncomeFromPension\", 301.0),\n",
715-
" FeatureFactory.newNumericalFeature(\"IncomeFromFamilyAllowance\", 0.0),\n",
716-
" FeatureFactory.newNumericalFeature(\"IncomeFromSocialWelfare\", 53.0),\n",
717-
" FeatureFactory.newNumericalFeature(\"IncomeFromLeavePay\", 0.0),\n",
718-
" FeatureFactory.newNumericalFeature(\"IncomeFromChildSupport\", 0.0),\n",
719-
" FeatureFactory.newNumericalFeature(\"IncomeOther\", 0.0),\n",
720-
" FeatureFactory.newNumericalFeature(\"ExistingLiabilities\", 8.0),\n",
721-
" FeatureFactory.newNumericalFeature(\"RefinanceLiabilities\", 6.0),\n",
722-
" FeatureFactory.newNumericalFeature(\"DebtToIncome\", 26.29),\n",
723-
" FeatureFactory.newNumericalFeature(\"FreeCash\", 10.92),\n",
724-
" FeatureFactory.newNumericalFeature(\"CreditScoreEeMini\", 1000.0),\n",
725-
" FeatureFactory.newNumericalFeature(\"NoOfPreviousLoansBeforeLoan\", 1.0),\n",
726-
" FeatureFactory.newNumericalFeature(\"AmountOfPreviousLoansBeforeLoan\", 500.0),\n",
727-
" FeatureFactory.newNumericalFeature(\"PreviousRepaymentsBeforeLoan\", 590.95),\n",
728-
" FeatureFactory.newNumericalFeature(\"PreviousEarlyRepaymentsBefoleLoan\", 0.0),\n",
729-
" FeatureFactory.newNumericalFeature(\"PreviousEarlyRepaymentsCountBeforeLoan\", 0.0),\n",
730-
" FeatureFactory.newBooleanFeature(\"Council_house\", False),\n",
731-
" FeatureFactory.newBooleanFeature(\"Homeless\", False),\n",
732-
" FeatureFactory.newBooleanFeature(\"Joint_ownership\", False),\n",
733-
" FeatureFactory.newBooleanFeature(\"Joint_tenant\", False),\n",
734-
" FeatureFactory.newBooleanFeature(\"Living_with_parents\", False),\n",
735-
" FeatureFactory.newBooleanFeature(\"Mortgage\", False),\n",
736-
" FeatureFactory.newBooleanFeature(\"Other\", False),\n",
737-
" FeatureFactory.newBooleanFeature(\"Owner\", False),\n",
738-
" FeatureFactory.newBooleanFeature(\"Owner_with_encumbrance\", True),\n",
739-
" FeatureFactory.newBooleanFeature(\"Tenant\", True),\n",
740-
" FeatureFactory.newBooleanFeature(\"Entrepreneur\", False),\n",
741-
" FeatureFactory.newBooleanFeature(\"Fully\", False),\n",
742-
" FeatureFactory.newBooleanFeature(\"Partially\", False),\n",
743-
" FeatureFactory.newBooleanFeature(\"Retiree\", True),\n",
744-
" FeatureFactory.newBooleanFeature(\"Self_employed\", False), \n",
745-
"]"
707+
"def make_feature(name, value):\n",
708+
" if type(value) is bool:\n",
709+
" return FeatureFactory.newBooleanFeature(name, value)\n",
710+
" else:\n",
711+
" return FeatureFactory.newNumericalFeature(name, value)\n",
712+
"\n",
713+
"features = [make_feature(p[0], p[1]) for p in [(\"NewCreditCustomer\", False),\n",
714+
" (\"Amount\", 2125.0),\n",
715+
" (\"Interest\", 20.97),\n",
716+
" (\"LoanDuration\", 60.0),\n",
717+
" (\"Education\", 4.0),\n",
718+
" (\"NrOfDependants\", 0.0),\n",
719+
" (\"EmploymentDurationCurrentEmployer\", 6.0),\n",
720+
" (\"IncomeFromPrincipalEmployer\", 0.0),\n",
721+
" (\"IncomeFromPension\", 301.0),\n",
722+
" (\"IncomeFromFamilyAllowance\", 0.0),\n",
723+
" (\"IncomeFromSocialWelfare\", 53.0),\n",
724+
" (\"IncomeFromLeavePay\", 0.0),\n",
725+
" (\"IncomeFromChildSupport\", 0.0),\n",
726+
" (\"IncomeOther\", 0.0),\n",
727+
" (\"ExistingLiabilities\", 8.0),\n",
728+
" (\"RefinanceLiabilities\", 6.0),\n",
729+
" (\"DebtToIncome\", 26.29),\n",
730+
" (\"FreeCash\", 10.92),\n",
731+
" (\"CreditScoreEeMini\", 1000.0),\n",
732+
" (\"NoOfPreviousLoansBeforeLoan\", 1.0),\n",
733+
" (\"AmountOfPreviousLoansBeforeLoan\", 500.0),\n",
734+
" (\"PreviousRepaymentsBeforeLoan\", 590.95),\n",
735+
" (\"PreviousEarlyRepaymentsBefoleLoan\", 0.0),\n",
736+
" (\"PreviousEarlyRepaymentsCountBeforeLoan\", 0.0),\n",
737+
" (\"Council_house\", False),\n",
738+
" (\"Homeless\", False),\n",
739+
" (\"Joint_ownership\", False),\n",
740+
" (\"Joint_tenant\", False),\n",
741+
" (\"Living_with_parents\", False),\n",
742+
" (\"Mortgage\", False),\n",
743+
" (\"Other\", False),\n",
744+
" (\"Owner\", False),\n",
745+
" (\"Owner_with_encumbrance\", True),\n",
746+
" (\"Tenant\", True),\n",
747+
" (\"Entrepreneur\", False),\n",
748+
" (\"Fully\", False),\n",
749+
" (\"Partially\", False),\n",
750+
" (\"Retiree\", True),\n",
751+
" (\"Self_employed\", False)]]"
746752
]
747753
},
748754
{
@@ -755,17 +761,17 @@
755761
},
756762
{
757763
"cell_type": "code",
758-
"execution_count": 26,
764+
"execution_count": 28,
759765
"id": "2b279cae",
760766
"metadata": {},
761767
"outputs": [
762768
{
763769
"data": {
764770
"text/plain": [
765-
"'Output{value=false, type=boolean, score=0.0, name='PaidLoan'}'"
771+
"'Output{value=false, type=boolean, score=0.7835956811904907, name='PaidLoan'}'"
766772
]
767773
},
768-
"execution_count": 26,
774+
"execution_count": 28,
769775
"metadata": {},
770776
"output_type": "execute_result"
771777
}
@@ -788,7 +794,7 @@
788794
},
789795
{
790796
"cell_type": "code",
791-
"execution_count": 28,
797+
"execution_count": 29,
792798
"id": "18fff350",
793799
"metadata": {},
794800
"outputs": [],
@@ -808,7 +814,7 @@
808814
},
809815
{
810816
"cell_type": "code",
811-
"execution_count": 29,
817+
"execution_count": 30,
812818
"id": "3a4815d9",
813819
"metadata": {},
814820
"outputs": [],
@@ -826,7 +832,7 @@
826832
},
827833
{
828834
"cell_type": "code",
829-
"execution_count": 30,
835+
"execution_count": 31,
830836
"id": "7277e246",
831837
"metadata": {},
832838
"outputs": [],
@@ -850,7 +856,7 @@
850856
},
851857
{
852858
"cell_type": "code",
853-
"execution_count": 31,
859+
"execution_count": 32,
854860
"id": "24d45182",
855861
"metadata": {},
856862
"outputs": [],
@@ -868,7 +874,7 @@
868874
},
869875
{
870876
"cell_type": "code",
871-
"execution_count": 32,
877+
"execution_count": 33,
872878
"id": "f9340354",
873879
"metadata": {},
874880
"outputs": [],
@@ -886,7 +892,7 @@
886892
},
887893
{
888894
"cell_type": "code",
889-
"execution_count": 33,
895+
"execution_count": 34,
890896
"id": "ef45bde0",
891897
"metadata": {},
892898
"outputs": [],
@@ -908,7 +914,7 @@
908914
},
909915
{
910916
"cell_type": "code",
911-
"execution_count": 34,
917+
"execution_count": 35,
912918
"id": "e9ea7928",
913919
"metadata": {},
914920
"outputs": [],
@@ -926,17 +932,17 @@
926932
},
927933
{
928934
"cell_type": "code",
929-
"execution_count": 35,
935+
"execution_count": 36,
930936
"id": "3433536d",
931937
"metadata": {},
932938
"outputs": [
933939
{
934940
"data": {
935941
"text/plain": [
936-
"'Output{value=true, type=boolean, score=0.0, name='PaidLoan'}'"
942+
"'Output{value=true, type=boolean, score=0.6006738543510437, name='PaidLoan'}'"
937943
]
938944
},
939-
"execution_count": 35,
945+
"execution_count": 36,
940946
"metadata": {},
941947
"output_type": "execute_result"
942948
}
@@ -956,7 +962,7 @@
956962
},
957963
{
958964
"cell_type": "code",
959-
"execution_count": 36,
965+
"execution_count": 37,
960966
"id": "aad1faa7",
961967
"metadata": {},
962968
"outputs": [
@@ -1008,7 +1014,7 @@
10081014
},
10091015
{
10101016
"cell_type": "code",
1011-
"execution_count": 37,
1017+
"execution_count": 38,
10121018
"id": "3cfa92da",
10131019
"metadata": {},
10141020
"outputs": [],
@@ -1070,7 +1076,7 @@
10701076
},
10711077
{
10721078
"cell_type": "code",
1073-
"execution_count": 38,
1079+
"execution_count": 39,
10741080
"id": "6e91cb5f",
10751081
"metadata": {},
10761082
"outputs": [],
@@ -1128,7 +1134,7 @@
11281134
},
11291135
{
11301136
"cell_type": "code",
1131-
"execution_count": 39,
1137+
"execution_count": 40,
11321138
"id": "0d39a0ab",
11331139
"metadata": {},
11341140
"outputs": [],
@@ -1150,7 +1156,7 @@
11501156
},
11511157
{
11521158
"cell_type": "code",
1153-
"execution_count": 40,
1159+
"execution_count": 41,
11541160
"id": "57a9b194",
11551161
"metadata": {},
11561162
"outputs": [],
@@ -1168,17 +1174,17 @@
11681174
},
11691175
{
11701176
"cell_type": "code",
1171-
"execution_count": 41,
1177+
"execution_count": 42,
11721178
"id": "f052a7f1",
11731179
"metadata": {},
11741180
"outputs": [
11751181
{
11761182
"data": {
11771183
"text/plain": [
1178-
"'Output{value=true, type=boolean, score=0.0, name='PaidLoan'}'"
1184+
"'Output{value=true, type=boolean, score=0.5038489103317261, name='PaidLoan'}'"
11791185
]
11801186
},
1181-
"execution_count": 41,
1187+
"execution_count": 42,
11821188
"metadata": {},
11831189
"output_type": "execute_result"
11841190
}
@@ -1198,7 +1204,7 @@
11981204
},
11991205
{
12001206
"cell_type": "code",
1201-
"execution_count": 42,
1207+
"execution_count": 43,
12021208
"id": "0557e545",
12031209
"metadata": {},
12041210
"outputs": [
@@ -1207,9 +1213,8 @@
12071213
"output_type": "stream",
12081214
"text": [
12091215
"Feature 'LoanDuration': 60.0 -> 56.947228037333545\n",
1210-
"Feature 'IncomeFromSocialWelfare': 53.0 -> 60.0\n",
1211-
"Feature 'FreeCash': 10.92 -> 10.914352713171315\n",
1212-
"Feature 'Tenant': true -> false\n"
1216+
"Feature 'IncomeFromSocialWelfare': 53.0 -> 59.6876474017064\n",
1217+
"Feature 'FreeCash': 10.92 -> 10.914352713171315\n"
12131218
]
12141219
}
12151220
],

0 commit comments

Comments
 (0)