11# pylint: disable = import-error, too-few-public-methods, invalid-name, duplicate-code
22"""General model classes"""
3- from typing import List , Optional , Tuple
43import uuid as _uuid
4+ from typing import List , Optional
55
66from java .lang import Long
7- from java .util .concurrent import CompletableFuture , ForkJoinPool
7+ from java .util .concurrent import CompletableFuture
88from jpype import JImplements , JOverride , _jcustomizer , _jclass
9+ from org .kie .kogito .explainability .local .counterfactual .entities import (
10+ CounterfactualEntity ,
11+ )
912from org .kie .kogito .explainability .model import (
1013 CounterfactualPrediction as _CounterfactualPrediction ,
1114 DataDistribution ,
2124 Value as _Value ,
2225 Type as _Type ,
2326)
24- from org .kie .kogito .explainability .local .counterfactual .entities import (
25- CounterfactualEntity ,
27+
28+ from org .kie .kogito .explainability .model .domain import (
29+ EmptyFeatureDomain as _EmptyFeatureDomain ,
2630)
2731
2832from trustyai .model .domain import feature_domain
@@ -165,6 +169,19 @@ def value(self):
165169 """Return value"""
166170 return self .getValue ()
167171
172+ @property
173+ def domain (self ):
174+ """Return domain"""
175+ _domain = self .getDomain ()
176+ if isinstance (_domain , _EmptyFeatureDomain ):
177+ return None
178+ return _domain
179+
180+ @property
181+ def is_constrained (self ):
182+ """Return contraint"""
183+ return self .isConstrained ()
184+
168185
169186@_jcustomizer .JImplementationFor ("org.kie.kogito.explainability.model.Value" )
170187# pylint: disable=no-member
@@ -222,11 +239,6 @@ def output(self) -> PredictionOutput:
222239 """Return input"""
223240 return self .getOutput ()
224241
225- @property
226- def constraints (self ):
227- """Return constraints"""
228- return self .getConstraints ()
229-
230242 @property
231243 def data_distribution (self ):
232244 """Return data distribution"""
@@ -266,16 +278,22 @@ def output(name, dtype, value=None, score=1.0) -> _Output:
266278 return _Output (name , _type , Value (value ), score )
267279
268280
269- def feature (name : str , dtype : str , value = None ) -> Feature :
281+ def feature (name : str , dtype : str , value = None , domain = None ) -> Feature :
270282 """Helper method to build features"""
283+
271284 if dtype == "categorical" :
272- _feature = FeatureFactory .newCategoricalFeature ( name , value )
285+ _factory = FeatureFactory .newCategoricalFeature
273286 elif dtype == "number" :
274- _feature = FeatureFactory .newNumericalFeature ( name , value )
287+ _factory = FeatureFactory .newNumericalFeature
275288 elif dtype == "bool" :
276- _feature = FeatureFactory .newBooleanFeature ( name , value )
289+ _factory = FeatureFactory .newBooleanFeature
277290 else :
278- _feature = FeatureFactory .newObjectFeature (name , value )
291+ _factory = FeatureFactory .newObjectFeature
292+
293+ if domain :
294+ _feature = _factory (name , value , feature_domain (domain ))
295+ else :
296+ _feature = _factory (name , value )
279297 return _feature
280298
281299
@@ -291,8 +309,6 @@ def simple_prediction(
291309def counterfactual_prediction (
292310 input_features : List [Feature ],
293311 outputs : List [Output ],
294- domains : List [Optional [Tuple ]],
295- constraints : Optional [List [bool ]] = None ,
296312 data_distribution : Optional [DataDistribution ] = None ,
297313 uuid : Optional [_uuid .UUID ] = None ,
298314 timeout : Optional [float ] = None ,
@@ -302,19 +318,10 @@ def counterfactual_prediction(
302318 uuid = _uuid .uuid4 ()
303319 if timeout :
304320 timeout = Long (timeout )
305- if not constraints :
306- constraints = [False ] * len (input_features )
307-
308- # build the feature domains from the Python tuples
309- java_domains = _jclass .JClass ("java.util.Arrays" ).asList (
310- [feature_domain (domain ) for domain in domains ]
311- )
312321
313322 return CounterfactualPrediction (
314323 PredictionInput (input_features ),
315324 PredictionOutput (outputs ),
316- PredictionFeatureDomain (java_domains ),
317- constraints ,
318325 data_distribution ,
319326 uuid ,
320327 timeout ,
0 commit comments