@@ -545,34 +545,60 @@ def predict(self, X=None, visual=False, **kwargs):
545545 return prediction , dict (zip (visual_names , outputs [- len (visual_names ) :]))
546546
547547 return outputs
548+
549+
548550
549551 @guide
550- def evaluate (self , X = None , y = None ,metrics = None , additional_args = None , context_mapping = None , metric_args_mapping = None ):
552+ def evaluate (self , X = None , y = None ,metrics = None , global_args = None , local_args = None , global_mapping = None , local_mapping = None ):
551553 if X is None :
552554 X = self .X_test
553555 if y is None :
554556 y = self .y_test
555557
556- # may have multiple proba_steps and multiple produce args
557-
558- # context_0 = self.pipeline.predict(X, output_=0)
559- # y_proba = context_0["y_pred"][::, 1]
560558 final_context = self .pipeline .predict (X , output_ = - 1 )
561559
560+ # remap items, if any
561+ if global_mapping is not None :
562+ for cur , new in global_mapping .items ():
563+ if cur in final_context :
564+ cur_item = final_context .pop (cur )
565+ final_context [new ] = cur_item
566+
567+
562568 if metrics is None :
563569 metrics = DEFAULT_METRICS
564- if metric_args is None :
565- metric_args = {}
570+
571+ if global_args is None :
572+ global_args = {}
573+
574+ if local_args is None :
575+ local_args = {}
576+
577+ if local_mapping is None :
578+ local_mapping = {}
579+
566580
567581 results = {}
568582 for metric in metrics :
569583 try :
570584 metric_primitive = self ._get_ml_primitive (metric )
571- additional_kwargs = {}
572- if metric_primitive .name in metric_args :
573- additional_kwargs = metric_args [metric_primitive .name ]
585+
586+ if metric in local_mapping :
587+ metric_context = {}
588+ metric_mapping = local_mapping [metric ]
589+ for cur , item in final_context .items ():
590+ new = metric_mapping .get (cur , cur )
591+ metric_context [new ] = item
592+ else :
593+ metric_context = final_context
594+
595+
596+ if metric in local_args :
597+ metric_args = local_args [metric ]
598+ else :
599+ metric_args = {}
574600
575- res = metric_primitive .produce (y_true = self .y_test , ** final_context , ** additional_kwargs )
601+ res = metric_primitive .produce (y_true = self .y_test , ** metric_context , ** metric_args )
576602 results [metric_primitive .name ] = res
577603 except Exception as e :
578604 LOGGER .error (f"Unable to run evaluation metric: { metric_primitive .name } " , exc_info = e )
0 commit comments