33from typing import Optional , Tuple , List , Union
44
55from jpype import _jclass
6+
67from org .kie .trustyai .explainability .model .domain import (
78 FeatureDomain ,
89 NumericalFeatureDomain ,
910 CategoricalFeatureDomain ,
11+ CategoricalNumericalFeatureDomain ,
12+ ObjectFeatureDomain ,
1013 EmptyFeatureDomain ,
1114)
1215
1316
14- def feature_domain (
15- values : Optional [Union [Tuple , List [str ]]]
16- ) -> Optional [FeatureDomain ]:
17+ def feature_domain (values : Optional [Union [Tuple , List ]]) -> Optional [FeatureDomain ]:
1718 r"""Create a Java :class:`FeatureDomain`. This represents the valid range of values for a
1819 particular feature, which is useful when constraining a counterfactual explanation to ensure it
1920 only recovers valid inputs. For example, if we had a feature that described a person's age, we
@@ -22,13 +23,18 @@ def feature_domain(
2223
2324 Parameters
2425 ----------
25- values : Optional[Union[Tuple, List[str] ]]
26+ values : Optional[Union[Tuple, List]]
2627 The valid values of the feature. If `values` takes the form of:
2728
2829 * **A tuple of floats or integers:** The feature domain will be a continuous range from
2930 ``values[0]`` to ``values[1]``.
30- * **A list of strings:** The feature domain will be categorical, where `values` contains
31- all possible valid feature values.
31+ * **A list of floats or integers:**: The feature domain will be a *numeric* categorical,
32+ where `values` contains all possible valid feature values.
33+ * **A list of strings:** The feature domain will be a *string* categorical, where `values`
34+ contains all possible valid feature values.
35+ * **A list of objects:** The feature domain will be an *object* categorical, where `values`
36+ contains all possible valid feature values. These may present an issue if the objects
37+ are not natively Java serializable.
3238
3339 Otherwise, the feature domain will be taken as `Empty`, which will mean it will be held
3440 fixed during the counterfactual explanation.
@@ -43,12 +49,29 @@ def feature_domain(
4349 if not values :
4450 domain = EmptyFeatureDomain .create ()
4551 else :
46- if isinstance (values [0 ], (float , int )):
47- domain = NumericalFeatureDomain .create (values [0 ], values [1 ])
48- elif isinstance (values [0 ], str ):
49- domain = CategoricalFeatureDomain .create (
50- _jclass .JClass ("java.util.Arrays" ).asList (values )
52+ if isinstance (values , tuple ):
53+ assert isinstance (values [0 ], (float , int )) and isinstance (
54+ values [1 ], (float , int )
55+ )
56+ assert len (values ) == 2 , (
57+ "Tuples passed as domain values must only contain"
58+ " two values that define the (minimum, maximum) of the domain"
5159 )
60+ domain = NumericalFeatureDomain .create (values [0 ], values [1 ])
61+
62+ elif isinstance (values , list ):
63+ java_array = _jclass .JClass ("java.util.Arrays" ).asList (values )
64+ if isinstance (values [0 ], bool ) and isinstance (values [1 ], bool ):
65+ domain = ObjectFeatureDomain .create (java_array )
66+ elif isinstance (values [0 ], (float , int )) and isinstance (
67+ values [1 ], (float , int )
68+ ):
69+ domain = CategoricalNumericalFeatureDomain .create (java_array )
70+ elif isinstance (values [0 ], str ):
71+ domain = CategoricalFeatureDomain .create (java_array )
72+ else :
73+ domain = ObjectFeatureDomain .create (java_array )
74+
5275 else :
5376 domain = EmptyFeatureDomain .create ()
5477 return domain
0 commit comments