Skip to content

Commit 89a1834

Browse files
committed
drop and
1 parent b6f8349 commit 89a1834

File tree

5 files changed

+17
-133
lines changed

5 files changed

+17
-133
lines changed

README.md

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,8 @@ with param_scope(param1=2):
7676

7777
## Predefined Parameter
7878
```python
79-
@let( # predefine two parameter for `model_train`
80-
learning_rate = 1.0,
81-
penalty = 'l1'
82-
)
83-
def model_train(X, y):
79+
@auto_param #convert keyword arguments into hyper parameters
80+
def model_train(X, y, learning_rate = 1.0, penalty = 'l1'):
8481
LR = LogisticRegression(C=1.0,
8582
lr=local_param('learning_rate'),
8683
penalty=local_param('penalty'))

examples/sparse_lr/README.md

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,16 @@ Sparse LR Examples
44
This example is based on `scikit-learn` example: [l1 penalty and sparsity in logistic regression](https://scikit-learn.org/stable/auto_examples/linear_model/plot_logistic_l1_l2_sparsity.html#sphx-glr-auto-examples-linear-model-plot-logistic-l1-l2-sparsity-py), which classifies 8x8 images of digits into two classes: 0-4 against 5-9,
55
and visualize the coefficients of the model for different penalty methods(l1 or l2) and C.
66

7-
The algorithm is defined in function `sparse_lr_plot` from `model.py`. We use the decorator `let` to declare hyper-parameters for our function:
7+
The algorithm is defined in function `sparse_lr_plot` from `model.py`. We use the decorator `auto_param` to declare hyper-parameters for our function:
88
``` python
9-
@let(learning_rate=0.01, penalty='l1', C=0.01, tol=0.01)
10-
def sparse_lr_plot(X, y):
11-
C = local_param('C')
12-
penalty = local_param('penalty')
13-
tol = local_param('tol')
9+
@auto_param
10+
def sparse_lr_plot(X, y, learning_rate=0.01, penalty='l1', C=0.01, tol=0.01):
1411
print({'C': C, 'penalty': penalty, 'tol': tol})
1512
...
1613
```
1714

18-
Four hyper-parameter are defined for `sparse_lr_plot`: `learning_rate`, `penalty`, `C` and `tol`.
15+
Four keyword arguments are defined for `sparse_lr_plot`: `learning_rate`, `penalty`, `C` and `tol`. `auto_param` will convert these arguments into hyper-parameters.
16+
1917
There are two ways to control the hyper-parameters:
2018
1. parameter scope (see detail in `example_1.py`):
2119

examples/sparse_lr/model.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,11 @@
22
from sklearn.linear_model import LogisticRegression
33
import matplotlib.pyplot as plt
44

5-
from hyperparameter import let, local_param, param_scope
5+
from hyperparameter import auto_param, param_scope
66

77

8-
@let(learning_rate=0.01, penalty='l1', C=0.01, tol=0.01)
9-
def sparse_lr_plot(X, y):
10-
C = local_param('C')
11-
penalty = local_param('penalty')
12-
tol = local_param('tol')
8+
@auto_param
9+
def sparse_lr_plot(X, y, learning_rate=0.01, penalty='l1', C=0.01, tol=0.01):
1310
LR = LogisticRegression(C=C, penalty=penalty, tol=tol, solver='saga')
1411

1512
LR.fit(X, y)

hyperparameter/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .hp import *
1+
from .hp import HyperParameter, param_scope, auto_param

hyperparameter/hp.py

Lines changed: 6 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -274,10 +274,6 @@ def from_json_file(path):
274274
return HyperParameter(**obj)
275275

276276

277-
def hparam(*arg, **kws):
278-
return HyperParameter(*arg, **kws)
279-
280-
281277
class param_scope(HyperParameter):
282278
'''
283279
thread safe scoped hyper parameeter
@@ -327,119 +323,14 @@ def __exit__(self, exc_type, exc_value, traceback):
327323

328324
@staticmethod
329325
def init(params):
326+
"""
327+
init param_scope for a new thread.
328+
"""
330329
if not hasattr(param_scope.tls, '_cfg_'):
331330
param_scope.tls._cfg_ = []
332331
param_scope.tls._cfg_.append(params)
333332

334333

335-
def local_param(name: str):
336-
"""
337-
>>> @let(a=1)
338-
... def foo():
339-
... print(local_param('a'))
340-
341-
>>> foo()
342-
1
343-
344-
>>> with param_scope('foo.a=2'):
345-
... foo()
346-
2
347-
"""
348-
import sys
349-
namespace = inspect.getmodulename(sys._getframe(1).f_code.co_filename)
350-
if namespace == '__main__':
351-
namespace = None
352-
if namespace is not None:
353-
namespace += '.{}'.format(sys._getframe(1).f_code.co_name)
354-
else:
355-
namespace = sys._getframe(1).f_code.co_name
356-
with param_scope() as hp:
357-
value = hp.get(namespace + '.' + name)
358-
if value is not None:
359-
return value
360-
else:
361-
raise Exception(
362-
'name {} not defined for {}, available parameters: {}'.format(
363-
name, namespace, str(hp)))
364-
365-
366-
def let(*arg, **kws):
367-
"""
368-
wrap a function with parameters
369-
370-
example for pre-defined global parameter
371-
>>> @let(
372-
... 'a.b=1',
373-
... )
374-
... def foo():
375-
... with param_scope() as hp:
376-
... print(hp.a.b)
377-
378-
>>> foo()
379-
1
380-
381-
>>> with param_scope('a.b=2') as hp:
382-
... foo()
383-
2
384-
385-
example for pre-defined function local parameter
386-
387-
>>> @let(b=1)
388-
... def foo():
389-
... return local_param('b')
390-
391-
>>> foo()
392-
1
393-
394-
>>> with param_scope('foo.b=2'):
395-
... foo()
396-
2
397-
"""
398-
predef_arg = arg
399-
predef_kws = kws
400-
401-
def wrapper(func):
402-
namespace = func.__module__
403-
if namespace == '__main__':
404-
namespace = None
405-
if namespace is not None:
406-
namespace += '.{}'.format(func.__name__)
407-
else:
408-
namespace = func.__name__
409-
410-
global params_list
411-
for item in predef_arg:
412-
if '=' in item:
413-
k, v = item.split('=', 1)
414-
Tracker.wlist.add(k)
415-
for k, v in predef_kws.items():
416-
name = '{}.{}'.format(namespace, k)
417-
Tracker.wlist.add(name)
418-
419-
def result_func(*arg, **kws):
420-
accepted_arg = []
421-
accepted_kws = {}
422-
with param_scope() as hp:
423-
for item in predef_arg:
424-
if '=' in item:
425-
k, v = item.split('=', 1)
426-
if hp.get(k) is None:
427-
accepted_arg.append(item)
428-
for k, v in predef_kws.items():
429-
name = '{}.{}'.format(namespace, k)
430-
if hp.get(name) is None:
431-
accepted_kws[name] = safe_numeric(v)
432-
with param_scope(*accepted_arg) as hp:
433-
for k, v in accepted_kws.items():
434-
v = safe_numeric(v)
435-
hp.put(k, v)
436-
return func(*arg, **kws)
437-
438-
return result_func
439-
440-
return wrapper
441-
442-
443334
def auto_param(func):
444335
"""
445336
Convert keyword arguments into hyperparameters
@@ -472,6 +363,7 @@ def auto_param(func):
472363
if v.default != v.empty:
473364
name = '{}.{}'.format(namespace, k)
474365
predef_kws[k] = name
366+
Tracker.rlist.add(name)
475367

476368
def wrapper(*arg, **kws):
477369
with param_scope() as hp:
@@ -505,11 +397,11 @@ def safe_numeric(value):
505397
class TestHyperParameter(unittest.TestCase):
506398

507399
def test_parameter_create(self):
508-
param1 = hparam(a=1, b=2)
400+
param1 = HyperParameter(a=1, b=2)
509401
self.assertEqual(param1.a, 1)
510402
self.assertEqual(param1.b, 2)
511403

512-
param2 = hparam(**{'a': 1, 'b': 2})
404+
param2 = HyperParameter(**{'a': 1, 'b': 2})
513405
self.assertEqual(param2.a, 1)
514406
self.assertEqual(param2.b, 2)
515407

0 commit comments

Comments
 (0)