11"""Explainers.countefactual module"""
22# pylint: disable = import-error, too-few-public-methods, wrong-import-order, line-too-long,
33# pylint: disable = unused-argument
4- from typing import Optional , Union
4+ from typing import Optional , Union , List
55import matplotlib .pyplot as plt
66import matplotlib as mpl
77import pandas as pd
2020 Model ,
2121)
2222
23+
2324from trustyai .utils .data_conversions import (
2425 prediction_object_to_numpy ,
2526 prediction_object_to_pandas ,
2627 OneInputUnionType ,
2728 OneOutputUnionType ,
2829 data_conversion_docstring ,
30+ one_input_convert ,
2931)
3032
3133from org .kie .trustyai .explainability .local .counterfactual import (
3840 DataDistribution ,
3941 PredictionProvider ,
4042)
43+
44+ from org .kie .trustyai .explainability .model .domain import FeatureDomain
45+
4146from org .optaplanner .core .config .solver .termination import TerminationConfig
4247from java .lang import Long
4348
@@ -181,11 +186,12 @@ def explain(
181186 inputs : OneInputUnionType ,
182187 goal : OneOutputUnionType ,
183188 model : Union [PredictionProvider , Model ],
189+ feature_domains : List [FeatureDomain ] = None ,
184190 data_distribution : Optional [DataDistribution ] = None ,
185191 uuid : Optional [_uuid .UUID ] = None ,
186192 timeout : Optional [float ] = None ,
187193 ) -> CounterfactualResult :
188- """Request for a counterfactual explanation given a list of features, goals and a
194+ r """Request for a counterfactual explanation given a list of features, goals and a
189195 :class:`~PredictionProvider`
190196
191197 Parameters
@@ -197,6 +203,15 @@ def explain(
197203 These can take the form of a: {}
198204 model : :obj:`~trustyai.model.PredictionProvider`
199205 The TrustyAI model as generated by :class:`~trustyai.model.Model` or a Java :class:`PredictionProvider`
206+ feature_domains : List[:class:`FeatureDomain`]
207+ A list of feature domains (each created by :func:`~trustyai.model.feature_domain()`)
208+ that define the valid domain of the input features. The ith element of the list defines
209+ the domain of the ith input feature. If the ith element of this list is ``None``, the
210+ no domain information will be added to the ith feature. If the ith feature had no
211+ previously-supplied domain information, it will be taken to be constrained and
212+ non-variable. If ``feature_domains=None``, no domain information will be added to any
213+ of the features, thus preserving existing domains if they've been manually added
214+ previously or holding undomained features constrained.
200215 data_distribution : Optional[:class:`DataDistribution`]
201216 The :class:`DataDistribution` to use when sampling the inputs.
202217 uuid : Optional[:class:`_uuid.UUID`]
@@ -210,7 +225,7 @@ def explain(
210225 Object containing the results of the counterfactual explanation.
211226 """
212227 _prediction = counterfactual_prediction (
213- input_features = inputs ,
228+ input_features = one_input_convert ( inputs , feature_domains = feature_domains ) ,
214229 outputs = goal ,
215230 data_distribution = data_distribution ,
216231 uuid = uuid ,
0 commit comments