Skip to content

Commit 801b9d5

Browse files
authored
Merge pull request #42 from ruivieira/FAI-710
Move simple and counterfactual predictions to the model package
2 parents 87e2e3a + ab523eb commit 801b9d5

File tree

5 files changed

+51
-66
lines changed

5 files changed

+51
-66
lines changed

tests/test_counterfactualexplainer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
from java.util import Random
88

99
from trustyai.explainers import CounterfactualExplainer
10-
from trustyai.local.counterfactual import counterfactual_prediction
1110
from trustyai.model import (
12-
FeatureFactory,
11+
counterfactual_prediction,
1312
output, Model, feature,
1413
)
1514
from trustyai.utils import TestUtils

tests/test_limeexplainer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66
import pytest
77

88
from trustyai.explainers import LimeExplainer
9-
from trustyai.local.counterfactual import simple_prediction
109
from trustyai.utils import TestUtils
11-
from trustyai.model import feature
10+
from trustyai.model import feature, simple_prediction
1211

1312
from org.kie.kogito.explainability.local import (
1413
LocalExplanationException,

tests/test_shap.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
import pytest
77

88
from trustyai.explainers import SHAPExplainer
9-
from trustyai.local.counterfactual import simple_prediction
10-
from trustyai.model import feature, PredictionInput
9+
from trustyai.model import feature, PredictionInput, simple_prediction
1110
from trustyai.utils import TestUtils
1211

1312

trustyai/local/counterfactual.py

Lines changed: 0 additions & 59 deletions
This file was deleted.

trustyai/model/__init__.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
# pylint: disable = import-error, too-few-public-methods, invalid-name, duplicate-code
22
"""General model classes"""
3-
from typing import List
3+
from typing import List, Optional, Tuple
4+
import uuid as _uuid
45

6+
from java.lang import Long
57
from java.util.concurrent import CompletableFuture, ForkJoinPool
68
from jpype import JImplements, JOverride, _jcustomizer, _jclass
79
from org.kie.kogito.explainability.model import (
810
CounterfactualPrediction as _CounterfactualPrediction,
11+
DataDistribution,
912
DataDomain as _DataDomain,
1013
Feature,
1114
FeatureFactory as _FeatureFactory,
@@ -22,6 +25,8 @@
2225
CounterfactualEntity,
2326
)
2427

28+
from trustyai.model.domain import feature_domain
29+
2530
CounterfactualPrediction = _CounterfactualPrediction
2631
DataDomain = _DataDomain
2732
FeatureFactory = _FeatureFactory
@@ -272,3 +277,45 @@ def feature(name: str, dtype: str, value=None) -> Feature:
272277
else:
273278
_feature = FeatureFactory.newObjectFeature(name, value)
274279
return _feature
280+
281+
282+
def simple_prediction(
283+
input_features: List[Feature],
284+
outputs: List[Output],
285+
) -> SimplePrediction:
286+
"""Helper to build SimplePrediction"""
287+
return SimplePrediction(PredictionInput(input_features), PredictionOutput(outputs))
288+
289+
290+
# pylint: disable=too-many-arguments
291+
def counterfactual_prediction(
292+
input_features: List[Feature],
293+
outputs: List[Output],
294+
domains: List[Optional[Tuple]],
295+
constraints: Optional[List[bool]] = None,
296+
data_distribution: Optional[DataDistribution] = None,
297+
uuid: Optional[_uuid.UUID] = None,
298+
timeout: Optional[float] = None,
299+
) -> CounterfactualPrediction:
300+
"""Helper to build CounterfactualPrediction"""
301+
if not uuid:
302+
uuid = _uuid.uuid4()
303+
if timeout:
304+
timeout = Long(timeout)
305+
if not constraints:
306+
constraints = [False] * len(input_features)
307+
308+
# build the feature domains from the Python tuples
309+
java_domains = _jclass.JClass("java.util.Arrays").asList(
310+
[feature_domain(domain) for domain in domains]
311+
)
312+
313+
return CounterfactualPrediction(
314+
PredictionInput(input_features),
315+
PredictionOutput(outputs),
316+
PredictionFeatureDomain(java_domains),
317+
constraints,
318+
data_distribution,
319+
uuid,
320+
timeout,
321+
)

0 commit comments

Comments
 (0)