Skip to content

Commit 9efa6ab

Browse files
authored
support get_by_hash in auto_param (#4)
1 parent fa329d7 commit 9efa6ab

File tree

4 files changed

+106
-2
lines changed

4 files changed

+106
-2
lines changed

hyperparameter/api.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import inspect
33
from typing import Any, Callable, Dict
44

5-
from hyperparameter.storage import TLSKVStorage
5+
from hyperparameter.storage import TLSKVStorage, has_rust_backend
6+
from hyperparameter.storage import xxh64
67
from .tune import Suggester
78

89

@@ -461,7 +462,7 @@ def current():
461462
def init(params=None):
462463
"""init param_scope for a new thread."""
463464
param_scope(**params).__enter__()
464-
465+
465466
@staticmethod
466467
def frozen():
467468
with param_scope():
@@ -520,6 +521,38 @@ def auto_param(name_or_func):
520521
if callable(name_or_func):
521522
return auto_param(None)(name_or_func)
522523

524+
if has_rust_backend:
525+
526+
def hashed_wrapper(func):
527+
predef_kws = {}
528+
529+
if name_or_func is None:
530+
namespace = func.__name__
531+
else:
532+
namespace = name_or_func
533+
534+
signature = inspect.signature(func)
535+
for k, v in signature.parameters.items():
536+
if v.default != v.empty:
537+
name = "{}.{}".format(namespace, k)
538+
predef_kws[k] = xxh64(name)
539+
540+
@functools.wraps(func)
541+
def inner(*arg, **kws):
542+
with param_scope() as hp:
543+
for k, v in predef_kws.items():
544+
if k not in kws:
545+
try:
546+
val = hp._storage.get_by_hash(v)
547+
kws[k] = val
548+
except ValueError:
549+
pass
550+
return func(*arg, **kws)
551+
552+
return inner
553+
554+
return hashed_wrapper
555+
523556
def wrapper(func):
524557
predef_kws = {}
525558
predef_val = {}

hyperparameter/storage.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ def _update(values={}, prefix=None):
102102

103103
def clear(self):
104104
self._storage.clear()
105+
106+
def get_by_hash(self, *args, **kwargs):
107+
raise RuntimeError("hyperparameter is not build with rust backend")
105108

106109
def get(self, name: str, accessor: Callable = None) -> Any:
107110
if name in self.__slots__:
@@ -139,10 +142,21 @@ def frozen():
139142
GLOBAL_STORAGE.update(TLSKVStorage.tls.his[-1].storage())
140143

141144

145+
has_rust_backend = False
146+
147+
148+
def xxh64(*args, **kwargs):
149+
raise RuntimeError("hyperparameter is not build with rust backend")
150+
151+
142152
try:
143153
if os.environ.get("HYPERPARAMETER_BACKEND", "RUST") == "RUST":
144154
from hyperparameter.rbackend import KVStorage
155+
from hyperparameter.rbackend import xxh64
156+
145157
TLSKVStorage = KVStorage
158+
has_rust_backend = True
146159
except:
147160
import traceback
161+
148162
traceback.print_exc()

src/ext.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use crate::storage::frozen_as_global_storage;
1818
use crate::storage::Storage;
1919
use crate::storage::StorageManager;
2020
use crate::storage::MGR;
21+
use crate::xxh::xxhstr;
2122

2223
#[repr(C)]
2324
enum UserDefinedType {
@@ -127,6 +128,28 @@ impl KVStorage {
127128
}
128129
}
129130

131+
pub unsafe fn get_by_hash(&mut self, py: Python<'_>, hkey: u64) -> PyResult<Option<PyObject>> {
132+
match self.storage.get_by_hash(hkey) {
133+
Some(val) => match val {
134+
Value::Empty => Err(PyValueError::new_err("not found")),
135+
Value::Int(v) => Ok(Some(v.into_py(py))),
136+
Value::Float(v) => Ok(Some(v.into_py(py))),
137+
Value::Text(v) => Ok(Some(v.into_py(py))),
138+
Value::Boolen(v) => Ok(Some(v.into_py(py))),
139+
Value::UserDefined(v, k, _) => {
140+
if k == UserDefinedType::PyObjectType as i32 {
141+
Ok(Some(
142+
PyAny::from_borrowed_ptr(py, v as *mut pyo3::ffi::PyObject).into(),
143+
))
144+
} else {
145+
Ok(Some((v as u64).into_py(py)))
146+
}
147+
}
148+
},
149+
None => Err(PyValueError::new_err("not found")),
150+
}
151+
}
152+
130153
pub unsafe fn put(&mut self, key: String, val: &PyAny) -> PyResult<()> {
131154
if self.isview {
132155
MGR.with(|mgr: &RefCell<StorageManager>| mgr.borrow_mut().put_key(key.clone()));
@@ -172,8 +195,14 @@ impl KVStorage {
172195
}
173196
}
174197

198+
#[pyfunction]
199+
pub fn xxh64(s: &str) -> u64 {
200+
xxhstr(s)
201+
}
202+
175203
#[pymodule]
176204
fn rbackend(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
177205
m.add_class::<KVStorage>()?;
206+
m.add_function(wrap_pyfunction!(xxh64, m)?)?;
178207
Ok(())
179208
}

tests/test_auto_param.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from unittest import TestCase
2+
3+
from hyperparameter import auto_param, param_scope
4+
5+
6+
class TestAutoParam(TestCase):
7+
def test_auto_param_func(self):
8+
@auto_param("foo")
9+
def foo(a, b=1, c=2.0, d=False, e="str"):
10+
return a, b, c, d, e
11+
12+
with param_scope(**{"foo.b": 2}):
13+
self.assertEqual(foo(1), (1, 2, 2.0, False, "str"))
14+
15+
with param_scope(**{"foo.c": 3.0}):
16+
self.assertEqual(foo(1), (1, 1, 3.0, False, "str"))
17+
18+
def test_auto_param_func2(self):
19+
@auto_param("foo")
20+
def foo(a, b=1, c=2.0, d=False, e="str"):
21+
return a, b, c, d, e
22+
23+
with param_scope():
24+
param_scope.foo.b = 2
25+
self.assertEqual(foo(1), (1, 2, 2.0, False, "str"))
26+
param_scope.foo.c = 3.0
27+
self.assertEqual(foo(1), (1, 2, 3.0, False, "str"))
28+
self.assertEqual(foo(1), (1, 1, 2.0, False, "str"))

0 commit comments

Comments
 (0)