From e0bf1c76b730b51daecec228f2eab695bd40c732 Mon Sep 17 00:00:00 2001 From: Rob Geada Date: Wed, 1 Feb 2023 11:42:29 +0000 Subject: [PATCH 1/3] Fixed Categorical feature casting issues --- src/trustyai/explainers/counterfactuals.py | 20 +++++++--- src/trustyai/model/domain.py | 14 +++++-- src/trustyai/utils/data_conversions.py | 44 ++++++++++++++++++---- 3 files changed, 60 insertions(+), 18 deletions(-) diff --git a/src/trustyai/explainers/counterfactuals.py b/src/trustyai/explainers/counterfactuals.py index 727abdf..e5270d2 100644 --- a/src/trustyai/explainers/counterfactuals.py +++ b/src/trustyai/explainers/counterfactuals.py @@ -27,7 +27,7 @@ OneInputUnionType, OneOutputUnionType, data_conversion_docstring, - one_input_convert, + one_input_convert, java_string_capture, ) from org.kie.trustyai.explainability.local.counterfactual import ( @@ -79,6 +79,14 @@ def proposed_features_dataframe(self): [PredictionInput([entity.as_feature() for entity in self._result.entities])] ) + + def _get_feature_difference(self, value_pair): + proposed, original = value_pair + try: + return proposed - original + except: + return "{} -> {}".format(original, proposed) + def as_dataframe(self) -> pd.DataFrame: """ Return the counterfactual result as a dataframe @@ -99,15 +107,15 @@ def as_dataframe(self) -> pd.DataFrame: features = self._result.getFeatures() data = {} - data["features"] = [f"{entity.as_feature().getName()}" for entity in entities] - data["proposed"] = [entity.as_feature().value.as_obj() for entity in entities] - data["original"] = [ - feature.getValue().getUnderlyingObject() for feature in features + data["Features"] = [f"{entity.as_feature().getName()}" for entity in entities] + data["Proposed"] = [java_string_capture(entity.as_feature().value.as_obj()) for entity in entities] + data["Original"] = [ + java_string_capture(feature.getValue().getUnderlyingObject()) for feature in features ] data["constrained"] = [feature.is_constrained for feature in features] dfr = pd.DataFrame.from_dict(data) - dfr["difference"] = dfr.proposed - dfr.original + dfr["Difference"] = dfr[["Proposed", "Original"]].apply(self._get_feature_difference, 1) return dfr def as_html(self) -> pd.io.formats.style.Styler: diff --git a/src/trustyai/model/domain.py b/src/trustyai/model/domain.py index 1727dbd..9cfe66e 100644 --- a/src/trustyai/model/domain.py +++ b/src/trustyai/model/domain.py @@ -2,7 +2,8 @@ """Conversion method between Python and TrustyAI Java types""" from typing import Optional, Tuple, List, Union -from jpype import _jclass +import jpype +from jpype import _jclass, JArray from org.kie.trustyai.explainability.model.domain import ( FeatureDomain, @@ -60,16 +61,21 @@ def feature_domain(values: Optional[Union[Tuple, List]]) -> Optional[FeatureDoma domain = NumericalFeatureDomain.create(values[0], values[1]) elif isinstance(values, list): - java_array = _jclass.JClass("java.util.Arrays").asList(values) if isinstance(values[0], bool) and isinstance(values[1], bool): + java_values = [jpype.JBoolean(v) for v in values] + java_array = _jclass.JClass("java.util.Arrays").asList(java_values) domain = ObjectFeatureDomain.create(java_array) elif isinstance(values[0], (float, int)) and isinstance( values[1], (float, int) ): + if isinstance(values[0], float): + java_values = [jpype.JDouble(v) for v in values] + else: + java_values = [jpype.JInt(v) for v in values] + java_array = _jclass.JClass("java.util.Arrays").asList(java_values) domain = CategoricalNumericalFeatureDomain.create(java_array) - elif isinstance(values[0], str): - domain = CategoricalFeatureDomain.create(java_array) else: + java_array = _jclass.JClass("java.util.Arrays").asList(values) domain = ObjectFeatureDomain.create(java_array) else: diff --git a/src/trustyai/utils/data_conversions.py b/src/trustyai/utils/data_conversions.py index 1f69c62..05a53a3 100644 --- a/src/trustyai/utils/data_conversions.py +++ b/src/trustyai/utils/data_conversions.py @@ -12,12 +12,16 @@ Output, PredictionInput, PredictionOutput, + Type ) from org.kie.trustyai.explainability.model.domain import ( FeatureDomain, EmptyFeatureDomain, + NumericalFeatureDomain, + CategoricalFeatureDomain, + CategoricalNumericalFeatureDomain, + ObjectFeatureDomain, ) - import pandas as pd import numpy as np @@ -51,7 +55,20 @@ ManyOutputsUnionType = Union[np.ndarray, pd.DataFrame, List[PredictionOutput]] # trusty type names -trusty_type_map = {"i": "number", "O": "categorical", "f": "number", "b": "bool"} +trusty_type_map = { + "i": "categorical", + "U": "categorical", + "O": "object", + "f": "number", + "b": "bool" +} + +feature_domain_map = { + "NumericalFeatureDomain": Type.NUMBER, + "CategoricalFeatureDomain": Type.CATEGORICAL, + "CategoricalNumericalFeatureDomain": Type.CATEGORICAL, + "ObjectFeatureDomain": Type.CATEGORICAL +} # universal docstrings for functions that use these data conversions =============================== @@ -173,14 +190,15 @@ def domain_insertion( "previous domain with the new one.".format(i, f.toString()) ) warnings.warn(warning_msg) + domain_class_name = feature_domains[i].getClass().getSimpleName() + new_type = feature_domain_map.get(domain_class_name, f.getType()) domained_features.append( Feature( - f.getName(), f.getType(), f.getValue(), False, feature_domains[i] + f.getName(), new_type, f.getValue(), False, feature_domains[i] ) ) return PredictionInput(domained_features) - # === input functions ============================================================================== def one_input_convert( python_inputs: OneInputUnionType, @@ -363,6 +381,7 @@ def numpy_to_prediction_object( wrapper = PredictionOutput if names is None: names = [f"{prefix}-{i}" for i in range(shape[1])] + types = [trusty_type_map[array[:, i].dtype.kind] for i in range(shape[1])] predictions = [] for row_index in range(shape[0]): @@ -393,14 +412,14 @@ def prediction_object_to_numpy( if isinstance(objects[0], PredictionInput): arr = np.array( [ - [f.getValue().getUnderlyingObject() for f in pi.getFeatures()] + [java_string_capture(f.getValue().getUnderlyingObject()) for f in pi.getFeatures()] for pi in objects ] ) else: arr = np.array( [ - [o.getValue().getUnderlyingObject() for o in po.getOutputs()] + [java_string_capture(o.getValue().getUnderlyingObject()) for o in po.getOutputs()] for po in objects ] ) @@ -423,7 +442,8 @@ def prediction_object_to_pandas( df = pd.DataFrame( [ { - in_feature.getName(): in_feature.getValue().getUnderlyingObject() + str(in_feature.getName()): + java_string_capture(in_feature.getValue().getUnderlyingObject()) for in_feature in pi.getFeatures() } for pi in objects @@ -433,7 +453,8 @@ def prediction_object_to_pandas( df = pd.DataFrame( [ { - output.getName(): output.getValue().getUnderlyingObject() + str(output.getName()): + java_string_capture(output.getValue().getUnderlyingObject()) for output in po.getOutputs() } for po in objects @@ -569,3 +590,10 @@ def numpy_to_trusty_dataframe( pi = many_inputs_convert(arr) return Dataframe.createFromInputs(pi) + + +def java_string_capture(obj): + """Given some arbitrary object, convert it to a Python string if Java string, else + pass through unmodified. This prevents incorrect parsing of Java strings to Python + char tuples""" + return str(obj) if obj.getClass().getName() == "java.lang.String" else obj \ No newline at end of file From bd36d8ff8e99dd4876bd6e7f6c29fc926e0b8602 Mon Sep 17 00:00:00 2001 From: Rob Geada Date: Wed, 1 Feb 2023 11:48:51 +0000 Subject: [PATCH 2/3] linting and black --- src/trustyai/explainers/counterfactuals.py | 20 ++++++++----- src/trustyai/model/domain.py | 3 +- src/trustyai/utils/data_conversions.py | 33 +++++++++++++--------- 3 files changed, 34 insertions(+), 22 deletions(-) diff --git a/src/trustyai/explainers/counterfactuals.py b/src/trustyai/explainers/counterfactuals.py index e5270d2..d1a6a6c 100644 --- a/src/trustyai/explainers/counterfactuals.py +++ b/src/trustyai/explainers/counterfactuals.py @@ -27,7 +27,8 @@ OneInputUnionType, OneOutputUnionType, data_conversion_docstring, - one_input_convert, java_string_capture, + one_input_convert, + java_string_capture, ) from org.kie.trustyai.explainability.local.counterfactual import ( @@ -79,13 +80,12 @@ def proposed_features_dataframe(self): [PredictionInput([entity.as_feature() for entity in self._result.entities])] ) - def _get_feature_difference(self, value_pair): proposed, original = value_pair try: return proposed - original - except: - return "{} -> {}".format(original, proposed) + except TypeError: + return f"{original} -> {proposed}" def as_dataframe(self) -> pd.DataFrame: """ @@ -108,14 +108,20 @@ def as_dataframe(self) -> pd.DataFrame: data = {} data["Features"] = [f"{entity.as_feature().getName()}" for entity in entities] - data["Proposed"] = [java_string_capture(entity.as_feature().value.as_obj()) for entity in entities] + data["Proposed"] = [ + java_string_capture(entity.as_feature().value.as_obj()) + for entity in entities + ] data["Original"] = [ - java_string_capture(feature.getValue().getUnderlyingObject()) for feature in features + java_string_capture(feature.getValue().getUnderlyingObject()) + for feature in features ] data["constrained"] = [feature.is_constrained for feature in features] dfr = pd.DataFrame.from_dict(data) - dfr["Difference"] = dfr[["Proposed", "Original"]].apply(self._get_feature_difference, 1) + dfr["Difference"] = dfr[["Proposed", "Original"]].apply( + self._get_feature_difference, 1 + ) return dfr def as_html(self) -> pd.io.formats.style.Styler: diff --git a/src/trustyai/model/domain.py b/src/trustyai/model/domain.py index 9cfe66e..73f80d0 100644 --- a/src/trustyai/model/domain.py +++ b/src/trustyai/model/domain.py @@ -3,12 +3,11 @@ from typing import Optional, Tuple, List, Union import jpype -from jpype import _jclass, JArray +from jpype import _jclass from org.kie.trustyai.explainability.model.domain import ( FeatureDomain, NumericalFeatureDomain, - CategoricalFeatureDomain, CategoricalNumericalFeatureDomain, ObjectFeatureDomain, EmptyFeatureDomain, diff --git a/src/trustyai/utils/data_conversions.py b/src/trustyai/utils/data_conversions.py index 05a53a3..764c613 100644 --- a/src/trustyai/utils/data_conversions.py +++ b/src/trustyai/utils/data_conversions.py @@ -12,7 +12,7 @@ Output, PredictionInput, PredictionOutput, - Type + Type, ) from org.kie.trustyai.explainability.model.domain import ( FeatureDomain, @@ -60,14 +60,14 @@ "U": "categorical", "O": "object", "f": "number", - "b": "bool" + "b": "bool", } feature_domain_map = { "NumericalFeatureDomain": Type.NUMBER, "CategoricalFeatureDomain": Type.CATEGORICAL, "CategoricalNumericalFeatureDomain": Type.CATEGORICAL, - "ObjectFeatureDomain": Type.CATEGORICAL + "ObjectFeatureDomain": Type.CATEGORICAL, } @@ -193,12 +193,11 @@ def domain_insertion( domain_class_name = feature_domains[i].getClass().getSimpleName() new_type = feature_domain_map.get(domain_class_name, f.getType()) domained_features.append( - Feature( - f.getName(), new_type, f.getValue(), False, feature_domains[i] - ) + Feature(f.getName(), new_type, f.getValue(), False, feature_domains[i]) ) return PredictionInput(domained_features) + # === input functions ============================================================================== def one_input_convert( python_inputs: OneInputUnionType, @@ -412,14 +411,20 @@ def prediction_object_to_numpy( if isinstance(objects[0], PredictionInput): arr = np.array( [ - [java_string_capture(f.getValue().getUnderlyingObject()) for f in pi.getFeatures()] + [ + java_string_capture(f.getValue().getUnderlyingObject()) + for f in pi.getFeatures() + ] for pi in objects ] ) else: arr = np.array( [ - [java_string_capture(o.getValue().getUnderlyingObject()) for o in po.getOutputs()] + [ + java_string_capture(o.getValue().getUnderlyingObject()) + for o in po.getOutputs() + ] for po in objects ] ) @@ -442,8 +447,9 @@ def prediction_object_to_pandas( df = pd.DataFrame( [ { - str(in_feature.getName()): - java_string_capture(in_feature.getValue().getUnderlyingObject()) + str(in_feature.getName()): java_string_capture( + in_feature.getValue().getUnderlyingObject() + ) for in_feature in pi.getFeatures() } for pi in objects @@ -453,8 +459,9 @@ def prediction_object_to_pandas( df = pd.DataFrame( [ { - str(output.getName()): - java_string_capture(output.getValue().getUnderlyingObject()) + str(output.getName()): java_string_capture( + output.getValue().getUnderlyingObject() + ) for output in po.getOutputs() } for po in objects @@ -596,4 +603,4 @@ def java_string_capture(obj): """Given some arbitrary object, convert it to a Python string if Java string, else pass through unmodified. This prevents incorrect parsing of Java strings to Python char tuples""" - return str(obj) if obj.getClass().getName() == "java.lang.String" else obj \ No newline at end of file + return str(obj) if obj.getClass().getName() == "java.lang.String" else obj From 587220de28d713bf8310d493c7338f9021445702 Mon Sep 17 00:00:00 2001 From: Rob Geada Date: Wed, 1 Feb 2023 13:43:57 +0000 Subject: [PATCH 3/3] fixed jlong/jint casting when passing ints directly to jvm --- src/trustyai/explainers/counterfactuals.py | 8 ++-- src/trustyai/metrics/fairness/group.py | 22 +++++----- src/trustyai/model/__init__.py | 12 ++++-- src/trustyai/model/domain.py | 15 +++---- src/trustyai/utils/data_conversions.py | 20 ++++++--- tests/general/test_conversions.py | 18 ++------ tests/general/test_counterfactualexplainer.py | 41 ++++++++++++++++++- 7 files changed, 92 insertions(+), 44 deletions(-) diff --git a/src/trustyai/explainers/counterfactuals.py b/src/trustyai/explainers/counterfactuals.py index d1a6a6c..e7a20af 100644 --- a/src/trustyai/explainers/counterfactuals.py +++ b/src/trustyai/explainers/counterfactuals.py @@ -142,7 +142,7 @@ def plot(self, block=True) -> None: Plot the counterfactual result. """ _df = self.as_dataframe().copy() - _df = _df[_df["difference"] != 0.0] + _df = _df[_df["Difference"] != 0.0] def change_colour(value): if value == 0.0: @@ -154,9 +154,9 @@ def change_colour(value): return colour with mpl.rc_context(drcp): - colour = _df["difference"].transform(change_colour) - plot = _df[["features", "proposed", "original"]].plot.barh( - x="features", color={"proposed": colour, "original": "black"} + colour = _df["Difference"].transform(change_colour) + plot = _df[["Features", "Proposed", "Original"]].plot.barh( + x="Features", color={"Proposed": colour, "Original": "black"} ) plot.set_title("Counterfactual") plt.show(block=block) diff --git a/src/trustyai/metrics/fairness/group.py b/src/trustyai/metrics/fairness/group.py index 2d39277..b1a929e 100644 --- a/src/trustyai/metrics/fairness/group.py +++ b/src/trustyai/metrics/fairness/group.py @@ -12,6 +12,7 @@ OneOutputUnionType, one_output_convert, to_trusty_dataframe, + python_int_capture, ) ColumSelector = Union[List[int], List[str]] @@ -59,7 +60,7 @@ def statistical_parity_difference_model( ) -> float: """Calculate Statistical Parity Difference using a samples dataframe and a model""" favorable_prediction_object = one_output_convert(favorable) - _privilege_values = [Value(v) for v in privilege_values] + _privilege_values = [Value(python_int_capture(v)) for v in privilege_values] _jsamples = to_trusty_dataframe( data=samples, no_outputs=True, feature_names=feature_names ) @@ -103,7 +104,7 @@ def disparate_impact_ratio_model( ) -> float: """Calculate Disparate Impact Ration using a samples dataframe and a model""" favorable_prediction_object = one_output_convert(favorable) - _privilege_values = [Value(v) for v in privilege_values] + _privilege_values = [Value(python_int_capture(v)) for v in privilege_values] _jsamples = to_trusty_dataframe( data=samples, no_outputs=True, feature_names=feature_names ) @@ -131,8 +132,8 @@ def average_odds_difference( raise ValueError( f"Dataframes have different shapes ({test.shape} and {truth.shape})" ) - _privilege_values = [Value(v) for v in privilege_values] - _positive_class = [Value(v) for v in positive_class] + _privilege_values = [Value(python_int_capture(v)) for v in privilege_values] + _positive_class = [Value(python_int_capture(v)) for v in positive_class] # determine privileged columns _privilege_columns = _column_selector_to_index(privilege_columns, test) return FairnessMetrics.groupAverageOddsDifference( @@ -156,8 +157,8 @@ def average_odds_difference_model( _jsamples = to_trusty_dataframe( data=samples, no_outputs=True, feature_names=feature_names ) - _privilege_values = [Value(v) for v in privilege_values] - _positive_class = [Value(v) for v in positive_class] + _privilege_values = [Value(python_int_capture(v)) for v in privilege_values] + _positive_class = [Value(python_int_capture(v)) for v in positive_class] # determine privileged columns _privilege_columns = _column_selector_to_index(privilege_columns, samples) return FairnessMetrics.groupAverageOddsDifference( @@ -179,9 +180,10 @@ def average_predictive_value_difference( raise ValueError( f"Dataframes have different shapes ({test.shape} and {truth.shape})" ) - _privilege_values = [Value(v) for v in privilege_values] - _positive_class = [Value(v) for v in positive_class] + _privilege_values = [Value(python_int_capture(v)) for v in privilege_values] + _positive_class = [Value(python_int_capture(v)) for v in positive_class] _privilege_columns = _column_selector_to_index(privilege_columns, test) + return FairnessMetrics.groupAveragePredictiveValueDifference( to_trusty_dataframe(data=test, outputs=outputs, feature_names=feature_names), to_trusty_dataframe(data=truth, outputs=outputs, feature_names=feature_names), @@ -201,8 +203,8 @@ def average_predictive_value_difference_model( ) -> float: """Calculate Average Predictive Value Difference for a sample dataframe using the provided model""" _jsamples = to_trusty_dataframe(samples, no_outputs=True) - _privilege_values = [Value(v) for v in privilege_values] - _positive_class = [Value(v) for v in positive_class] + _privilege_values = [Value(python_int_capture(v)) for v in privilege_values] + _positive_class = [Value(python_int_capture(v)) for v in positive_class] # determine privileged columns _privilege_columns = _column_selector_to_index(privilege_columns, samples) return FairnessMetrics.groupAveragePredictiveValueDifference( diff --git a/src/trustyai/model/__init__.py b/src/trustyai/model/__init__.py index 748d0fd..d2a5cbe 100644 --- a/src/trustyai/model/__init__.py +++ b/src/trustyai/model/__init__.py @@ -7,6 +7,8 @@ import uuid as _uuid from abc import ABC from typing import List, Optional, Union, Callable, Tuple + +import jpype import pandas as pd import pyarrow as pa import numpy as np @@ -23,6 +25,7 @@ prediction_object_to_numpy, prediction_object_to_pandas, data_conversion_docstring, + python_int_capture, ) from trustyai.model.domain import feature_domain @@ -810,7 +813,8 @@ def output(name, dtype, value=None, score=1.0) -> _Output: _type = Type.CATEGORICAL else: _type = Type.UNDEFINED - return _Output(name, _type, Value(value), score) + + return _Output(name, _type, Value(python_int_capture(value)), score) def full_text_feature( @@ -859,12 +863,14 @@ def feature( """ if dtype == "categorical": - if isinstance(value, int): + if isinstance(value, (np.int64, int)): _factory = FeatureFactory.newCategoricalNumericalFeature value = JInt(value) - else: + elif isinstance(value, str): _factory = FeatureFactory.newCategoricalFeature value = JString(value) + else: + _factory = FeatureFactory.newObjectFeature elif dtype == "number": _factory = FeatureFactory.newNumericalFeature elif dtype == "bool": diff --git a/src/trustyai/model/domain.py b/src/trustyai/model/domain.py index 73f80d0..34be4bc 100644 --- a/src/trustyai/model/domain.py +++ b/src/trustyai/model/domain.py @@ -1,8 +1,9 @@ -# pylint: disable = import-error +# pylint: disable = import-error, unidiomatic-typecheck """Conversion method between Python and TrustyAI Java types""" from typing import Optional, Tuple, List, Union import jpype +import numpy as np from jpype import _jclass from org.kie.trustyai.explainability.model.domain import ( @@ -60,17 +61,17 @@ def feature_domain(values: Optional[Union[Tuple, List]]) -> Optional[FeatureDoma domain = NumericalFeatureDomain.create(values[0], values[1]) elif isinstance(values, list): - if isinstance(values[0], bool) and isinstance(values[1], bool): + if type(values[0]) == bool and type(values[1]) == bool: java_values = [jpype.JBoolean(v) for v in values] java_array = _jclass.JClass("java.util.Arrays").asList(java_values) domain = ObjectFeatureDomain.create(java_array) - elif isinstance(values[0], (float, int)) and isinstance( - values[1], (float, int) + elif isinstance(values[0], (float, int, np.number)) and isinstance( + values[1], (float, int, np.number) ): - if isinstance(values[0], float): - java_values = [jpype.JDouble(v) for v in values] - else: + if isinstance(values[0], (int, np.int64)): java_values = [jpype.JInt(v) for v in values] + else: + java_values = [jpype.JDouble(v) for v in values] java_array = _jclass.JClass("java.util.Arrays").asList(java_values) domain = CategoricalNumericalFeatureDomain.create(java_array) else: diff --git a/src/trustyai/utils/data_conversions.py b/src/trustyai/utils/data_conversions.py index 764c613..16ed8ea 100644 --- a/src/trustyai/utils/data_conversions.py +++ b/src/trustyai/utils/data_conversions.py @@ -5,6 +5,7 @@ from typing import Union, List, Optional, Tuple from itertools import filterfalse +import jpype import trustyai.model from org.kie.trustyai.explainability.model import ( Dataframe, @@ -56,9 +57,9 @@ # trusty type names trusty_type_map = { - "i": "categorical", + "i": "number", "U": "categorical", - "O": "object", + "O": "categorical", "f": "number", "b": "bool", } @@ -345,7 +346,7 @@ def df_to_prediction_object( values = list(row) collection = [] for fv in values: - f = func(name=fv[2], dtype=fv[1], value=fv[0]) + f = func(name=fv[2], dtype=fv[1], value=python_int_capture(fv[0])) collection.append(f) predictions.append(wrapper(collection)) return predictions @@ -389,7 +390,7 @@ def numpy_to_prediction_object( f = func( name=names[col_index], dtype=types[col_index], - value=array[row_index, col_index], + value=python_int_capture(array[row_index, col_index]), ) collection.append(f) predictions.append(wrapper(collection)) @@ -546,10 +547,10 @@ def df_to_trusty_dataframe( pi = many_inputs_convert( python_inputs=data.iloc[:, input_indices], feature_names=input_names ) + po = many_outputs_convert( python_outputs=data.iloc[:, output_indices], names=output_names ) - return Dataframe.createFrom(pi, po) pi = many_inputs_convert(data) @@ -589,6 +590,7 @@ def numpy_to_trusty_dataframe( pi = many_inputs_convert( python_inputs=np.take(arr, input_indices, axis), feature_names=input_names ) + po = many_outputs_convert( python_outputs=np.take(arr, output_indices, axis), names=output_names ) @@ -604,3 +606,11 @@ def java_string_capture(obj): pass through unmodified. This prevents incorrect parsing of Java strings to Python char tuples""" return str(obj) if obj.getClass().getName() == "java.lang.String" else obj + + +def python_int_capture(obj): + """Given some arbitrary object, convert it to a Java int if Python int, else + pass through unmodified. This prevents incorrect parsing of Python ints to Java longs""" + if not isinstance(obj, bool) and isinstance(obj, (int, np.int64)): + return jpype.JInt(obj) + return obj diff --git a/tests/general/test_conversions.py b/tests/general/test_conversions.py index 7893da3..dcdead2 100644 --- a/tests/general/test_conversions.py +++ b/tests/general/test_conversions.py @@ -51,12 +51,7 @@ def test_categorical_numeric_domain_list(): domain = [0, 1000] jdomain = feature_domain(domain) assert jdomain.getCategories().size() == 2 - assert jdomain.getCategories().containsAll(domain) - - domain = [0.0, 1000.0] - jdomain = feature_domain(domain) - assert jdomain.getCategories().size() == 2 - assert jdomain.getCategories().containsAll(domain) + assert [x in list(jdomain.getCategories()) for x in domain] def test_categorical_object_domain_list(): @@ -65,7 +60,7 @@ def test_categorical_object_domain_list(): jdomain = feature_domain(domain) assert str(jdomain.getClass().getSimpleName()) == "ObjectFeatureDomain" assert jdomain.getCategories().size() == 2 - assert jdomain.getCategories().containsAll(domain) + assert [x in list(jdomain.getCategories()) for x in domain] def test_categorical_object_domain_list_2(): @@ -74,7 +69,7 @@ def test_categorical_object_domain_list_2(): jdomain = feature_domain(domain) assert str(jdomain.getClass().getSimpleName()) == "ObjectFeatureDomain" assert jdomain.getCategories().size() == 2 - assert jdomain.getCategories().containsAll(domain) + assert [x in list(jdomain.getCategories()) for x in domain] def test_empty_domain(): @@ -88,12 +83,7 @@ def test_categorical_domain_tuple(): domain = ["foo", "bar", "baz"] jdomain = feature_domain(domain) assert jdomain.getCategories().size() == 3 - assert jdomain.getCategories().containsAll(list(domain)) - - domain = ["foo", "bar", "baz"] - jdomain = feature_domain(domain) - assert jdomain.getCategories().size() == 3 - assert jdomain.getCategories().containsAll(domain) + assert [x in list(jdomain.getCategories()) for x in domain] def test_feature_function(): diff --git a/tests/general/test_counterfactualexplainer.py b/tests/general/test_counterfactualexplainer.py index ef82230..44d2447 100644 --- a/tests/general/test_counterfactualexplainer.py +++ b/tests/general/test_counterfactualexplainer.py @@ -11,7 +11,7 @@ from java.util import Random from pytest import approx -from trustyai.explainers import CounterfactualExplainer +from trustyai.explainers import CounterfactualExplainer, LimeExplainer from trustyai.explainers.counterfactuals import GoalCriteria from org.kie.trustyai.explainability.local.counterfactual.goal import GoalScore from trustyai.model import ( @@ -272,3 +272,42 @@ def test_counterfactual_with_domain_argument_overwrite(): feature_domains=[feature_domain((-10, 10)) for _ in range(5)], model=model ) + + +def test_counterfactual_with_object_counterfactual(): + """Test categorical objects work with as_dataframe""" + np.random.seed(0) + + # will output 5 * 1 * 2 * 2 == 20 + # goal is 5 + 1 + 2 + 2 == 10 + data = pd.DataFrame([{"a": 5., "b": 1, "c": "alpha", "d": 2., "e": 2}]) + feature_domains = [ + feature_domain((0., 10.)), + feature_domain([0, 1, 2]), + feature_domain(["alpha", "beta", "gamma"]), + feature_domain((0., 200.)), + feature_domain([1, 2, 3]) + ] + + def pred_func(x): + out = np.zeros(len(x)) + for i, row in x.iterrows(): + if row["c"] == "alpha": + out[i] = row["a"] * row["b"] * row["d"] * row["e"] + elif row["c"] == "beta": + out[i] = row["a"] + row["b"] + row["d"] + row["e"] + else: + out[i] = -10. + return out/1. + + model = Model(pred_func, dataframe_input=True) + explainer = CounterfactualExplainer(steps=10_000) + + cf_result = explainer.explain( + inputs=data, + goal=np.array([10.]), + feature_domains=feature_domains, + model=model, + ) + + assert cf_result.as_dataframe().iloc[2]['Difference'] == "alpha -> beta" \ No newline at end of file