1+ # pylint: disable=import-error, wrong-import-position, wrong-import-order, R0801
2+ """Test suite for counterfactual explanations"""
13import os
24import sys
35import uuid
79
810import trustyai
911
10- trustyai .init (path = [
11- "./dep/org/kie/kogito/explainability-core/1.8.0.Final/*" ,
12- "./dep/org/slf4j/slf4j-api/1.7.30/slf4j-api-1.7.30.jar" ,
13- "./dep/org/apache/commons/commons-lang3/3.12.0/commons-lang3-3.12.0.jar" ,
14- "./dep/org/optaplanner/optaplanner-core/8.8.0.Final/optaplanner-core-8.8.0.Final.jar" ,
15- "./dep/org/apache/commons/commons-math3/3.6.1/commons-math3-3.6.1.jar" ,
16- "./dep/org/kie/kie-api/7.55.0.Final/kie-api-7.55.0.Final.jar" ,
17- "./dep/io/micrometer/micrometer-core/1.6.6/micrometer-core-1.6.6.jar" ,
18- ])
19-
20- from trustyai .local .counterfactual import CounterfactualExplainer , CounterfactualConfigurationFactory
12+ trustyai .init (
13+ path = [
14+ "./dep/org/kie/kogito/explainability-core/1.8.0.Final/*" ,
15+ "./dep/org/slf4j/slf4j-api/1.7.30/slf4j-api-1.7.30.jar" ,
16+ "./dep/org/apache/commons/commons-lang3/3.12.0/commons-lang3-3.12.0.jar" ,
17+ "./dep/org/optaplanner/optaplanner-core/8.8.0.Final/optaplanner-core-8.8.0.Final.jar" ,
18+ "./dep/org/apache/commons/commons-math3/3.6.1/commons-math3-3.6.1.jar" ,
19+ "./dep/org/kie/kie-api/7.55.0.Final/kie-api-7.55.0.Final.jar" ,
20+ "./dep/io/micrometer/micrometer-core/1.6.6/micrometer-core-1.6.6.jar" ,
21+ ]
22+ )
23+
24+ from trustyai .local .counterfactual import (
25+ CounterfactualExplainer ,
26+ CounterfactualConfigurationFactory ,
27+ )
28+ from trustyai .utils import TestUtils , Config
2129from trustyai .model .domain import NumericalFeatureDomain
2230from trustyai .model import (
2331 CounterfactualPrediction ,
3240)
3341from java .util import Random
3442from java .lang import Long
35- from trustyai . utils import TestUtils , Config
43+
3644from org .optaplanner .core .config .solver .termination import TerminationConfig
3745
3846jrandom = Random ()
3947jrandom .setSeed (0 )
4048
4149
42- def mockFeature (d ):
43- return FeatureFactory .newNumericalFeature ("f-num" , d )
44-
45-
46- def sumSkipModel (inputs ):
47- """SumSkip test model"""
48- prediction_outputs = []
49- for predictionInput in inputs :
50- features = predictionInput .getFeatures ()
51- result = 0.0
52- for i in range (features .size ()):
53- if i != 0 :
54- result += features .get (i ).getValue ().asNumber ()
55- o = [Output (f"sum-but0" , Type .NUMBER , Value (result ), 1.0 )]
56- prediction_output = PredictionOutput (o )
57- prediction_outputs .append (prediction_output )
58- return prediction_outputs
59-
60-
61- def runCounterfactualSearch (goal ,
62- constraints ,
63- dataDomain ,
64- features ,
65- model ):
66- terminationConfig = TerminationConfig ().withScoreCalculationCountLimit (Long .valueOf (10_000 ))
67- solverConfig = CounterfactualConfigurationFactory \
68- .builder ().withTerminationConfig (terminationConfig ).build ()
69-
70- explainer = CounterfactualExplainer \
71- .builder () \
72- .withSolverConfig (solverConfig ) \
50+ def run_counterfactual_search (goal , constraints , data_domain , features , model ):
51+ """Creates a CF explainer and returns a result"""
52+ termination_config = TerminationConfig ().withScoreCalculationCountLimit (
53+ Long .valueOf (10_000 )
54+ )
55+ solver_config = (
56+ CounterfactualConfigurationFactory .builder ()
57+ .withTerminationConfig (termination_config )
7358 .build ()
74- input = PredictionInput (features )
59+ )
60+
61+ explainer = (
62+ CounterfactualExplainer .builder ().withSolverConfig (solver_config ).build ()
63+ )
64+ input_ = PredictionInput (features )
7565 output = PredictionOutput (goal )
76- domain = PredictionFeatureDomain (dataDomain .getFeatureDomains ())
77- prediction = CounterfactualPrediction (input , output , domain , constraints , None , uuid .uuid4 ())
78- return explainer .explainAsync (prediction , model ) \
79- .get (Config .INSTANCE .getAsyncTimeout (), Config .INSTANCE .getAsyncTimeUnit ())
66+ domain = PredictionFeatureDomain (data_domain .getFeatureDomains ())
67+ prediction = CounterfactualPrediction (
68+ input_ , output , domain , constraints , None , uuid .uuid4 ()
69+ )
70+ return explainer .explainAsync (prediction , model ).get (
71+ Config .INSTANCE .getAsyncTimeout (), Config .INSTANCE .getAsyncTimeUnit ()
72+ )
8073
8174
82- def testNonEmptyInput ():
75+ def test_non_empty_input ():
8376 """Checks whether the returned CF entities are not null"""
84- termination_config = TerminationConfig ().withScoreCalculationCountLimit (Long .valueOf (1000 ))
85- solver_config = CounterfactualConfigurationFactory .builder ().withTerminationConfig (termination_config ).build ()
77+ termination_config = TerminationConfig ().withScoreCalculationCountLimit (
78+ Long .valueOf (1000 )
79+ )
80+ solver_config = (
81+ CounterfactualConfigurationFactory .builder ()
82+ .withTerminationConfig (termination_config )
83+ .build ()
84+ )
8685 n_features = 10
87- explainer = CounterfactualExplainer .builder ().withSolverConfig (solver_config ).build ()
88- goal = [Output (f"f-num{ i + 1 } " , Type .NUMBER , Value (10.0 ), 0.0 ) for i in range (n_features - 1 )]
89- features = [FeatureFactory .newNumericalFeature (f"f-num{ i } " , i * 2.0 ) for i in range (n_features )]
86+ explainer = (
87+ CounterfactualExplainer .builder ().withSolverConfig (solver_config ).build ()
88+ )
89+ goal = [
90+ Output (f"f-num{ i + 1 } " , Type .NUMBER , Value (10.0 ), 0.0 )
91+ for i in range (n_features - 1 )
92+ ]
93+ features = [
94+ FeatureFactory .newNumericalFeature (f"f-num{ i } " , i * 2.0 )
95+ for i in range (n_features )
96+ ]
9097 constraints = [False ] * n_features
9198 feature_boundaries = [NumericalFeatureDomain .create (0.0 , 1000.0 )] * n_features
9299
93100 model = TestUtils .getSumSkipModel (0 )
94101 _input = PredictionInput (features )
95102 output = PredictionOutput (goal )
96- prediction = CounterfactualPrediction (_input , output , PredictionFeatureDomain (feature_boundaries ), constraints ,
97- None ,
98- uuid .uuid4 ())
103+ prediction = CounterfactualPrediction (
104+ _input ,
105+ output ,
106+ PredictionFeatureDomain (feature_boundaries ),
107+ constraints ,
108+ None ,
109+ uuid .uuid4 (),
110+ )
99111
100112 counterfactual_result = explainer .explainAsync (prediction , model ).get ()
101113 for entity in counterfactual_result .getEntities ():
102114 print (entity )
103115 assert entity is not None
104116
105117
106- def testCounterfactualMatch ():
118+ def test_counterfactual_match ():
119+ """Test if there's a valid counterfactual"""
107120 goal = [Output ("inside" , Type .BOOLEAN , Value (True ), 0.0 )]
108121
109- features = [FeatureFactory .newNumericalFeature (f"f-num{ i + 1 } " , 10.0 ) for i in range (4 )]
122+ features = [
123+ FeatureFactory .newNumericalFeature (f"f-num{ i + 1 } " , 10.0 ) for i in range (4 )
124+ ]
110125 constraints = [False ] * 4
111126 feature_boundaries = [NumericalFeatureDomain .create (0.0 , 1000.0 )] * 4
112127
@@ -115,11 +130,13 @@ def testCounterfactualMatch():
115130 center = 500.0
116131 epsilon = 10.0
117132
118- result = \
119- runCounterfactualSearch (goal ,
120- constraints ,
121- data_domain , features ,
122- TestUtils .getSumThresholdModel (center , epsilon ))
133+ result = run_counterfactual_search (
134+ goal ,
135+ constraints ,
136+ data_domain ,
137+ features ,
138+ TestUtils .getSumThresholdModel (center , epsilon ),
139+ )
123140
124141 total_sum = 0
125142 for entity in result .getEntities ():
0 commit comments