Skip to content

Commit 00e1fa5

Browse files
committed
new build system
1 parent 7b6d015 commit 00e1fa5

File tree

11 files changed

+125
-96
lines changed

11 files changed

+125
-96
lines changed

hyperparameter/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
from .api import auto_param, param_scope
2+
import importlib.metadata
23

34
__all__ = ["param_scope", "auto_param"]
45

5-
VERSION = "0.5.0"
6+
VERSION = importlib.metadata.version("hyperparameter")
7+
8+
import os
9+
include = os.path.dirname(__file__)
10+
try:
11+
import hyperparameter.rbackend
12+
lib = hyperparameter.rbackend.__file__
13+
except:
14+
lib = ""

hyperparameter/api.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any, Callable, Dict
44

55
from hyperparameter.storage import TLSKVStorage
6+
from .tune import Suggester
67

78

89
def _repr_dict(d):
@@ -16,7 +17,7 @@ class _DynamicDispatch:
1617

1718
__slots__ = ("_func", "_name")
1819

19-
def __get__(self): # a trick that let doctest descover this class
20+
def __get__(self): # a trick that let doctest discover this class
2021
pass
2122

2223
def __init__(self, func: Callable, name=None):
@@ -91,7 +92,7 @@ def _dynamic_dispatch(func, name=None):
9192

9293

9394
class _ParamAccessor:
94-
"""Accessor that handles missing parameters and default values
95+
"""Missing parameter and default value handler for hyperparameters.
9596
9697
Examples
9798
---------
@@ -122,6 +123,8 @@ def get_or_else(self, default: Any = None):
122123
value = self._root.get(self._name)
123124
if isinstance(value, _ParamAccessor):
124125
return default
126+
if isinstance(value, Suggester):
127+
return value()
125128
if type(default) is bool and isinstance(value, str):
126129
if value is None:
127130
return False
@@ -152,11 +155,14 @@ def get_or_else(self, default: Any = None):
152155
return float(value)
153156
except Exception as exc:
154157
return value
155-
try:
156-
return type(default)(value)
157-
except Exception as exc:
158-
# raise exc
159-
return value
158+
if type(default) is float:
159+
try:
160+
return float(value)
161+
except:
162+
return value
163+
if type(default) is str:
164+
return str(value)
165+
return value
160166

161167
def __getitem__(self, index: str) -> Any:
162168
return self.__getattr__(index)
@@ -234,7 +240,7 @@ def storage(self):
234240
return self._storage
235241

236242
def __getitem__(self, key: str) -> Any:
237-
"""get parameter with dict-style api
243+
"""dict-style api for parameter reading.
238244
Examples
239245
----------
240246
>>> hp = _HyperParameter(param1=1, obj1={"prop1": "a"})
@@ -249,7 +255,7 @@ def __getitem__(self, key: str) -> Any:
249255
return _ParamAccessor(self._storage, key)
250256

251257
def __setitem__(self, key: str, value: Any) -> None:
252-
"""set parameter with dict-style api
258+
"""dict-style api for parameter writing.
253259
Examples
254260
----------
255261
>>> hp = _HyperParameter(param1=1, obj1={"prop1": "a"})
@@ -271,7 +277,7 @@ def __setitem__(self, key: str, value: Any) -> None:
271277
return self._storage.put(key, value)
272278

273279
def __getattr__(self, name: str) -> Any:
274-
"""get parameter with object-style api
280+
"""object-style api for parameter reading.
275281
Examples
276282
--------
277283
>>> hp = _HyperParameter(param1=1, obj1={"prop1": "a"})
@@ -285,7 +291,7 @@ def __getattr__(self, name: str) -> Any:
285291
return _ParamAccessor(self, name)
286292

