Skip to content

Commit 644ca8f

Browse files
authored
Merge pull request #47 from ruivieira/deps/1.18
Update to new 1.18 Feature API
2 parents 407ba3d + b59edd1 commit 644ca8f

File tree

6 files changed

+65
-82
lines changed

6 files changed

+65
-82
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
![version](https://img.shields.io/badge/version-0.0.9-green) ![TrustyAI](https://img.shields.io/badge/TrustyAI-1.17-green) [![Tests](https://github.com/trustyai-python/module/actions/workflows/workflow.yml/badge.svg)](https://github.com/trustyai-python/examples/actions/workflows/workflow.yml)
1+
![version](https://img.shields.io/badge/version-0.0.9-green) ![TrustyAI](https://img.shields.io/badge/TrustyAI-1.18-green) [![Tests](https://github.com/trustyai-python/module/actions/workflows/workflow.yml/badge.svg)](https://github.com/trustyai-python/examples/actions/workflows/workflow.yml)
22

33
# python-trustyai
44

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
with open(os.path.join(here, 'README.md'), encoding='utf-8') as f:
88
long_description = f.read()
99

10-
TRUSTY_VERSION = "1.17.0.Final"
10+
TRUSTY_VERSION = "1.18.0.Final"
1111

1212

1313
class PostInstall(install):
@@ -27,7 +27,7 @@ def run(self):
2727

2828
setup(
2929
name="trustyai",
30-
version="0.0.9",
30+
version="0.1.0",
3131
description="Python bindings to the TrustyAI explainability library",
3232
long_description=long_description,
3333
long_description_content_type="text/markdown",

tests/test_conversions.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,21 @@ def test_feature_function():
7272
assert f3.name == "f-3"
7373
assert f3.value.as_string() == "foo"
7474
assert f3.type == Type.CATEGORICAL
75+
76+
77+
def test_feature_domains():
78+
"""Test domains"""
79+
f1 = feature(name="f-1", value=1.0, dtype="number")
80+
assert f1.name == "f-1"
81+
assert f1.value.as_number() == 1.0
82+
assert f1.type == Type.NUMBER
83+
assert f1.domain is None
84+
assert f1.is_constrained
85+
86+
f2 = feature(name="f-2", value=2.0, dtype="number", domain=(0.0, 10.0))
87+
assert f2.name == "f-2"
88+
assert f2.value.as_number() == 2.0
89+
assert f2.type == Type.NUMBER
90+
assert f2.domain
91+
print(f2.domain)
92+
assert not f2.is_constrained

tests/test_counterfactualexplainer.py

Lines changed: 5 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
# pylint: disable=import-error, wrong-import-position, wrong-import-order, R0801
22
"""Test suite for counterfactual explanations"""
33

4-
from common import *
5-
from pytest import approx
6-
74
from java.util import Random
5+
from pytest import approx
86

7+
from common import *
98
from trustyai.explainers import CounterfactualExplainer
109
from trustyai.model import (
1110
counterfactual_prediction,
@@ -24,17 +23,15 @@ def test_non_empty_input():
2423

2524
goal = [output(name="f-num1", dtype="number", value=10.0, score=0.0)]
2625
features = [
27-
feature(name=f"f-num{i}", value=i * 2.0, dtype="number")
26+
feature(name=f"f-num{i}", value=i * 2.0, dtype="number", domain=(0.0, 1000.0))
2827
for i in range(n_features)
2928
]
30-
domains = [(0.0, 1000.0)] * n_features
3129

3230
model = TestUtils.getSumSkipModel(0)
3331

3432
prediction = counterfactual_prediction(
3533
input_features=features,
3634
outputs=goal,
37-
domains=domains
3835
)
3936

4037
counterfactual_result = explainer.explain(prediction, model)
@@ -48,9 +45,8 @@ def test_counterfactual_match():
4845
goal = [output(name="inside", dtype="bool", value=True, score=0.0)]
4946

5047
features = [
51-
feature(name=f"f-num{i + 1}", value=10.0, dtype="number") for i in range(4)
48+
feature(name=f"f-num{i + 1}", value=10.0, dtype="number", domain=(0.0, 1000.0)) for i in range(4)
5249
]
53-
domains = [(0.0, 1000.0)] * 4
5450

5551
center = 500.0
5652
epsilon = 10.0
@@ -60,7 +56,6 @@ def test_counterfactual_match():
6056
prediction = counterfactual_prediction(
6157
input_features=features,
6258
outputs=goal,
63-
domains=domains
6459
)
6560
model = TestUtils.getSumThresholdModel(center, epsilon)
6661
result = explainer.explain(prediction, model)
@@ -86,55 +81,17 @@ def test_counterfactual_match_python_model():
8681
n_features = 5
8782

8883
features = [
89-
feature(name=f"f-num{i + 1}", value=10.0, dtype="number") for i in range(n_features)
84+
feature(name=f"f-num{i + 1}", value=10.0, dtype="number", domain=(0.0, 1000.0)) for i in range(n_features)
9085
]
91-
domains = [(0.0, 1000.0)] * n_features
9286

9387
explainer = CounterfactualExplainer(steps=1000)
9488

9589
prediction = counterfactual_prediction(
9690
input_features=features,
9791
outputs=goal,
98-
domains=domains
9992
)
10093

10194
model = Model(sum_skip_model)
10295

10396
result = explainer.explain(prediction, model)
10497
assert sum([entity.as_feature().value.as_number() for entity in result.entities]) == approx(GOAL_VALUE, rel=3)
105-
106-
107-
def test_default_constraints():
108-
goal = [output(name="sum-but-0", dtype="number", value=1000, score=1.0)]
109-
110-
n_features = 5
111-
112-
features = [
113-
feature(name=f"f-num{i + 1}", value=10.0, dtype="number") for i in range(n_features)
114-
]
115-
domains = [(0.0, 1000.0)] * n_features
116-
117-
prediction = counterfactual_prediction(
118-
input_features=features,
119-
outputs=goal,
120-
domains=domains
121-
)
122-
123-
assert len(prediction.constraints) == n_features
124-
125-
n_features = 10
126-
127-
features = [
128-
feature(name=f"f-num{i + 1}", value=10.0, dtype="number") for i in range(n_features)
129-
]
130-
domains = [(0.0, 1000.0)] * n_features
131-
132-
constaints = [False] * n_features
133-
prediction = counterfactual_prediction(
134-
input_features=features,
135-
outputs=goal,
136-
domains=domains,
137-
constraints=constaints
138-
)
139-
140-
assert len(prediction.constraints) == n_features

trustyai/__init__.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,25 @@
11
# pylint: disable = import-error, import-outside-toplevel, dangerous-default-value, invalid-name, R0801
22
"""Main TrustyAI Python bindings"""
3-
from typing import List
4-
import site
53
import os
4+
import site
65
import uuid
6+
from typing import List
7+
78
import jpype
89
import jpype.imports
910
from jpype import _jcustomizer, _jclass
1011

11-
TRUSTY_VERSION = "1.17.0.Final"
12+
TRUSTY_VERSION = "1.18.0.Final"
1213
DEFAULT_DEP_PATH = os.path.join(site.getsitepackages()[0], "trustyai", "dep")
1314

1415
CORE_DEPS = [
1516
f"{DEFAULT_DEP_PATH}/org/kie/kogito/explainability-core/{TRUSTY_VERSION}/*",
1617
f"{DEFAULT_DEP_PATH}/org/slf4j/slf4j-api/1.7.30/slf4j-api-1.7.30.jar",
1718
f"{DEFAULT_DEP_PATH}/org/apache/commons/commons-lang3/3.12.0/commons-lang3-3.12.0.jar",
18-
f"{DEFAULT_DEP_PATH}/org/optaplanner/optaplanner-core/8.17.0.Final/"
19-
f"optaplanner-core-8.17.0.Final.jar",
19+
f"{DEFAULT_DEP_PATH}/org/optaplanner/optaplanner-core-impl/8.18.0.Final/"
20+
f"optaplanner-core-impl-8.18.0.Final.jar",
2021
f"{DEFAULT_DEP_PATH}/org/apache/commons/commons-math3/3.6.1/commons-math3-3.6.1.jar",
21-
f"{DEFAULT_DEP_PATH}/org/kie/kie-api/8.17.0.Beta/kie-api-8.17.0.Beta.jar",
22+
f"{DEFAULT_DEP_PATH}/org/kie/kie-api/8.18.0.Beta/kie-api-8.18.0.Beta.jar",
2223
f"{DEFAULT_DEP_PATH}/io/micrometer/micrometer-core/1.8.2/micrometer-core-1.8.2.jar",
2324
]
2425

trustyai/model/__init__.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
# pylint: disable = import-error, too-few-public-methods, invalid-name, duplicate-code
22
"""General model classes"""
3-
from typing import List, Optional, Tuple
43
import uuid as _uuid
4+
from typing import List, Optional
55

66
from java.lang import Long
7-
from java.util.concurrent import CompletableFuture, ForkJoinPool
7+
from java.util.concurrent import CompletableFuture
88
from jpype import JImplements, JOverride, _jcustomizer, _jclass
9+
from org.kie.kogito.explainability.local.counterfactual.entities import (
10+
CounterfactualEntity,
11+
)
912
from org.kie.kogito.explainability.model import (
1013
CounterfactualPrediction as _CounterfactualPrediction,
1114
DataDistribution,
@@ -21,8 +24,9 @@
2124
Value as _Value,
2225
Type as _Type,
2326
)
24-
from org.kie.kogito.explainability.local.counterfactual.entities import (
25-
CounterfactualEntity,
27+
28+
from org.kie.kogito.explainability.model.domain import (
29+
EmptyFeatureDomain as _EmptyFeatureDomain,
2630
)
2731

2832
from trustyai.model.domain import feature_domain
@@ -165,6 +169,19 @@ def value(self):
165169
"""Return value"""
166170
return self.getValue()
167171

172+
@property
173+
def domain(self):
174+
"""Return domain"""
175+
_domain = self.getDomain()
176+
if isinstance(_domain, _EmptyFeatureDomain):
177+
return None
178+
return _domain
179+
180+
@property
181+
def is_constrained(self):
182+
"""Return contraint"""
183+
return self.isConstrained()
184+
168185

169186
@_jcustomizer.JImplementationFor("org.kie.kogito.explainability.model.Value")
170187
# pylint: disable=no-member
@@ -222,11 +239,6 @@ def output(self) -> PredictionOutput:
222239
"""Return input"""
223240
return self.getOutput()
224241

225-
@property
226-
def constraints(self):
227-
"""Return constraints"""
228-
return self.getConstraints()
229-
230242
@property
231243
def data_distribution(self):
232244
"""Return data distribution"""
@@ -266,16 +278,22 @@ def output(name, dtype, value=None, score=1.0) -> _Output:
266278
return _Output(name, _type, Value(value), score)
267279

268280

269-
def feature(name: str, dtype: str, value=None) -> Feature:
281+
def feature(name: str, dtype: str, value=None, domain=None) -> Feature:
270282
"""Helper method to build features"""
283+
271284
if dtype == "categorical":
272-
_feature = FeatureFactory.newCategoricalFeature(name, value)
285+
_factory = FeatureFactory.newCategoricalFeature
273286
elif dtype == "number":
274-
_feature = FeatureFactory.newNumericalFeature(name, value)
287+
_factory = FeatureFactory.newNumericalFeature
275288
elif dtype == "bool":
276-
_feature = FeatureFactory.newBooleanFeature(name, value)
289+
_factory = FeatureFactory.newBooleanFeature
277290
else:
278-
_feature = FeatureFactory.newObjectFeature(name, value)
291+
_factory = FeatureFactory.newObjectFeature
292+
293+
if domain:
294+
_feature = _factory(name, value, feature_domain(domain))
295+
else:
296+
_feature = _factory(name, value)
279297
return _feature
280298

281299

@@ -291,8 +309,6 @@ def simple_prediction(
291309
def counterfactual_prediction(
292310
input_features: List[Feature],
293311
outputs: List[Output],
294-
domains: List[Optional[Tuple]],
295-
constraints: Optional[List[bool]] = None,
296312
data_distribution: Optional[DataDistribution] = None,
297313
uuid: Optional[_uuid.UUID] = None,
298314
timeout: Optional[float] = None,
@@ -302,19 +318,10 @@ def counterfactual_prediction(
302318
uuid = _uuid.uuid4()
303319
if timeout:
304320
timeout = Long(timeout)
305-
if not constraints:
306-
constraints = [False] * len(input_features)
307-
308-
# build the feature domains from the Python tuples
309-
java_domains = _jclass.JClass("java.util.Arrays").asList(
310-
[feature_domain(domain) for domain in domains]
311-
)
312321

313322
return CounterfactualPrediction(
314323
PredictionInput(input_features),
315324
PredictionOutput(outputs),
316-
PredictionFeatureDomain(java_domains),
317-
constraints,
318325
data_distribution,
319326
uuid,
320327
timeout,

0 commit comments

Comments
 (0)