11# pylint: disable = import-error, too-few-public-methods, invalid-name, duplicate-code 
22"""General model classes""" 
3- from  typing  import  List , Optional , Tuple 
43import  uuid  as  _uuid 
4+ from  typing  import  List , Optional 
55
66from  java .lang  import  Long 
7- from  java .util .concurrent  import  CompletableFuture ,  ForkJoinPool 
7+ from  java .util .concurrent  import  CompletableFuture 
88from  jpype  import  JImplements , JOverride , _jcustomizer , _jclass 
9+ from  org .kie .kogito .explainability .local .counterfactual .entities  import  (
10+     CounterfactualEntity ,
11+ )
912from  org .kie .kogito .explainability .model  import  (
1013    CounterfactualPrediction  as  _CounterfactualPrediction ,
1114    DataDistribution ,
2124    Value  as  _Value ,
2225    Type  as  _Type ,
2326)
24- from  org .kie .kogito .explainability .local .counterfactual .entities  import  (
25-     CounterfactualEntity ,
27+ 
28+ from  org .kie .kogito .explainability .model .domain  import  (
29+     EmptyFeatureDomain  as  _EmptyFeatureDomain ,
2630)
2731
2832from  trustyai .model .domain  import  feature_domain 
@@ -165,6 +169,19 @@ def value(self):
165169        """Return value""" 
166170        return  self .getValue ()
167171
172+     @property  
173+     def  domain (self ):
174+         """Return domain""" 
175+         _domain  =  self .getDomain ()
176+         if  isinstance (_domain , _EmptyFeatureDomain ):
177+             return  None 
178+         return  _domain 
179+ 
180+     @property  
181+     def  is_constrained (self ):
182+         """Return contraint""" 
183+         return  self .isConstrained ()
184+ 
168185
169186@_jcustomizer .JImplementationFor ("org.kie.kogito.explainability.model.Value" ) 
170187# pylint: disable=no-member 
@@ -222,11 +239,6 @@ def output(self) -> PredictionOutput:
222239        """Return input""" 
223240        return  self .getOutput ()
224241
225-     @property  
226-     def  constraints (self ):
227-         """Return constraints""" 
228-         return  self .getConstraints ()
229- 
230242    @property  
231243    def  data_distribution (self ):
232244        """Return data distribution""" 
@@ -266,16 +278,22 @@ def output(name, dtype, value=None, score=1.0) -> _Output:
266278    return  _Output (name , _type , Value (value ), score )
267279
268280
269- def  feature (name : str , dtype : str , value = None ) ->  Feature :
281+ def  feature (name : str , dtype : str , value = None ,  domain = None ) ->  Feature :
270282    """Helper method to build features""" 
283+ 
271284    if  dtype  ==  "categorical" :
272-         _feature  =  FeatureFactory .newCategoricalFeature ( name ,  value ) 
285+         _factory  =  FeatureFactory .newCategoricalFeature 
273286    elif  dtype  ==  "number" :
274-         _feature  =  FeatureFactory .newNumericalFeature ( name ,  value ) 
287+         _factory  =  FeatureFactory .newNumericalFeature 
275288    elif  dtype  ==  "bool" :
276-         _feature  =  FeatureFactory .newBooleanFeature ( name ,  value ) 
289+         _factory  =  FeatureFactory .newBooleanFeature 
277290    else :
278-         _feature  =  FeatureFactory .newObjectFeature (name , value )
291+         _factory  =  FeatureFactory .newObjectFeature 
292+ 
293+     if  domain :
294+         _feature  =  _factory (name , value , feature_domain (domain ))
295+     else :
296+         _feature  =  _factory (name , value )
279297    return  _feature 
280298
281299
@@ -291,8 +309,6 @@ def simple_prediction(
291309def  counterfactual_prediction (
292310    input_features : List [Feature ],
293311    outputs : List [Output ],
294-     domains : List [Optional [Tuple ]],
295-     constraints : Optional [List [bool ]] =  None ,
296312    data_distribution : Optional [DataDistribution ] =  None ,
297313    uuid : Optional [_uuid .UUID ] =  None ,
298314    timeout : Optional [float ] =  None ,
@@ -302,19 +318,10 @@ def counterfactual_prediction(
302318        uuid  =  _uuid .uuid4 ()
303319    if  timeout :
304320        timeout  =  Long (timeout )
305-     if  not  constraints :
306-         constraints  =  [False ] *  len (input_features )
307- 
308-     # build the feature domains from the Python tuples 
309-     java_domains  =  _jclass .JClass ("java.util.Arrays" ).asList (
310-         [feature_domain (domain ) for  domain  in  domains ]
311-     )
312321
313322    return  CounterfactualPrediction (
314323        PredictionInput (input_features ),
315324        PredictionOutput (outputs ),
316-         PredictionFeatureDomain (java_domains ),
317-         constraints ,
318325        data_distribution ,
319326        uuid ,
320327        timeout ,
0 commit comments