287293
def __setattr__(self, name: str, value: Any) -> None:
288-
"""set parameter with object-style api
294+
"""object-style api for parameter writing.
289295
Examples
290296
--------
291297
>>> hp = _HyperParameter()
@@ -383,16 +389,17 @@ def __enter__(self):
383389
384390
Examples
385391
--------
386-
>>> param_scope.p = "origin"
387-
>>> with param_scope(**{"p": "origin"}) as ps:
388-
... ps.storage().storage() # outer scope
389-
... with param_scope() as ps: # unmodified scope
390-
... ps.storage().storage() # inner scope
391-
... with param_scope(**{"p": "modified"}) as ps: # modified scope
392-
... ps.storage().storage() # inner scope with modified params
393-
... _ = param_scope(**{"p": "modified"}) # not used in with ctx
394-
... with param_scope() as ps: # unmodified scope
395-
... ps.storage().storage() # inner scope
392+
>>> with param_scope():
393+
... param_scope.p = "origin"
394+
... with param_scope(**{"p": "origin"}) as ps:
395+
... ps.storage().storage() # outer scope
396+
... with param_scope() as ps: # unmodified scope
397+
... ps.storage().storage() # inner scope
398+
... with param_scope(**{"p": "modified"}) as ps: # modified scope
399+
... ps.storage().storage() # inner scope with modified params
400+
... _ = param_scope(**{"p": "modified"}) # not used in with ctx
401+
... with param_scope() as ps: # unmodified scope
402+
... ps.storage().storage() # inner scope
396403
{'p': 'origin'}
397404
{'p': 'origin'}
398405
{'p': 'modified'}
File renamed without changes.

hyperparameter/tune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def suggest_from(callback: Callable) -> Suggester:
2323
... index, self._offset = self._offset % len(self._lst), self._offset + 1
2424
... return self._lst[index]
2525
26-
>>> from hyperparameter import param_scope, suggest_from
26+
>>> from hyperparameter import param_scope
2727
>>> with param_scope(suggested = suggest_from(ValueWrapper([1,2,3]))) as ps:
2828
... ps().suggested()
2929
... ps().suggested()

pyproject.toml

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
11
[build-system]
2-
requires = ["setuptools", "setuptools-scm", "wheel", "setuptools-rust"]
3-
build-backend = "setuptools.build_meta"
2+
requires = ["maturin>=0.14,<0.15"]
3+
build-backend = "maturin"
44

55
[project]
66
name = "hyperparameter"
7+
version = "0.5.0"
78
authors = [{ name = "Reiase", email = "[email protected]" }]
89
description = "A hyper-parameter library for researchers, data scientists and machine learning engineers."
910
requires-python = ">=3.7"
1011
readme = "README.md"
1112
license = { text = "Apache License Version 2.0" }
12-
dynamic = ["version"]
1313

14-
[tool.setuptools]
15-
packages = ["hyperparameter", "hparam"]
16-
17-
[tool.setuptools.dynamic]
18-
version = { attr = "hyperparameter.VERSION" }
14+
[tool.maturin]
15+
module-name = "hyperparameter.rbackend"
16+
features = ["pyo3/extension-module"]
17+
include = ["hyperparameter/hyperparameter.h"]
1918

2019
[tool.black]
2120
line-length = 88

setup.py

Lines changed: 0 additions & 8 deletions
This file was deleted.

src/entry.rs

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ pub enum Value {
99
Text(CacheString),
1010
Boolen(bool),
1111
UserDefined(*mut c_void),
12-
PyObject(*mut c_void),
12+
PyObject(*mut c_void, unsafe fn(*mut c_void)),
1313
}
1414

