Skip to content

Commit 102229e

Browse files
committed
Update evaluate w/ global and local args and mapping
1 parent b42b048 commit 102229e

File tree

1 file changed

+37
-11
lines changed

1 file changed

+37
-11
lines changed

zephyr_ml/core.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -545,34 +545,60 @@ def predict(self, X=None, visual=False, **kwargs):
545545
return prediction, dict(zip(visual_names, outputs[-len(visual_names) :]))
546546

547547
return outputs
548+
549+
548550

549551
@guide
550-
def evaluate(self, X=None, y=None,metrics=None, additional_args = None, context_mapping = None, metric_args_mapping = None):
552+
def evaluate(self, X=None, y=None,metrics=None, global_args = None, local_args = None, global_mapping = None, local_mapping = None):
551553
if X is None:
552554
X = self.X_test
553555
if y is None:
554556
y = self.y_test
555557

556-
# may have multiple proba_steps and multiple produce args
557-
558-
# context_0 = self.pipeline.predict(X, output_=0)
559-
# y_proba = context_0["y_pred"][::, 1]
560558
final_context = self.pipeline.predict(X, output_=-1)
561559

560+
# remap items, if any
561+
if global_mapping is not None:
562+
for cur, new in global_mapping.items():
563+
if cur in final_context:
564+
cur_item = final_context.pop(cur)
565+
final_context[new] = cur_item
566+
567+
562568
if metrics is None:
563569
metrics = DEFAULT_METRICS
564-
if metric_args is None:
565-
metric_args = {}
570+
571+
if global_args is None:
572+
global_args = {}
573+
574+
if local_args is None:
575+
local_args = {}
576+
577+
if local_mapping is None:
578+
local_mapping = {}
579+
566580

567581
results = {}
568582
for metric in metrics:
569583
try:
570584
metric_primitive = self._get_ml_primitive(metric)
571-
additional_kwargs = {}
572-
if metric_primitive.name in metric_args:
573-
additional_kwargs = metric_args[metric_primitive.name]
585+
586+
if metric in local_mapping:
587+
metric_context = {}
588+
metric_mapping = local_mapping[metric]
589+
for cur, item in final_context.items():
590+
new = metric_mapping.get(cur, cur)
591+
metric_context[new] = item
592+
else:
593+
metric_context = final_context
594+
595+
596+
if metric in local_args:
597+
metric_args = local_args[metric]
598+
else:
599+
metric_args = {}
574600

575-
res = metric_primitive.produce(y_true = self.y_test, **final_context, **additional_kwargs)
601+
res = metric_primitive.produce(y_true = self.y_test, **metric_context, **metric_args)
576602
results[metric_primitive.name] = res
577603
except Exception as e:
578604
LOGGER.error(f"Unable to run evaluation metric: {metric_primitive.name}", exc_info = e)

0 commit comments

Comments
 (0)