Skip to content

Commit 9bb9fec

Browse files
committed
support parameter searching with optuna
Signed-off-by: reiase <[email protected]>
1 parent b94bc52 commit 9bb9fec

File tree

6 files changed

+236
-4
lines changed

6 files changed

+236
-4
lines changed

examples/optuna/README.md

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
Hyper-Parameter Optimization
2+
============================
3+
4+
This example is based on `optuna` [quick start example](https://optuna.org/#code_quickstart). [Optuna](https://optuna.org/) is an open-source [hyperparameter](https://github.com/reiase/hyperparameter) optimization framework which is easy to use:
5+
6+
```python
7+
import optuna
8+
9+
def objective(trial):
10+
x = trial.suggest_float('x', -10, 10)
11+
return (x - 2) ** 2
12+
13+
study = optuna.create_study()
14+
study.optimize(objective, n_trials=100)
15+
16+
study.best_params # E.g. {'x': 2.002108042}
17+
```
18+
19+
The above example creates a `study` object to search for the best parameter `x` that minimizes the objective function `(x-2)^2`.
20+
21+
Parameter Searching with [`HyperParameter`](https://github.com/reiase/hyperparameter)
22+
-----------------------------------------
23+
24+
Parameter searching can be much easier with [`HyperParameter`](https://github.com/reiase/hyperparameter):
25+
26+
```python
27+
import optuna
28+
from hyperparameter import param_scope, auto_param, lazy_dispatch
29+
30+
@auto_param
31+
def objective(x = 0.0):
32+
return (x - 2) ** 2
33+
34+
def wrapper(trial):
35+
trial = lazy_dispatch(trial)
36+
with param_scope(**{
37+
"objective.x": trial.suggest_float('objective.x', -10, 10)
38+
}):
39+
return objective()
40+
41+
study = optuna.create_study()
42+
study.optimize(wrapper, n_trials=100)
43+
44+
study.best_params # E.g. {'x': 2.002108042}
45+
```
46+
47+
We directly apply [the `auto_param` decorator](https://reiase.github.io/hyperparameter/quick_start/#auto_param) to the objective function so that it accepts parameters from [`param_scope`](https://reiase.github.io/hyperparameter/quick_start/#param_scope). Then we define a wrapper function that adapts `param_scope` API to `optuna`'s `trial` API and starts the parameter experiment as suggested in `optuna`'s example.
48+
49+
Put the Best Parameters into Production
50+
---------------------------------------
51+
52+
To put the best parameters into production, we can directly pass them to `param_scope`. This is very convenient if you want to put a ML model into production.
53+
54+
```python
55+
with param_scope(**study.best_params):
56+
print(f"{study.best_params} => {objective()}")
57+
```
58+
59+
- hyperparameters for nested functions:
60+
61+
For complex problems with nested functions:
62+
63+
```python
64+
@auto_param
65+
def objective_x(x = 0.0):
66+
return (x - 2) ** 2
67+
68+
@auto_param
69+
def objective_y(y = 0.0):
70+
return (y - 1) ** 3
71+
72+
def objective():
73+
return objective_x() * objective_y()
74+
75+
def wrapper(trial):
76+
trial = lazy_dispatch(trial)
77+
with param_scope(**{
78+
"objective_x.x": trial.suggest_float('objective_x.x', -10, 10),
79+
"objective_y.y": trial.suggest_float('objective_y.y', -10, 10)
80+
}):
81+
return objective()
82+
83+
study = optuna.create_study()
84+
study.optimize(wrapper, n_trials=100)
85+
86+
study.best_params # E.g. {'x': 2.002108042}
87+
```

examples/optuna/example.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import optuna
2+
3+
def objective(trial):
4+
x = trial.suggest_float('x.x.y', -10, 10)
5+
return (x - 2) ** 2
6+
7+
study = optuna.create_study()
8+
study.optimize(objective, n_trials=100)
9+
10+
print(study.best_params) # E.g. {'x': 2.002108042}

examples/optuna/example_hp.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import optuna
2+
from hyperparameter import param_scope, auto_param, lazy_dispatch
3+
4+
@auto_param
5+
def objective(x = 0.0):
6+
return (x - 2) ** 2
7+
8+
def wrapper(trial):
9+
trial = lazy_dispatch(trial)
10+
with param_scope(**{
11+
"objective.x": trial.suggest_float('objective.x', -10, 10)
12+
}):
13+
return objective()
14+
15+
study = optuna.create_study()
16+
study.optimize(wrapper, n_trials=100)
17+
18+
study.best_params # E.g. {'x': 2.002108042}
19+
20+
with param_scope(**study.best_params):
21+
print(f"{study.best_params} => {objective()}")
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import optuna
2+
from hyperparameter import param_scope, auto_param, lazy_dispatch
3+
4+
@auto_param
5+
def objective_x(x = 0.0):
6+
return (x - 2) ** 2
7+
8+
@auto_param
9+
def objective_y(y = 0.0):
10+
return (y - 1) ** 4
11+
12+
def objective():
13+
return objective_x() + objective_y()
14+
15+
def wrapper(trial):
16+
trial = lazy_dispatch(trial)
17+
with param_scope(**{
18+
"objective_x.x": trial.suggest_float('objective_x.x', -10, 10),
19+
"objective_y.y": trial.suggest_float('objective_y.y', -10, 10)
20+
}):
21+
return objective()
22+
23+
study = optuna.create_study()
24+
study.optimize(wrapper, n_trials=100)
25+
26+
study.best_params # E.g. {'x': 2.002108042}
27+
28+
with param_scope(**study.best_params):
29+
print(f"{study.best_params} => {objective()}")

hyperparameter/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from .hyperparameter import HyperParameter
22
from .hyperparameter import param_scope, reads, writes, all_params
33
from .hyperparameter import auto_param, set_auto_param_callback
4-
from .hyperparameter import dynamic_dispatch
4+
from .hyperparameter import dynamic_dispatch, lazy_dispatch, suggest_from
55

66
__all__ = [
77
"HyperParameter",
88
"dynamic_dispatch",
9+
"suggest_from",
10+
"lazy_dispatch",
911
"param_scope",
1012
"reads",
1113
"writes",

hyperparameter/hyperparameter.py

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import inspect
2+
from operator import getitem
23
import threading
3-
from typing import Any, Callable, Dict, Set
4+
from typing import Any, Callable, Dict, List, Optional, Set
45

56

67
def _sorted_set(s):
@@ -40,10 +41,10 @@ def clear(self):
4041
self._get.clear()
4142
self._put.clear()
4243

43-
def read(self, key: str = None) -> Set[str]:
44+
def read(self, key: Optional[str] = None) -> Optional[List[str]]:
4445
return self._get.add(key) if key else _sorted_set(self._get)
4546

46-
def write(self, key: str = None) -> Set[str]:
47+
def write(self, key: Optional[str] = None) -> Optional[List[str]]:
4748
return self._put.add(key) if key else _sorted_set(self._put)
4849

4950
def all(self):
@@ -65,6 +66,47 @@ def all_params():
6566
return _tracker.all()
6667

6768

69+
class Suggester:
70+
def __init__(self, callback: Callable) -> None:
71+
self._callback = callback
72+
73+
def __call__(self):
74+
return self._callback()
75+
76+
77+
def suggest_from(callback: Callable) -> Suggester:
78+
""" Suggest parameter from a callback function
79+
80+
Examples
81+
--------
82+
>>> class ValueWrapper:
83+
... def __init__(self, lst):
84+
... self._lst = lst
85+
... self._offset = 0
86+
... def __call__(self):
87+
... index, self._offset = self._offset % len(self._lst), self._offset + 1
88+
... return self._lst[index]
89+
90+
>>> with param_scope(suggested = suggest_from(ValueWrapper([1,2,3]))) as ps:
91+
... ps().suggested()
92+
... ps().suggested()
93+
... ps().suggested()
94+
1
95+
2
96+
3
97+
"""
98+
return Suggester(callback)
99+
100+
101+
def _unwrap_suggester(func):
102+
def wrapper(*args, **kwargs):
103+
retval = func(*args, **kwargs)
104+
if isinstance(retval, Suggester):
105+
return retval()
106+
return retval
107+
return wrapper
108+
109+
68110
class _Accessor(dict):
69111
"""Helper for accessing hyper-parameters."""
70112

@@ -74,6 +116,7 @@ def __init__(self, root, path=None, scope=None):
74116
self._path = path
75117
self._scope = scope
76118

119+
@_unwrap_suggester
77120
def get_or_else(self, default: Any = None):
78121
"""Get value for the parameter, or get default value if the parameter is not defined."""
79122
if self._scope is not None:
@@ -173,6 +216,46 @@ def dynamic_dispatch(func, name=None, index=None):
173216
return new_class(func, name, index)
174217

175218

219+
class LazyDispatch:
220+
"""Dynamic call dispatcher
221+
222+
Examples
223+
--------
224+
225+
>>> class ExampleObj:
226+
... def get_42(self, offset):
227+
... return 42+offset
228+
229+
>>> lazy_obj = lazy_dispatch(ExampleObj())
230+
>>> lazy_obj.get_42(0)()
231+
42
232+
"""
233+
234+
def __init__(self, obj: Any, name=None, index=None):
235+
self._obj = obj
236+
self._name = name
237+
self._index = index
238+
239+
def __call__(self, *args, **kws) -> Any:
240+
obj = self._obj
241+
for n in self._name.split("."):
242+
obj = getattr(obj, n)
243+
if self._index is not None:
244+
obj = getitem(obj, self._index)
245+
return Suggester(lambda: obj(*args, **kws))
246+
247+
def __getattr__(self, name: str) -> Any:
248+
if self._name is not None:
249+
name = f"{self._name}.{name}"
250+
return lazy_dispatch(self._obj, name, self._index)
251+
252+
def __getitem__(self, index):
253+
return lazy_dispatch(self._obj, self._name, index)
254+
255+
def lazy_dispatch(obj, name=None, index=None):
256+
"""Wraps an object for lazy dispatch"""
257+
return LazyDispatch(obj, name, index)
258+
176259
class HyperParameter(dict):
177260
"""HyperParameter is an extended dict designed for parameter storage.
178261

0 commit comments

Comments
 (0)