1818    counterfactual_prediction ,
1919    PredictionInput ,
2020    Model ,
21+     GoalCriteria ,
2122)
2223
23- 
2424from  trustyai .utils .data_conversions  import  (
2525    prediction_object_to_numpy ,
2626    prediction_object_to_pandas ,
@@ -184,12 +184,13 @@ def __init__(self, steps=10_000):
184184    def  explain (
185185        self ,
186186        inputs : OneInputUnionType ,
187-         goal : OneOutputUnionType ,
188187        model : Union [PredictionProvider , Model ],
188+         goal : Optional [OneOutputUnionType ] =  None ,
189189        feature_domains : List [FeatureDomain ] =  None ,
190190        data_distribution : Optional [DataDistribution ] =  None ,
191191        uuid : Optional [_uuid .UUID ] =  None ,
192192        timeout : Optional [float ] =  None ,
193+         criteria : Optional [GoalCriteria ] =  None ,
193194    ) ->  CounterfactualResult :
194195        r"""Request for a counterfactual explanation given a list of features, goals and a 
195196        :class:`~PredictionProvider` 
@@ -217,7 +218,9 @@ def explain(
217218        uuid : Optional[:class:`_uuid.UUID`] 
218219            The UUID to use during search. 
219220        timeout : Optional[float] 
220-                 The timeout time in seconds of the counterfactual explanation. 
221+             The timeout time in seconds of the counterfactual explanation. 
222+         criteria : Optional[:class:`GoalCriteria`] 
223+             An optional custom scoring function, wrapped as a :class:`GoalCriteria`. 
221224
222225        Returns 
223226        ------- 
@@ -226,6 +229,9 @@ def explain(
226229        """ 
227230        feature_names  =  model .feature_names  if  isinstance (model , Model ) else  None 
228231        output_names  =  model .output_names  if  isinstance (model , Model ) else  None 
232+         if  goal  is  None  and  criteria  is  None :
233+             raise  ValueError ("Either a goal or criteria must be provided." )
234+ 
229235        _prediction  =  counterfactual_prediction (
230236            input_features = one_input_convert (
231237                inputs , feature_names = feature_names , feature_domains = feature_domains 
@@ -236,6 +242,7 @@ def explain(
236242            data_distribution = data_distribution ,
237243            uuid = uuid ,
238244            timeout = timeout ,
245+             criteria = criteria ,
239246        )
240247
241248        with  Model .NonArrowTransmission (model ):
0 commit comments