Skip to content

Commit 7436931

Browse files
committed
support suffix in param_scope
Signed-off-by: reiase <[email protected]>
1 parent d81302c commit 7436931

File tree

1 file changed

+32
-2
lines changed

1 file changed

+32
-2
lines changed

hyperparameter/hyperparameter.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,33 @@ def __enter__(self):
452452
def __exit__(self, exc_type, exc_value, traceback):
453453
_param_scope.tls.history.pop()
454454

455+
def __call__(self, suffix=None) -> Any:
456+
"""
457+
>>> @auto_param('myns.foo.params')
458+
... def foo(a, b=2, c='c', d=None):
459+
... print(a, b, c, d)
460+
461+
>>> def test():
462+
... with param_scope["sec1"]():
463+
... with param_scope["sec2"]():
464+
... foo(1)
465+
466+
>>> test()
467+
1 2 c None
468+
469+
>>> with param_scope(**{"myns.foo.params.b": 1}):
470+
... test()
471+
1 1 c None
472+
473+
>>> with param_scope(**{"myns.foo.params.b@sec1#sec2": 3}) as ps:
474+
... print(f"ps = {ps}")
475+
... test()
476+
ps = {'myns': {'foo': {'params': {'b@sec1#sec2': 3}}}}
477+
1 3 c None
478+
"""
479+
suffix = dict.get(self, "_suffix", None) if suffix is None else suffix
480+
return _Accessor(self, suffix=suffix)
481+
455482
@staticmethod
456483
def current():
457484
if not hasattr(_param_scope.tls, "history"):
@@ -474,7 +501,10 @@ def __init__(self, index=None):
474501
def __call__(self, *args, **kwargs):
475502
retval = _param_scope(*args, **kwargs)
476503
if self._index is not None:
477-
retval._suffix = self._index
504+
if dict.get(retval, "_suffix", None) is not None:
505+
retval._suffix = f"{retval._suffix}#{self._index}"
506+
else:
507+
retval._suffix = self._index
478508
return retval
479509

480510
def __getitem__(self, index):
@@ -574,7 +604,7 @@ def inner(*arg, **kws):
574604
local_params = {}
575605
for k, v in predef_kws.items():
576606
if getattr(hp(), v).get_or_else(None) is not None and k not in kws:
577-
kws[k] = hp.get(v)
607+
kws[k] = getattr(hp(), v).get_or_else(None)
578608
local_params[v] = hp.get(v)
579609
else:
580610
local_params[v] = predef_val[v]

0 commit comments

Comments
 (0)