File tree Expand file tree Collapse file tree 2 files changed +45
-3
lines changed Expand file tree Collapse file tree 2 files changed +45
-3
lines changed Original file line number Diff line number Diff line change 1- from .hp import *
1+ from .hp import *
Original file line number Diff line number Diff line change 1- from argparse import Namespace
21import inspect
32import json
43import threading
5- import sys
4+ import inspect
65
76from typing import Any
87
@@ -441,6 +440,49 @@ def result_func(*arg, **kws):
441440 return wrapper
442441
443442
443+ def auto_param (func ):
444+ """
445+ Convert keyword arguments into hyperparameters
446+
447+ Examples:
448+
449+ >>> @auto_param
450+ ... def foo(a, b=2, c='c'):
451+ ... print(a, b, c)
452+
453+ >>> foo(1)
454+ 1 2 c
455+
456+ >>> with param_scope('foo.b=3'):
457+ ... foo(2)
458+ 2 3 c
459+ """
460+ predef_kws = {}
461+
462+ namespace = func .__module__
463+ if namespace == '__main__' :
464+ namespace = None
465+ if namespace is not None :
466+ namespace += '.{}' .format (func .__name__ )
467+ else :
468+ namespace = func .__name__
469+
470+ signature = inspect .signature (func )
471+ for k , v in signature .parameters .items ():
472+ if v .default != v .empty :
473+ name = '{}.{}' .format (namespace , k )
474+ predef_kws [k ] = name
475+
476+ def wrapper (* arg , ** kws ):
477+ with param_scope () as hp :
478+ for k , v in predef_kws .items ():
479+ if hp .get (v ) is not None :
480+ kws [k ] = hp .get (v )
481+ return func (* arg , ** kws )
482+
483+ return wrapper
484+
485+
444486def safe_numeric (value ):
445487 if isinstance (value , str ):
446488 try :
You can’t perform that action at this time.
0 commit comments