Skip to content

Commit 725e4c1

Browse files
committed
fix frozen for python backend
1 parent 0008af8 commit 725e4c1

File tree

11 files changed

+124
-15
lines changed

11 files changed

+124
-15
lines changed

Cargo.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@ pyo3 = { version = "0.18.1", features = [
1717
"abi3",
1818
"abi3-py37",
1919
] }
20+
lazy_static = "1.4.0"
2021

21-
[dev_dependencies]
22-
rspec = "1.0"
22+
# [dev_dependencies]
23+
# rspec = "1.0"
2324

2425
[profile.dev]
25-
overflow-checks=false
26+
overflow-checks=false
File renamed without changes.
File renamed without changes.
File renamed without changes.

hyperparameter/api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,11 @@ def current():
461461
def init(params=None):
462462
"""init param_scope for a new thread."""
463463
param_scope(**params).__enter__()
464+
465+
@staticmethod
466+
def frozen():
467+
with param_scope():
468+
TLSKVStorage.frozen()
464469

465470

466471
_param_scope = param_scope._func

hyperparameter/storage.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
import os
12
import threading
23
from typing import Any, Callable, Dict, Iterable
34

45

6+
GLOBAL_STORAGE = {}
7+
8+
59
class Storage:
610
"""Base class for all storage implementations"""
711

@@ -55,7 +59,13 @@ def __init__(self, parent=None) -> None:
5559
self._parent = parent
5660
super().__init__()
5761

58-
if hasattr(TLSKVStorage.tls, "his") and len(TLSKVStorage.tls.his) > 0:
62+
if not hasattr(TLSKVStorage.tls, "his"):
63+
TLSKVStorage.tls.his = [TLSKVStorage.__new__(TLSKVStorage)]
64+
TLSKVStorage.tls.his[-1]._storage = GLOBAL_STORAGE
65+
TLSKVStorage.tls.his[-1]._parent = None
66+
self.update(GLOBAL_STORAGE)
67+
68+
elif hasattr(TLSKVStorage.tls, "his") and len(TLSKVStorage.tls.his) > 0:
5969
parent = TLSKVStorage.tls.his[-1]
6070
self.update(parent._storage)
6171

@@ -86,6 +96,7 @@ def _update(values={}, prefix=None):
8696
_update(v, prefix=key)
8797
else:
8898
storage[key] = v
99+
89100
if kws is not None:
90101
return _update(kws, prefix=None)
91102

@@ -123,10 +134,19 @@ def current():
123134
TLSKVStorage.tls.his = [TLSKVStorage()]
124135
return TLSKVStorage.tls.his[-1]
125136

137+
@staticmethod
138+
def frozen():
139+
GLOBAL_STORAGE.update(TLSKVStorage.tls.his[-1].storage())
140+
141+
126142
try:
127-
from hyperparameter.rbackend import KVStorage
128-
TLSKVStorage = KVStorage
143+
if os.environ.get("HYPERPARAMETER_BACKEND", "RUST") == "RUST":
144+
from hyperparameter.rbackend import KVStorage
145+
TLSKVStorage = KVStorage
146+
print("using native hyperparameter backend")
147+
else:
148+
print("using python hyperparameter backend")
129149
except:
130150
import traceback
131151
traceback.print_exc()
132-
pass
152+
print("using python hyperparameter backend")

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,4 @@ source = ["hyperparameter"]
3636
[tool.pytest.ini_options]
3737
minversion = "6.0"
3838
addopts = "-ra -q --durations=5 --doctest-modules --doctest-glob=*.md"
39-
testpaths = ["hyperparameter/", "docs"]
39+
testpaths = ["hyperparameter/", "docs", "tests/"]

src/ext.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@ use pyo3::types::PyString;
1414
use pyo3::FromPyPointer;
1515

