Skip to content

Commit 7578393

Browse files
committed
Fix CF notebook
1 parent 9c03d78 commit 7578393

File tree

1 file changed

+141
-44
lines changed

1 file changed

+141
-44
lines changed

notebooks/Counterfactuals.ipynb

Lines changed: 141 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
"\n",
2020
"trustyai.init(\n",
2121
" path=[\n",
22-
" \"../dep/org/kie/kogito/explainability-core/1.8.0.Final/*\",\n",
22+
" \"../dep/org/kie/kogito/explainability-core/1.8.0.Final/*\",\n",
23+
"# \"../dep/org/kie/kogito/explainability-core/1.8.0.Final/explainability-core-2.0.0-SNAPSHOT.jar\",\n",
24+
"# \"../dep/org/kie/kogito/explainability-core/1.8.0.Final/explainability-core-1.8.0.Final-tests.jar\",\n",
2325
" \"../dep/org/slf4j/slf4j-api/1.7.30/slf4j-api-1.7.30.jar\",\n",
2426
" \"../dep/org/apache/commons/commons-lang3/3.12.0/commons-lang3-3.12.0.jar\",\n",
2527
" \"../dep/org/optaplanner/optaplanner-core/8.8.0.Final/optaplanner-core-8.8.0.Final.jar\",\n",
@@ -139,12 +141,12 @@
139141
"name": "stdout",
140142
"output_type": "stream",
141143
"text": [
142-
"Feature x1 has value 5.154409874945087\n",
143-
"Feature x2 has value 1.730733062926768\n",
144-
"Feature x3 has value 9.000708957897421\n",
145-
"Feature x4 has value 4.539576608157114\n",
144+
"Feature x1 has value 1.4465457485394606\n",
145+
"Feature x2 has value 9.904958794943276\n",
146+
"Feature x3 has value 8.632408661102822\n",
147+
"Feature x4 has value 2.666374576834393\n",
146148
"\n",
147-
"Features sum is 20.42542850392639\n"
149+
"Features sum is 22.65028778141995\n"
148150
]
149151
}
150152
],
@@ -376,12 +378,12 @@
376378
"name": "stdout",
377379
"output_type": "stream",
378380
"text": [
379-
"java.lang.DoubleFeature{value=5.154409874945087, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x1'}\n",
380-
"java.lang.DoubleFeature{value=480.69885432449246, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x2'}\n",
381-
"java.lang.DoubleFeature{value=9.000708957897421, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x3'}\n",
382-
"java.lang.DoubleFeature{value=4.2932049776029935, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x4'}\n",
381+
"java.lang.DoubleFeature{value=1.547375925562755, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x1'}\n",
382+
"java.lang.DoubleFeature{value=486.3570827969562, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x2'}\n",
383+
"java.lang.DoubleFeature{value=8.632408661102822, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x3'}\n",
384+
"java.lang.DoubleFeature{value=2.666374576834393, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x4'}\n",
383385
"\n",
384-
"Feature sum is 499.147178134938\n"
386+
"Feature sum is 499.2032419604562\n"
385387
]
386388
}
387389
],
@@ -470,13 +472,13 @@
470472
"name": "stdout",
471473
"output_type": "stream",
472474
"text": [
473-
"Original x1: 5.154409874945087\n",
474-
"Original x4: 4.539576608157114\n",
475+
"Original x1: 1.4465457485394606\n",
476+
"Original x4: 2.666374576834393\n",
475477
"\n",
476-
"java.lang.DoubleFeature{value=5.154409874945087, intRangeMinimum=5.154409874945087, intRangeMaximum=5.154409874945087, id='x1'}\n",
477-
"java.lang.DoubleFeature{value=480.69885432449246, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x2'}\n",
478-
"java.lang.DoubleFeature{value=9.291426877845232, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x3'}\n",
479-
"java.lang.DoubleFeature{value=4.539576608157114, intRangeMinimum=4.539576608157114, intRangeMaximum=4.539576608157114, id='x4'}\n"
478+
"java.lang.DoubleFeature{value=1.4465457485394606, intRangeMinimum=1.4465457485394606, intRangeMaximum=1.4465457485394606, id='x1'}\n",
479+
"java.lang.DoubleFeature{value=486.3570827969562, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x2'}\n",
480+
"java.lang.DoubleFeature{value=8.538028490590378, intRangeMinimum=0.0, intRangeMaximum=1000.0, id='x3'}\n",
481+
"java.lang.DoubleFeature{value=2.666374576834393, intRangeMinimum=2.666374576834393, intRangeMaximum=2.666374576834393, id='x4'}\n"
480482
]
481483
}
482484
],
@@ -511,7 +513,7 @@
511513
},
512514
{
513515
"cell_type": "code",
514-
"execution_count": 57,
516+
"execution_count": 20,
515517
"id": "3f64510a",
516518
"metadata": {},
517519
"outputs": [
@@ -588,7 +590,7 @@
588590
},
589591
{
590592
"cell_type": "code",
591-
"execution_count": 58,
593+
"execution_count": 21,
592594
"id": "603d909e",
593595
"metadata": {},
594596
"outputs": [],
@@ -598,7 +600,7 @@
598600
},
599601
{
600602
"cell_type": "code",
601-
"execution_count": 59,
603+
"execution_count": 22,
602604
"id": "8baeb746",
603605
"metadata": {},
604606
"outputs": [
@@ -620,22 +622,25 @@
620622
},
621623
{
622624
"cell_type": "code",
623-
"execution_count": 75,
625+
"execution_count": 23,
624626
"id": "2080ef92",
625627
"metadata": {},
626628
"outputs": [],
627629
"source": [
628-
"def predict(inputs):\n",
629-
" values = [feature.getValue().asNumber() for feature in inputs.getFeatures()]\n",
630+
"from typing import List\n",
631+
"from trustyai.utils import toJList\n",
632+
"\n",
633+
"def predict(inputs: List[PredictionInput]) -> List[PredictionOutput]:\n",
634+
" values = [feature.getValue().asNumber() for feature in inputs.get(0).getFeatures()]\n",
630635
" result = xg_model.predict(np.array([values]))\n",
631636
" value = False if result[0]==0 else True\n",
632637
" output = Output(\"PaidLoan\", Type.BOOLEAN, Value(value), 0.0)\n",
633-
" return PredictionOutput([output])"
638+
" return toJList([PredictionOutput([output])])"
634639
]
635640
},
636641
{
637642
"cell_type": "code",
638-
"execution_count": 76,
643+
"execution_count": 24,
639644
"id": "00347c63",
640645
"metadata": {},
641646
"outputs": [],
@@ -647,7 +652,7 @@
647652
},
648653
{
649654
"cell_type": "code",
650-
"execution_count": 77,
655+
"execution_count": 25,
651656
"id": "06d52535",
652657
"metadata": {},
653658
"outputs": [],
@@ -697,7 +702,7 @@
697702
},
698703
{
699704
"cell_type": "code",
700-
"execution_count": 78,
705+
"execution_count": 26,
701706
"id": "fa3e099b",
702707
"metadata": {},
703708
"outputs": [
@@ -716,25 +721,30 @@
716721
},
717722
{
718723
"cell_type": "code",
719-
"execution_count": 79,
724+
"execution_count": 27,
720725
"id": "2b279cae",
721726
"metadata": {},
722727
"outputs": [
723728
{
724-
"name": "stdout",
725-
"output_type": "stream",
726-
"text": [
727-
"Output{value=false, type=boolean, score=0.0, name='PaidLoan'}\n"
728-
]
729+
"data": {
730+
"text/plain": [
731+
"'Output{value=false, type=boolean, score=0.0, name='PaidLoan'}'"
732+
]
733+
},
734+
"execution_count": 27,
735+
"metadata": {},
736+
"output_type": "execute_result"
729737
}
730738
],
731739
"source": [
732-
"print(model.predictAsync(PredictionInput(features)).get().getOutputs()[0])"
740+
"from trustyai.utils import toJList\n",
741+
"\n",
742+
"model.predictAsync(toJList([PredictionInput(features)])).get()[0].getOutputs()[0].toString()"
733743
]
734744
},
735745
{
736746
"cell_type": "code",
737-
"execution_count": 80,
747+
"execution_count": 28,
738748
"id": "18fff350",
739749
"metadata": {},
740750
"outputs": [],
@@ -744,7 +754,7 @@
744754
},
745755
{
746756
"cell_type": "code",
747-
"execution_count": 81,
757+
"execution_count": 29,
748758
"id": "3a4815d9",
749759
"metadata": {},
750760
"outputs": [],
@@ -754,12 +764,12 @@
754764
},
755765
{
756766
"cell_type": "code",
757-
"execution_count": 85,
767+
"execution_count": 30,
758768
"id": "7277e246",
759769
"metadata": {},
760770
"outputs": [],
761771
"source": [
762-
"termination_config = TerminationConfig().withBestScoreFeasible(True).withSecondsSpentLimit(Long.valueOf(10))\n",
772+
"termination_config = TerminationConfig().withSecondsSpentLimit(Long.valueOf(20))\n",
763773
"\n",
764774
"solver_config = (\n",
765775
" CounterfactualConfigurationFactory.builder()\n",
@@ -770,7 +780,7 @@
770780
},
771781
{
772782
"cell_type": "code",
773-
"execution_count": 86,
783+
"execution_count": 31,
774784
"id": "338e61e9",
775785
"metadata": {},
776786
"outputs": [],
@@ -780,7 +790,7 @@
780790
},
781791
{
782792
"cell_type": "code",
783-
"execution_count": 87,
793+
"execution_count": 32,
784794
"id": "f9340354",
785795
"metadata": {},
786796
"outputs": [],
@@ -790,7 +800,7 @@
790800
},
791801
{
792802
"cell_type": "code",
793-
"execution_count": 88,
803+
"execution_count": 33,
794804
"id": "ef45bde0",
795805
"metadata": {},
796806
"outputs": [],
@@ -802,7 +812,7 @@
802812
},
803813
{
804814
"cell_type": "code",
805-
"execution_count": 89,
815+
"execution_count": 34,
806816
"id": "5884e49c",
807817
"metadata": {},
808818
"outputs": [],
@@ -812,13 +822,100 @@
812822
},
813823
{
814824
"cell_type": "code",
815-
"execution_count": 90,
825+
"execution_count": 37,
816826
"id": "e9ea7928",
817827
"metadata": {},
818828
"outputs": [],
819829
"source": [
820-
"explanation = explainer.explainAsync(prediction, model)"
830+
"explanation = explainer.explainAsync(prediction, model).get()"
831+
]
832+
},
833+
{
834+
"cell_type": "code",
835+
"execution_count": 38,
836+
"id": "852c395c",
837+
"metadata": {},
838+
"outputs": [
839+
{
840+
"name": "stdout",
841+
"output_type": "stream",
842+
"text": [
843+
"java.lang.BooleanFeature{value=false, id='NewCreditCustomer'}\n",
844+
"java.lang.DoubleFeature{value=2125.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='Amount'}\n",
845+
"java.lang.DoubleFeature{value=20.97, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='Interest'}\n",
846+
"java.lang.DoubleFeature{value=60.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='LoanDuration'}\n",
847+
"java.lang.DoubleFeature{value=4.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='Education'}\n",
848+
"java.lang.DoubleFeature{value=0.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='NrOfDependants'}\n",
849+
"java.lang.DoubleFeature{value=6.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='EmploymentDurationCurrentEmployer'}\n",
850+
"java.lang.DoubleFeature{value=0.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='IncomeFromPrincipalEmployer'}\n",
851+
"java.lang.DoubleFeature{value=301.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='IncomeFromPension'}\n",
852+
"java.lang.DoubleFeature{value=0.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='IncomeFromFamilyAllowance'}\n",
853+
"java.lang.DoubleFeature{value=53.31125429433703, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='IncomeFromSocialWelfare'}\n",
854+
"java.lang.DoubleFeature{value=0.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='IncomeFromLeavePay'}\n",
855+
"java.lang.DoubleFeature{value=0.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='IncomeFromChildSupport'}\n",
856+
"java.lang.DoubleFeature{value=0.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='IncomeOther'}\n",
857+
"java.lang.DoubleFeature{value=8.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='ExistingLiabilities'}\n",
858+
"java.lang.DoubleFeature{value=1.230474777192958, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='RefinanceLiabilities'}\n",
859+
"java.lang.DoubleFeature{value=26.29, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='DebtToIncome'}\n",
860+
"java.lang.DoubleFeature{value=10.92, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='FreeCash'}\n",
861+
"java.lang.DoubleFeature{value=1000.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='CreditScoreEeMini'}\n",
862+
"java.lang.DoubleFeature{value=1.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='NoOfPreviousLoansBeforeLoan'}\n",
863+
"java.lang.DoubleFeature{value=500.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='AmountOfPreviousLoansBeforeLoan'}\n",
864+
"java.lang.DoubleFeature{value=590.95, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='PreviousRepaymentsBeforeLoan'}\n",
865+
"java.lang.DoubleFeature{value=0.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='PreviousEarlyRepaymentsBefoleLoan'}\n",
866+
"java.lang.DoubleFeature{value=6.0, intRangeMinimum=0.0, intRangeMaximum=10000.0, id='PreviousEarlyRepaymentsCountBeforeLoan'}\n",
867+
"java.lang.BooleanFeature{value=false, id='Council_house'}\n",
868+
"java.lang.BooleanFeature{value=false, id='Homeless'}\n",
869+
"java.lang.BooleanFeature{value=false, id='Joint_ownership'}\n",
870+
"java.lang.BooleanFeature{value=false, id='Joint_tenant'}\n",
871+
"java.lang.BooleanFeature{value=false, id='Living_with_parents'}\n",
872+
"java.lang.BooleanFeature{value=false, id='Mortgage'}\n",
873+
"java.lang.BooleanFeature{value=false, id='Other'}\n",
874+
"java.lang.BooleanFeature{value=true, id='Owner'}\n",
875+
"java.lang.BooleanFeature{value=false, id='Owner_with_encumbrance'}\n",
876+
"java.lang.BooleanFeature{value=true, id='Tenant'}\n",
877+
"java.lang.BooleanFeature{value=false, id='Entrepreneur'}\n",
878+
"java.lang.BooleanFeature{value=false, id='Fully'}\n",
879+
"java.lang.BooleanFeature{value=false, id='Partially'}\n",
880+
"java.lang.BooleanFeature{value=true, id='Retiree'}\n",
881+
"java.lang.BooleanFeature{value=false, id='Self_employed'}\n"
882+
]
883+
}
884+
],
885+
"source": [
886+
"for entity in explanation.getEntities():\n",
887+
" print(entity)"
821888
]
889+
},
890+
{
891+
"cell_type": "code",
892+
"execution_count": 42,
893+
"id": "cad93dea",
894+
"metadata": {},
895+
"outputs": [
896+
{
897+
"data": {
898+
"text/plain": [
899+
"'Output{value=true, type=boolean, score=0.0, name='PaidLoan'}'"
900+
]
901+
},
902+
"execution_count": 42,
903+
"metadata": {},
904+
"output_type": "execute_result"
905+
}
906+
],
907+
"source": [
908+
"testf = [f.asFeature() for f in explanation.getEntities()]\n",
909+
"model.predictAsync(toJList([PredictionInput(testf)])).get()[0].getOutputs()[0].toString()"
910+
]
911+
},
912+
{
913+
"cell_type": "code",
914+
"execution_count": null,
915+
"id": "927fe8f9",
916+
"metadata": {},
917+
"outputs": [],
918+
"source": []
822919
}
823920
],
824921
"metadata": {

0 commit comments

Comments
 (0)