Skip to content

Commit 07cb38e

Browse files
committed
add tracker callback for auto_param
1 parent 971cdf1 commit 07cb38e

File tree

3 files changed

+21
-5
lines changed

3 files changed

+21
-5
lines changed

examples/sparse_lr/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
from sklearn.linear_model import LogisticRegression
33
import matplotlib.pyplot as plt
44

5-
from hyperparameter import auto_param, param_scope
5+
from hyperparameter import auto_param, param_scope, set_tracker
66

7+
set_tracker(lambda x: print(x))
78

89
MyLogisticRegression = auto_param(LogisticRegression)
910

11+
1012
@auto_param
1113
def sparse_lr_plot(X, y, learning_rate=0.01, penalty='l1', C=0.01, tol=0.01):
1214
LR = MyLogisticRegression(C=C, penalty=penalty, tol=tol, solver='saga')

hyperparameter/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
11
from .hp import HyperParameter, param_scope, auto_param
2+
from .hp import Tracker
3+
4+
set_tracker = Tracker.set_tracker

hyperparameter/hp.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
class Tracker:
1010
rlist = set()
1111
wlist = set()
12+
callback = None
1213

1314
@staticmethod
1415
def reads():
@@ -36,6 +37,9 @@ def report():
3637
]
3738
return '\n'.join(retvals)
3839

40+
@staticmethod
41+
def set_tracker(func):
42+
Tracker.callback = func
3943

4044
class Accessor(dict):
4145
"""
@@ -338,17 +342,18 @@ def auto_param(func):
338342
Examples:
339343
340344
>>> @auto_param
341-
... def foo(a, b=2, c='c'):
342-
... print(a, b, c)
345+
... def foo(a, b=2, c='c', d=None):
346+
... print(a, b, c, d)
343347
344348
>>> foo(1)
345-
1 2 c
349+
1 2 c None
346350
347351
>>> with param_scope('foo.b=3'):
348352
... foo(2)
349-
2 3 c
353+
2 3 c None
350354
"""
351355
predef_kws = {}
356+
predef_val = {}
352357

353358
namespace = func.__module__
354359
if namespace == '__main__':
@@ -364,12 +369,18 @@ def auto_param(func):
364369
name = '{}.{}'.format(namespace, k)
365370
predef_kws[k] = name
366371
Tracker.rlist.add(name)
372+
predef_val[name] = v.default
367373

368374
def wrapper(*arg, **kws):
369375
with param_scope() as hp:
376+
local_params = {}
370377
for k, v in predef_kws.items():
378+
local_params[v] = predef_val[v]
371379
if hp.get(v) is not None:
372380
kws[k] = hp.get(v)
381+
local_params[v] = hp.get(v)
382+
if Tracker.callback is not None:
383+
Tracker.callback(local_params)
373384
return func(*arg, **kws)
374385

375386
return wrapper

0 commit comments

Comments
 (0)