1515
impl From<i64> for Value {
@@ -67,7 +67,7 @@ impl TryFrom<Value> for i64 {
6767
.or_else(|_| Err(format!("error convert {} into i64", v))),
6868
Value::Boolen(v) => Ok(v.into()),
6969
Value::UserDefined(_) => Err("data type not matched, `Userdefined` and i64".into()),
70-
Value::PyObject(_) => Err("data type not matched, `PyObject` and i64".into()),
70+
Value::PyObject(_, _) => Err("data type not matched, `PyObject` and i64".into()),
7171
}
7272
}
7373
}
@@ -85,7 +85,7 @@ impl TryFrom<Value> for f64 {
8585
.or_else(|_| Err(format!("error convert {} into i64", v))),
8686
Value::Boolen(_) => Err("data type not matched, `Boolen` and i64".into()),
8787
Value::UserDefined(_) => Err("data type not matched, `Userdefined` and f64".into()),
88-
Value::PyObject(_) => Err("data type not matched, `PyObject` and f64".into()),
88+
Value::PyObject(_, _) => Err("data type not matched, `PyObject` and f64".into()),
8989
}
9090
}
9191
}
@@ -101,7 +101,7 @@ impl TryFrom<Value> for String {
101101
Value::Text(v) => Ok(v.to_string()),
102102
Value::Boolen(v) => Ok(format!("{}", v)),
103103
Value::UserDefined(_) => Err("data type not matched, `Userdefined` and str".into()),
104-
Value::PyObject(_) => Err("data type not matched, `PyObject` and str".into()),
104+
Value::PyObject(_, _) => Err("data type not matched, `PyObject` and str".into()),
105105
}
106106
}
107107
}
@@ -117,12 +117,12 @@ impl TryFrom<Value> for bool {
117117
Value::Text(_) => Err("data type not matched, `Text` and bool".into()),
118118
Value::Boolen(v) => Ok(v),
119119
Value::UserDefined(_) => Err("data type not matched, `Userdefined` and str".into()),
120-
Value::PyObject(_) => Err("data type not matched, `PyObject` and str".into()),
120+
Value::PyObject(_, _) => Err("data type not matched, `PyObject` and str".into()),
121121
}
122122
}
123123
}
124124

125-
#[derive(Clone)]
125+
#[derive(Debug, Clone)]
126126
pub enum EntryValue {
127127
Single(Value),
128128
Versioned(Value, Box<EntryValue>),
@@ -144,7 +144,7 @@ impl EntryValue {
144144
}
145145
}
146146

