1818 counterfactual_prediction ,
1919 PredictionInput ,
2020 Model ,
21+ GoalCriteria ,
2122)
2223
23-
2424from trustyai .utils .data_conversions import (
2525 prediction_object_to_numpy ,
2626 prediction_object_to_pandas ,
@@ -184,12 +184,13 @@ def __init__(self, steps=10_000):
184184 def explain (
185185 self ,
186186 inputs : OneInputUnionType ,
187- goal : OneOutputUnionType ,
188187 model : Union [PredictionProvider , Model ],
188+ goal : Optional [OneOutputUnionType ] = None ,
189189 feature_domains : List [FeatureDomain ] = None ,
190190 data_distribution : Optional [DataDistribution ] = None ,
191191 uuid : Optional [_uuid .UUID ] = None ,
192192 timeout : Optional [float ] = None ,
193+ criteria : Optional [GoalCriteria ] = None ,
193194 ) -> CounterfactualResult :
194195 r"""Request for a counterfactual explanation given a list of features, goals and a
195196 :class:`~PredictionProvider`
@@ -217,7 +218,9 @@ def explain(
217218 uuid : Optional[:class:`_uuid.UUID`]
218219 The UUID to use during search.
219220 timeout : Optional[float]
220- The timeout time in seconds of the counterfactual explanation.
221+ The timeout time in seconds of the counterfactual explanation.
222+ criteria : Optional[:class:`GoalCriteria`]
223+ An optional custom scoring function, wrapped as a :class:`GoalCriteria`.
221224
222225 Returns
223226 -------
@@ -226,6 +229,9 @@ def explain(
226229 """
227230 feature_names = model .feature_names if isinstance (model , Model ) else None
228231 output_names = model .output_names if isinstance (model , Model ) else None
232+ if goal is None and criteria is None :
233+ raise ValueError ("Either a goal or criteria must be provided." )
234+
229235 _prediction = counterfactual_prediction (
230236 input_features = one_input_convert (
231237 inputs , feature_names = feature_names , feature_domains = feature_domains
@@ -236,6 +242,7 @@ def explain(
236242 data_distribution = data_distribution ,
237243 uuid = uuid ,
238244 timeout = timeout ,
245+ criteria = criteria ,
239246 )
240247
241248 with Model .NonArrowTransmission (model ):
0 commit comments