Skip to content

Commit 9d34f02

Browse files
committed
Fix CI errors
1 parent c83f33c commit 9d34f02

File tree

10 files changed

+232
-181
lines changed

10 files changed

+232
-181
lines changed

.github/workflows/workflow.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ jobs:
3434
./deps.sh
3535
- name: Lint
3636
run: |
37-
pylint $(find trustyai -type f -name "*.py")
37+
pylint --ignore-imports=yes $(find trustyai -type f -name "*.py")
3838
- name: Test with pytest
3939
run: |
4040
pytest

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
[![Tests](https://github.com/ruivieira/python-trustyai/actions/workflows/workflow.yml/badge.svg)](https://github.com/ruivieira/python-trustyai/actions/workflows/workflow.yml)
2-
32
# python-trustyai
43

54
Python bindings to [TrustyAI](https://kogito.kie.org/trustyai/)'s explainability library.

tests/common.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# pylint: disable=R0801
2+
"""Common methods and models for tests"""
3+
from trustyai.model import (
4+
FeatureFactory,
5+
Output,
6+
PredictionOutput,
7+
Type,
8+
Value,
9+
)
10+
11+
12+
def mock_feature(value):
13+
"""Create a mock numerical feature"""
14+
return FeatureFactory.newNumericalFeature("f-num", value)
15+
16+
17+
def sum_skip_model(inputs):
18+
"""SumSkip test model"""
19+
prediction_outputs = []
20+
for prediction_input in inputs:
21+
features = prediction_input.getFeatures()
22+
result = 0.0
23+
for i in range(features.size()):
24+
if i != 0:
25+
result += features.get(i).getValue().asNumber()
26+
output = [Output("sum-but0", Type.NUMBER, Value(result), 1.0)]
27+
prediction_output = PredictionOutput(output)
28+
prediction_outputs.append(prediction_output)
29+
return prediction_outputs

tests/test_counterfactualexplainer.py

Lines changed: 81 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# pylint: disable=import-error, wrong-import-position, wrong-import-order, R0801
2+
"""Test suite for counterfactual explanations"""
13
import os
24
import sys
35
import uuid
@@ -7,17 +9,23 @@
79

810
import trustyai
911

10-
trustyai.init(path=[
11-
"./dep/org/kie/kogito/explainability-core/1.8.0.Final/*",
12-
"./dep/org/slf4j/slf4j-api/1.7.30/slf4j-api-1.7.30.jar",
13-
"./dep/org/apache/commons/commons-lang3/3.12.0/commons-lang3-3.12.0.jar",
14-
"./dep/org/optaplanner/optaplanner-core/8.8.0.Final/optaplanner-core-8.8.0.Final.jar",
15-
"./dep/org/apache/commons/commons-math3/3.6.1/commons-math3-3.6.1.jar",
16-
"./dep/org/kie/kie-api/7.55.0.Final/kie-api-7.55.0.Final.jar",
17-
"./dep/io/micrometer/micrometer-core/1.6.6/micrometer-core-1.6.6.jar",
18-
])
19-
20-
from trustyai.local.counterfactual import CounterfactualExplainer, CounterfactualConfigurationFactory
12+
trustyai.init(
13+
path=[
14+
"./dep/org/kie/kogito/explainability-core/1.8.0.Final/*",
15+
"./dep/org/slf4j/slf4j-api/1.7.30/slf4j-api-1.7.30.jar",
16+
"./dep/org/apache/commons/commons-lang3/3.12.0/commons-lang3-3.12.0.jar",
17+
"./dep/org/optaplanner/optaplanner-core/8.8.0.Final/optaplanner-core-8.8.0.Final.jar",
18+
"./dep/org/apache/commons/commons-math3/3.6.1/commons-math3-3.6.1.jar",
19+
"./dep/org/kie/kie-api/7.55.0.Final/kie-api-7.55.0.Final.jar",
20+
"./dep/io/micrometer/micrometer-core/1.6.6/micrometer-core-1.6.6.jar",
21+
]
22+
)
23+
24+
from trustyai.local.counterfactual import (
25+
CounterfactualExplainer,
26+
CounterfactualConfigurationFactory,
27+
)
28+
from trustyai.utils import TestUtils, Config
2129
from trustyai.model.domain import NumericalFeatureDomain
2230
from trustyai.model import (
2331
CounterfactualPrediction,
@@ -32,81 +40,88 @@
3240
)
3341
from java.util import Random
3442
from java.lang import Long
35-
from trustyai.utils import TestUtils, Config
43+
3644
from org.optaplanner.core.config.solver.termination import TerminationConfig
3745

3846
jrandom = Random()
3947
jrandom.setSeed(0)
4048

4149

42-
def mockFeature(d):
43-
return FeatureFactory.newNumericalFeature("f-num", d)
44-
45-
46-
def sumSkipModel(inputs):
47-
"""SumSkip test model"""
48-
prediction_outputs = []
49-
for predictionInput in inputs:
50-
features = predictionInput.getFeatures()
51-
result = 0.0
52-
for i in range(features.size()):
53-
if i != 0:
54-
result += features.get(i).getValue().asNumber()
55-
o = [Output(f"sum-but0", Type.NUMBER, Value(result), 1.0)]
56-
prediction_output = PredictionOutput(o)
57-
prediction_outputs.append(prediction_output)
58-
return prediction_outputs
59-
60-
61-
def runCounterfactualSearch(goal,
62-
constraints,
63-
dataDomain,
64-
features,
65-
model):
66-
terminationConfig = TerminationConfig().withScoreCalculationCountLimit(Long.valueOf(10_000))
67-
solverConfig = CounterfactualConfigurationFactory \
68-
.builder().withTerminationConfig(terminationConfig).build()
69-
70-
explainer = CounterfactualExplainer \
71-
.builder() \
72-
.withSolverConfig(solverConfig) \
50+
def run_counterfactual_search(goal, constraints, data_domain, features, model):
51+
"""Creates a CF explainer and returns a result"""
52+
termination_config = TerminationConfig().withScoreCalculationCountLimit(
53+
Long.valueOf(10_000)
54+
)
55+
solver_config = (
56+
CounterfactualConfigurationFactory.builder()
57+
.withTerminationConfig(termination_config)
7358
.build()
74-
input = PredictionInput(features)
59+
)
60+
61+
explainer = (
62+
CounterfactualExplainer.builder().withSolverConfig(solver_config).build()
63+
)
64+
input_ = PredictionInput(features)
7565
output = PredictionOutput(goal)
76-
domain = PredictionFeatureDomain(dataDomain.getFeatureDomains())
77-
prediction = CounterfactualPrediction(input, output, domain, constraints, None, uuid.uuid4())
78-
return explainer.explainAsync(prediction, model) \
79-
.get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit())
66+
domain = PredictionFeatureDomain(data_domain.getFeatureDomains())
67+
prediction = CounterfactualPrediction(
68+
input_, output, domain, constraints, None, uuid.uuid4()
69+
)
70+
return explainer.explainAsync(prediction, model).get(
71+
Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()
72+
)
8073

8174

82-
def testNonEmptyInput():
75+
def test_non_empty_input():
8376
"""Checks whether the returned CF entities are not null"""
84-
termination_config = TerminationConfig().withScoreCalculationCountLimit(Long.valueOf(1000))
85-
solver_config = CounterfactualConfigurationFactory.builder().withTerminationConfig(termination_config).build()
77+
termination_config = TerminationConfig().withScoreCalculationCountLimit(
78+
Long.valueOf(1000)
79+
)
80+
solver_config = (
81+
CounterfactualConfigurationFactory.builder()
82+
.withTerminationConfig(termination_config)
83+
.build()
84+
)
8685
n_features = 10
87-
explainer = CounterfactualExplainer.builder().withSolverConfig(solver_config).build()
88-
goal = [Output(f"f-num{i + 1}", Type.NUMBER, Value(10.0), 0.0) for i in range(n_features - 1)]
89-
features = [FeatureFactory.newNumericalFeature(f"f-num{i}", i * 2.0) for i in range(n_features)]
86+
explainer = (
87+
CounterfactualExplainer.builder().withSolverConfig(solver_config).build()
88+
)
89+
goal = [
90+
Output(f"f-num{i + 1}", Type.NUMBER, Value(10.0), 0.0)
91+
for i in range(n_features - 1)
92+
]
93+
features = [
94+
FeatureFactory.newNumericalFeature(f"f-num{i}", i * 2.0)
95+
for i in range(n_features)
96+
]
9097
constraints = [False] * n_features
9198
feature_boundaries = [NumericalFeatureDomain.create(0.0, 1000.0)] * n_features
9299

93100
model = TestUtils.getSumSkipModel(0)
94101
_input = PredictionInput(features)
95102
output = PredictionOutput(goal)
96-
prediction = CounterfactualPrediction(_input, output, PredictionFeatureDomain(feature_boundaries), constraints,
97-
None,
98-
uuid.uuid4())
103+
prediction = CounterfactualPrediction(
104+
_input,
105+
output,
106+
PredictionFeatureDomain(feature_boundaries),
107+
constraints,
108+
None,
109+
uuid.uuid4(),
110+
)
99111

100112
counterfactual_result = explainer.explainAsync(prediction, model).get()
101113
for entity in counterfactual_result.getEntities():
102114
print(entity)
103115
assert entity is not None
104116

105117

106-
def testCounterfactualMatch():
118+
def test_counterfactual_match():
119+
"""Test if there's a valid counterfactual"""
107120
goal = [Output("inside", Type.BOOLEAN, Value(True), 0.0)]
108121

109-
features = [FeatureFactory.newNumericalFeature(f"f-num{i+1}", 10.0) for i in range(4)]
122+
features = [
123+
FeatureFactory.newNumericalFeature(f"f-num{i+1}", 10.0) for i in range(4)
124+
]
110125
constraints = [False] * 4
111126
feature_boundaries = [NumericalFeatureDomain.create(0.0, 1000.0)] * 4
112127

@@ -115,11 +130,13 @@ def testCounterfactualMatch():
115130
center = 500.0
116131
epsilon = 10.0
117132

118-
result = \
119-
runCounterfactualSearch(goal,
120-
constraints,
121-
data_domain, features,
122-
TestUtils.getSumThresholdModel(center, epsilon))
133+
result = run_counterfactual_search(
134+
goal,
135+
constraints,
136+
data_domain,
137+
features,
138+
TestUtils.getSumThresholdModel(center, epsilon),
139+
)
123140

124141
total_sum = 0
125142
for entity in result.getEntities():

tests/test_datautils.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
import sys, os
2-
import pytest
1+
# pylint: disable=import-error, wrong-import-position, wrong-import-order, invalid-name
2+
"""Data utils test suite"""
3+
import sys
4+
import os
35
from pytest import approx
4-
import math
56
import random
67

78
myPath = os.path.dirname(os.path.abspath(__file__))
@@ -12,23 +13,26 @@
1213
trustyai.init()
1314

1415
from trustyai.utils import DataUtils
15-
from trustyai.model import PerturbationContext, Feature, FeatureFactory
16+
from trustyai.model import PerturbationContext, FeatureFactory
1617
from java.util import Random
1718

1819
jrandom = Random()
1920

2021

21-
def test_GetMean():
22+
def test_get_mean():
23+
"""Test GetMean"""
2224
data = [2, 4, 3, 5, 1]
2325
assert DataUtils.getMean(data) == approx(3, 1e-6)
2426

2527

26-
def test_GetStdDev():
28+
def test_get_std_dev():
29+
"""Test GetStdDev"""
2730
data = [2, 4, 3, 5, 1]
2831
assert DataUtils.getStdDev(data, 3) == approx(1.41, 1e-2)
2932

3033

31-
def test_GaussianKernel():
34+
def test_gaussian_kernel():
35+
"""Test Gaussian Kernel"""
3236
x = 0.0
3337
k = DataUtils.gaussianKernel(x, 0, 1)
3438
assert k == approx(0.398, 1e-2)
@@ -37,28 +41,32 @@ def test_GaussianKernel():
3741
assert k == approx(0.389, 1e-2)
3842

3943

40-
def test_EuclideanDistance():
44+
def test_euclidean_distance():
45+
"""Test Euclidean distance"""
4146
x = [1, 1]
4247
y = [2, 3]
4348
distance = DataUtils.euclideanDistance(x, y)
44-
assert 2.236 == approx(distance, 1e-3)
49+
assert approx(distance, 1e-3) == 2.236
4550

4651

47-
def test_HammingDistanceDouble():
52+
def test_hamming_distance_double():
53+
"""Test Hamming distance for doubles"""
4854
x = [2, 1]
4955
y = [2, 3]
5056
distance = DataUtils.hammingDistance(x, y)
5157
assert distance == approx(1, 1e-1)
5258

5359

54-
def test_HammingDistanceString():
60+
def test_hamming_distance_string():
61+
"""Test Hamming distance for strings"""
5562
x = "test1"
5663
y = "test2"
5764
distance = DataUtils.hammingDistance(x, y)
5865
assert distance == approx(1, 1e-1)
5966

6067

61-
def test_DoublesToFeatures():
68+
def test_doubles_to_features():
69+
"""Test doubles to features"""
6270
inputs = [1 if i % 2 == 0 else 0 for i in range(10)]
6371
features = DataUtils.doublesToFeatures(inputs)
6472
assert features is not None
@@ -69,36 +77,41 @@ def test_DoublesToFeatures():
6977
assert f.getValue() is not None
7078

7179

72-
def test_ExponentialSmoothingKernel():
80+
def test_exponential_smoothing_kernel():
81+
"""Test exponential smoothing kernel"""
7382
x = 0.218
7483
k = DataUtils.exponentialSmoothingKernel(x, 2)
7584
assert k == approx(0.994, 1e-3)
7685

7786

78-
def test_PerturbFeaturesEmpty():
87+
def test_perturb_features_empty():
88+
"""Test perturb empty features"""
7989
features = []
8090
perturbationContext = PerturbationContext(jrandom, 0)
8191
newFeatures = DataUtils.perturbFeatures(features, perturbationContext)
8292
assert newFeatures is not None
8393
assert len(features) == newFeatures.size()
8494

8595

86-
def testRandomDistributionGeneration():
96+
def test_random_distribution_generation():
97+
"""Test random distribution generation"""
8798
dataDistribution = DataUtils.generateRandomDataDistribution(10, 10, jrandom)
8899
assert dataDistribution is not None
89100
assert dataDistribution.asFeatureDistributions() is not None
90101
for featureDistribution in dataDistribution.asFeatureDistributions():
91102
assert featureDistribution is not None
92103

93104

94-
def testLinearizedNumericFeatures():
105+
def test_linearized_numeric_features():
106+
"""Test linearised numeric features"""
95107
f = FeatureFactory.newNumericalFeature("f-num", 1.0)
96108
features = [f]
97109
linearizedFeatures = DataUtils.getLinearizedFeatures(features)
98110
assert len(features) == linearizedFeatures.size()
99111

100112

101-
def testSampleWithReplacement():
113+
def test_sample_with_replacement():
114+
"""Test sample with replacement"""
102115
emptyValues = []
103116
emptySamples = DataUtils.sampleWithReplacement(emptyValues, 1, jrandom)
104117
assert emptySamples is not None

0 commit comments

Comments
 (0)