|
1 | 1 | # pylint: disable = import-error, too-few-public-methods, invalid-name, duplicate-code |
2 | 2 | """General model classes""" |
3 | | -from typing import List |
| 3 | +from typing import List, Optional, Tuple |
| 4 | +import uuid as _uuid |
4 | 5 |
|
| 6 | +from java.lang import Long |
5 | 7 | from java.util.concurrent import CompletableFuture, ForkJoinPool |
6 | 8 | from jpype import JImplements, JOverride, _jcustomizer, _jclass |
7 | 9 | from org.kie.kogito.explainability.model import ( |
8 | 10 | CounterfactualPrediction as _CounterfactualPrediction, |
| 11 | + DataDistribution, |
9 | 12 | DataDomain as _DataDomain, |
10 | 13 | Feature, |
11 | 14 | FeatureFactory as _FeatureFactory, |
|
22 | 25 | CounterfactualEntity, |
23 | 26 | ) |
24 | 27 |
|
| 28 | +from trustyai.model.domain import feature_domain |
| 29 | + |
25 | 30 | CounterfactualPrediction = _CounterfactualPrediction |
26 | 31 | DataDomain = _DataDomain |
27 | 32 | FeatureFactory = _FeatureFactory |
@@ -272,3 +277,45 @@ def feature(name: str, dtype: str, value=None) -> Feature: |
272 | 277 | else: |
273 | 278 | _feature = FeatureFactory.newObjectFeature(name, value) |
274 | 279 | return _feature |
| 280 | + |
| 281 | + |
| 282 | +def simple_prediction( |
| 283 | + input_features: List[Feature], |
| 284 | + outputs: List[Output], |
| 285 | +) -> SimplePrediction: |
| 286 | + """Helper to build SimplePrediction""" |
| 287 | + return SimplePrediction(PredictionInput(input_features), PredictionOutput(outputs)) |
| 288 | + |
| 289 | + |
| 290 | +# pylint: disable=too-many-arguments |
| 291 | +def counterfactual_prediction( |
| 292 | + input_features: List[Feature], |
| 293 | + outputs: List[Output], |
| 294 | + domains: List[Optional[Tuple]], |
| 295 | + constraints: Optional[List[bool]] = None, |
| 296 | + data_distribution: Optional[DataDistribution] = None, |
| 297 | + uuid: Optional[_uuid.UUID] = None, |
| 298 | + timeout: Optional[float] = None, |
| 299 | +) -> CounterfactualPrediction: |
| 300 | + """Helper to build CounterfactualPrediction""" |
| 301 | + if not uuid: |
| 302 | + uuid = _uuid.uuid4() |
| 303 | + if timeout: |
| 304 | + timeout = Long(timeout) |
| 305 | + if not constraints: |
| 306 | + constraints = [False] * len(input_features) |
| 307 | + |
| 308 | + # build the feature domains from the Python tuples |
| 309 | + java_domains = _jclass.JClass("java.util.Arrays").asList( |
| 310 | + [feature_domain(domain) for domain in domains] |
| 311 | + ) |
| 312 | + |
| 313 | + return CounterfactualPrediction( |
| 314 | + PredictionInput(input_features), |
| 315 | + PredictionOutput(outputs), |
| 316 | + PredictionFeatureDomain(java_domains), |
| 317 | + constraints, |
| 318 | + data_distribution, |
| 319 | + uuid, |
| 320 | + timeout, |
| 321 | + ) |
0 commit comments