147-
#[derive(Clone)]
147+
#[derive(Debug, Clone)]
148148
pub struct Entry {
149149
pub key: String,
150150
pub val: EntryValue,
@@ -181,10 +181,17 @@ impl Entry {
181181
}
182182

183183
pub fn rollback(&mut self) -> Result<(), ()> {
184+
let val = self.val.get();
184185
let his = self.val.history();
185186
match his {
186187
None => Err(()),
187188
Some(h) => {
189+
match val {
190+
Value::PyObject(obj, free) => unsafe {
191+
free(*obj);
192+
},
193+
_ => {}
194+
}
188195
self.val = *h.clone();
189196
Ok(())
190197
}

src/ext.rs

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use std::ffi::c_void;
22

33
use pyo3::exceptions::PyValueError;
4+
use pyo3::ffi::Py_DecRef;
5+
use pyo3::ffi::Py_IncRef;
46
use pyo3::prelude::*;
57
use pyo3::types::PyBool;
68
use pyo3::types::PyDict;
@@ -20,15 +22,6 @@ pub struct KVStorage {
2022
isview: bool,
2123
}
2224

23-
impl KVStorage {
24-
pub fn _storage(&mut self) -> *mut Storage {
25-
match self.isview {
26-
true => MGR.with(|mgr| mgr.borrow_mut().stack.last().unwrap().clone()),
27-
false => &mut self.storage,
28-
}
29-
}
30-
}
31-
3225
#[pymethods]
3326
impl KVStorage {
3427
#[new]
@@ -41,16 +34,15 @@ impl KVStorage {
4134

4235
pub unsafe fn storage(&mut self, py: Python<'_>) -> PyResult<PyObject> {
4336
let res = PyDict::new(py);
44-
let s = self._storage();
45-
for k in (*s).keys().iter() {
46-
match (*s).get(k).unwrap() {
37+
for k in self.storage.keys().iter() {
38+
match self.storage.get(k).unwrap() {
4739
Value::Empty => Ok(()),
4840
Value::Int(v) => res.set_item(k, v),
4941
Value::Float(v) => res.set_item(k, v),
5042
Value::Text(v) => res.set_item(k, v.as_str()),
5143
Value::Boolen(v) => res.set_item(k, v),
5244
Value::UserDefined(v) => res.set_item(k, v as u64),
53-
Value::PyObject(v) => {
45+
Value::PyObject(v, _) => {
5446
res.set_item(k, PyAny::from_owned_ptr(py, v as *mut pyo3::ffi::PyObject))
5547
}
5648
}
@@ -60,8 +52,7 @@ impl KVStorage {
6052
}
6153

6254
pub unsafe fn keys(&mut self, py: Python<'_>) -> PyResult<PyObject> {
63-
let s = self._storage();
64-
let res = PyList::new(py, (*s).keys());
55+
let res = PyList::new(py, self.storage.keys());
6556
Ok(res.into())
6657
}
6758

@@ -86,23 +77,21 @@ impl KVStorage {
8677
}
8778

8879
pub unsafe fn clear(&mut self) {
89-
let s = self._storage();
90-
for k in (*s).keys().iter() {
80+
for k in self.storage.keys().iter() {
9181
self.storage.put(k, Value::Empty);
9282
}
9383
}
9484

9585
pub unsafe fn get(&mut self, py: Python<'_>, key: String) -> PyResult<Option<PyObject>> {
96-
let s = self._storage();
97-
match (*s).get(key) {
86+
match self.storage.get(key) {
9887
Some(val) => match val {
9988
Value::Empty => Err(PyValueError::new_err("not found")),
10089
Value::Int(v) => Ok(Some(v.into_py(py))),
10190
Value::Float(v) => Ok(Some(v.into_py(py))),
10291
Value::Text(v) => Ok(Some(v.into_py(py))),
10392
Value::Boolen(v) => Ok(Some(v.into_py(py))),
10493
Value::UserDefined(v) => Ok(Some((v as u64).into_py(py))),
105-
Value::PyObject(v) => Ok(Some(
94+
Value::PyObject(v, _) => Ok(Some(
10695
PyAny::from_owned_ptr(py, v as *mut pyo3::ffi::PyObject).into(),
10796
)),
10897
},
@@ -111,21 +100,30 @@ impl KVStorage {
111100
}
112101

113102
pub unsafe fn put(&mut self, key: String, val: &PyAny) -> PyResult<()> {
114-
let s = self._storage();
103+
if self.isview {
104+
MGR.with_borrow_mut(|mgr| mgr.put_key(key.clone()));
105+
}
115106
if val.is_none() {
116-
(*s).put(key, Value::Empty);
107+
self.storage.put(key, Value::Empty);
117108
return Ok(());
118109
}
119110
if val.is_instance_of::<PyBool>().unwrap() {
120-
(*s).put(key, val.extract::<bool>().unwrap());
111+
self.storage.put(key, val.extract::<bool>().unwrap());
121112
} else if val.is_instance_of::<PyFloat>().unwrap() {
122-
(*s).put(key, val.extract::<f64>().unwrap());
113+
self.storage.put(key, val.extract::<f64>().unwrap());
123114
} else if val.is_instance_of::<PyString>().unwrap() {
124-
(*s).put(key, val.extract::<&str>().unwrap());
115+
self.storage.put(key, val.extract::<&str>().unwrap());
125116
} else if val.is_instance_of::<PyInt>().unwrap() {
126-
(*s).put(key, val.extract::<i64>().unwrap());
117+
self.storage.put(key, val.extract::<i64>().unwrap());
127118
} else {
128-
(*s).put(key, Value::PyObject(val.into_ptr() as *mut c_void));
119+
// TODO support release pyobj
120+
Py_IncRef(val.into_ptr());
121+
self.storage.put(
122+
key,
123+
Value::PyObject(val.into_ptr() as *mut c_void, |obj: *mut c_void| {
124+
Py_DecRef(obj as *mut pyo3::ffi::PyObject);
125+
}),
126+
);
129127
}
130128
Ok(())
131129
}
@@ -140,10 +138,12 @@ impl KVStorage {
140138

141139
#[staticmethod]
142140
pub fn current() -> KVStorage {
143-
KVStorage {
141+
let mut kv = KVStorage {
144142
storage: Storage::new(),
145143
isview: true,
146-
}
144+
};
145+
kv.storage.isview = 1;
146+
kv
147147
}
148148
}
149149

0 commit comments

Comments
 (0)