Skip to content

Commit b6f8349

Browse files
committed
add auto_param
1 parent 6e3ef61 commit b6f8349

File tree

2 files changed

+45
-3
lines changed

2 files changed

+45
-3
lines changed

hyperparameter/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .hp import *
1+
from .hp import *

hyperparameter/hp.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
from argparse import Namespace
21
import inspect
32
import json
43
import threading
5-
import sys
4+
import inspect
65

76
from typing import Any
87

@@ -441,6 +440,49 @@ def result_func(*arg, **kws):
441440
return wrapper
442441

443442

443+
def auto_param(func):
444+
"""
445+
Convert keyword arguments into hyperparameters
446+
447+
Examples:
448+
449+
>>> @auto_param
450+
... def foo(a, b=2, c='c'):
451+
... print(a, b, c)
452+
453+
>>> foo(1)
454+
1 2 c
455+
456+
>>> with param_scope('foo.b=3'):
457+
... foo(2)
458+
2 3 c
459+
"""
460+
predef_kws = {}
461+
462+
namespace = func.__module__
463+
if namespace == '__main__':
464+
namespace = None
465+
if namespace is not None:
466+
namespace += '.{}'.format(func.__name__)
467+
else:
468+
namespace = func.__name__
469+
470+
signature = inspect.signature(func)
471+
for k, v in signature.parameters.items():
472+
if v.default != v.empty:
473+
name = '{}.{}'.format(namespace, k)
474+
predef_kws[k] = name
475+
476+
def wrapper(*arg, **kws):
477+
with param_scope() as hp:
478+
for k, v in predef_kws.items():
479+
if hp.get(v) is not None:
480+
kws[k] = hp.get(v)
481+
return func(*arg, **kws)
482+
483+
return wrapper
484+
485+
444486
def safe_numeric(value):
445487
if isinstance(value, str):
446488
try:

0 commit comments

Comments
 (0)