1616
use crate::entry::Value;
17+
use crate::storage::frozen_as_global_storage;
1718
use crate::storage::Storage;
18-
use crate::storage::MGR;
1919
use crate::storage::StorageManager;
20+
use crate::storage::MGR;
2021

2122
#[pyclass]
2223
pub struct KVStorage {
@@ -148,6 +149,11 @@ impl KVStorage {
148149
kv.storage.isview = 1;
149150
kv
150151
}
152+
153+
#[staticmethod]
154+
pub fn frozen() {
155+
frozen_as_global_storage();
156+
}
151157
}
152158

153159
#[pymodule]

src/lib.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
// #![feature(local_key_cell_methods)]
2-
// #![feature(let_chains)]
3-
41
pub mod entry;
52
pub mod storage;
63

src/storage.rs

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ use std::cell::{Ref, RefCell, RefMut};
22
use std::collections::BTreeMap;
33
use std::collections::HashSet;
44
use std::rc::Rc;
5+
use std::sync::Mutex;
6+
7+
use lazy_static::lazy_static;
58

69
use crate::entry::{Entry, EntryValue, Value};
710
use crate::xxh::xxhstr;
@@ -48,15 +51,65 @@ thread_local! {
4851
}
4952

5053
pub fn init_storage_manager() -> RefCell<StorageManager> {
54+
let mut tree = Tree::new();
55+
global_storage_get(&mut tree);
5156
let sm = RefCell::new(StorageManager {
52-
tls: Rc::new(RefCell::new(Tree::new())),
57+
tls: Rc::new(RefCell::new(tree)),
5358
stack: Vec::new(),
5459
});
5560
sm.borrow_mut().stack.push(RefCell::new(HashSet::new()));
5661

5762
return sm;
5863
}
5964

65+
lazy_static! {
66+
static ref GLOBAL_STORAGE: Mutex<u64> = {
67+
let tree = Box::new(Tree::new());
68+
Mutex::new(Box::into_raw(tree) as u64)
69+
};
70+
}
71+
72+
fn global_storage_set(t: &Tree) {
73+
GLOBAL_STORAGE
74+
.lock()
75+
.and_then(|v| unsafe {
76+
let ptr = v.clone() as *mut Tree;
77+
match ptr.as_mut() {
78+
Some(tree) => {
79+
tree.clear();
80+
tree.clone_from(t);
81+
}
82+
None => todo!(),
83+
};
84+
Ok(())
85+
})
86+
.unwrap();
87+
}
88+
89+
pub fn frozen_as_global_storage() {
90+
MGR.with(|mgr| {
91+
let t = mgr.borrow().tls.borrow().clone();
92+
global_storage_set(&t);
93+
});
94+
}
95+
96+
fn global_storage_get(t: &mut Tree) {
97+
GLOBAL_STORAGE
98+
.lock()
99+
.and_then(|v| unsafe {
100+
let ptr = v.clone() as *mut Tree;
101+
match ptr.as_mut() {
102+
Some(tree) => {
103+
t.clear();
104+
t.clone_from(tree);
105+
}
106+
None => todo!(),
107+
};
108+
Ok(())
109+
})
110+
.unwrap();
111+
}
112+
60113
#[derive(Debug)]
61114
pub struct Storage {
62115
pub parent: Rc<RefCell<Tree>>,
@@ -114,10 +167,11 @@ impl Storage {
114167

115168
pub fn exit(&mut self) {
116169
MGR.with(|m| {
117-
if let Some(keys) = m.borrow_mut().stack.pop() {
170+
let mut m = m.borrow_mut();
171+
if let Some(keys) = m.stack.pop() {
118172
keys.borrow()
119173
.iter()
120-
.for_each(|k| tree_rollback(m.borrow_mut().tls.borrow_mut(), *k));
174+
.for_each(|k| tree_rollback(m.tls.borrow_mut(), *k));
121175
}
122176
});
123177
// MGR.with_borrow_mut(|m| {

0 commit comments

Comments
 (0)