Skip to content

Commit c4d8549

Browse files
committed
support pase bool from string
1 parent 83e1811 commit c4d8549

File tree

5 files changed

+86
-33
lines changed

5 files changed

+86
-33
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ pyo3 = { version = "0.18.1", features = [
1717
"abi3-py37",
1818
] }
1919
lazy_static = "1.4.0"
20+
phf = { version = "0.11", features = ["macros"] }
2021

2122
[dev_dependencies]
2223
rspec = "1.0"

examples/cpp/cxx_test.cc

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ int main()
4646
<< "expected: abc" << std::endl
4747
<< "returned: " << abc << std::endl
4848
<< "expected: 0" << std::endl
49-
<< "returned: " << GETPARAM(a.b, 1) << std::endl;
49+
<< "returned: " << GETPARAM(a.b, 1) << std::endl
50+
<< "expected: false" << std::endl
51+
<< "returned: " << GETPARAM(a.b, "true") << std::endl;
5052

5153
std::cout << "\n:: (opt api) test undefined" << std::endl
5254
<< "expected: 100" << std::endl
@@ -56,5 +58,15 @@ int main()
5658

5759
std::cout << "test1.test2: " << GETPARAM(test1.test2, 100) << std::endl;
5860

61+
// ===== bool test ====
62+
std::cout << "\n:: test bool parameter" << std::endl
63+
<< "expected: true" << std::endl
64+
<< "returned: " << GETPARAM(test1.bool1, false) << std::endl
65+
<< "expected: true" << std::endl
66+
<< "returned: " << GETPARAM(test1.bool2, false) << std::endl
67+
<< "expected: false" << std::endl
68+
<< "returned: " << GETPARAM(test1.bool3, true) << std::endl
69+
<< "expected: false" << std::endl
70+
<< "returned: " << GETPARAM(test1.bool4, true) << std::endl;
5971
return 0;
6072
}

examples/cpp/cxx_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,8 @@
1010

1111
with param_scope():
1212
param_scope.test1.test2 = 2
13+
param_scope.test1.bool1 = "true"
14+
param_scope.test1.bool2 = "YES"
15+
param_scope.test1.bool3 = "FALSE"
16+
param_scope.test1.bool4 = "NO"
1317
a.main()

hyperparameter/hyperparameter.h

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ namespace hyperparameter
3232
{
3333
return finalize((len >= 32 ? h32bytes(p, len, seed) : seed + PRIME5) + len, p + (len & ~0x1F), len & 0x1F);
3434
}
35-
35+
3636
private:
3737
static constexpr uint64_t PRIME1 = 11400714785074694791ULL;
3838
static constexpr uint64_t PRIME2 = 14029467366897019727ULL;
@@ -98,114 +98,116 @@ namespace hyperparameter
9898
Hyperparameter() : _storage(hyper_create_storage()) {}
9999
~Hyperparameter() { hyper_destory_storage(_storage); }
100100

101-
void enter() { storage_enter(_storage); }
102-
void exit() { storage_exit(_storage); }
101+
inline void enter() { storage_enter(_storage); }
102+
inline void exit() { storage_exit(_storage); }
103103

104104
template <typename T>
105-
T get(uint64_t key, T def);
105+
inline T get(uint64_t key, T def);
106106

107107
template <typename T>
108-
T get(const std::string &key, T def) { return get(key.c_str(), key.size(), def); }
108+
inline T get(const std::string &key, T def) { return get(key.c_str(), key.size(), def); }
109109

110110
template <typename T>
111-
T get(const char *key, int keylen, T def) { return get(xxhash(key, keylen), def); }
111+
inline T get(const char *key, int keylen, T def) { return get(xxhash(key, keylen), def); }
112112

113113
template <typename T>
114-
void put(const std::string &key, T val) { put(key.c_str(), val); }
114+
inline void put(const std::string &key, T val) { put(key.c_str(), val); }
115115

116116
template <typename T>
117-
void put(const char *key, T val);
117+
inline void put(const char *key, T val);
118118
};
119119

120-
Hyperparameter *create() { return new Hyperparameter(); }
121-
std::shared_ptr<Hyperparameter> create_shared() { return std::make_shared<Hyperparameter>(); }
120+
inline Hyperparameter *create() { return new Hyperparameter(); }
121+
inline std::shared_ptr<Hyperparameter> create_shared() { return std::make_shared<Hyperparameter>(); }
122122

123123
template <>
124-
int64_t Hyperparameter::get<int64_t>(uint64_t key, int64_t def)
124+
inline int64_t Hyperparameter::get<int64_t>(uint64_t key, int64_t def)
125125
{
126126
return storage_hget_or_i64(_storage, key, def);
127127
}
128128

129129
template <>
130-
int32_t Hyperparameter::get<int32_t>(uint64_t key, int32_t def)
130+
inline int32_t Hyperparameter::get<int32_t>(uint64_t key, int32_t def)
131131
{
132132
return storage_hget_or_i64(_storage, key, def);
133133
}
134134

135135
template <>
136-
double Hyperparameter::get<double>(uint64_t key, double def)
136+
inline double Hyperparameter::get<double>(uint64_t key, double def)
137137
{
138138
return storage_hget_or_f64(_storage, key, def);
139139
}
140140

141141
template <>
142-
bool Hyperparameter::get<bool>(uint64_t key, bool def)
142+
inline bool Hyperparameter::get<bool>(uint64_t key, bool def)
143143
{
144144
return storage_hget_or_bool(_storage, key, def);
145145
}
146146

147147
template <>
148-
std::string Hyperparameter::get<std::string>(uint64_t key, std::string def)
148+
inline std::string Hyperparameter::get<std::string>(uint64_t key, std::string def)
149149
{
150150
return std::string(storage_hget_or_str(_storage, key, def.c_str()));
151151
}
152152

153153
template <>
154-
const char *Hyperparameter::get<const char *>(uint64_t key, const char *def)
154+
inline const char *Hyperparameter::get<const char *>(uint64_t key, const char *def)
155155
{
156156
return storage_hget_or_str(_storage, key, def);
157157
}
158158

159159
template <>
160-
void Hyperparameter::put<int64_t>(const char *key, int64_t val)
160+
inline void Hyperparameter::put<int64_t>(const char *key, int64_t val)
161161
{
162162
return storage_put_i64(_storage, key, val);
163163
}
164164

165165
template <>
166-
void Hyperparameter::put<int32_t>(const char *key, int32_t val)
166+
inline void Hyperparameter::put<int32_t>(const char *key, int32_t val)
167167
{
168168
return storage_put_i64(_storage, key, val);
169169
}
170170

171171
template <>
172-
void Hyperparameter::put<double>(const char *key, double val)
172+
inline void Hyperparameter::put<double>(const char *key, double val)
173173
{
174174
return storage_put_f64(_storage, key, val);
175175
}
176176

177177
template <>
178-
void Hyperparameter::put<bool>(const char *key, bool val)
178+
inline void Hyperparameter::put<bool>(const char *key, bool val)
179179
{
180180
return storage_put_bool(_storage, key, val);
181181
}
182182

183183
template <>
184-
void Hyperparameter::put<const std::string &>(const char *key, const std::string &val)
184+
inline void Hyperparameter::put<const std::string &>(const char *key, const std::string &val)
185185
{
186186
return storage_put_str(_storage, key, val.c_str());
187187
}
188188

189189
template <>
190-
void Hyperparameter::put<const char *>(const char *key, const char *val)
190+
inline void Hyperparameter::put<const char *>(const char *key, const char *val)
191191
{
192192
return storage_put_str(_storage, key, val);
193193
}
194194

195-
std::shared_ptr<hyperparameter::Hyperparameter> get_hp() {
196-
static std::shared_ptr<Hyperparameter> hp;
197-
if (!hp) {
198-
hp = hyperparameter::create_shared();
199-
}
200-
return hp;
195+
inline std::shared_ptr<hyperparameter::Hyperparameter> get_hp()
196+
{
197+
static std::shared_ptr<Hyperparameter> hp;
198+
if (!hp)
199+
{
200+
hp = hyperparameter::create_shared();
201+
}
202+
return hp;
201203
}
202204
}
203205

204-
#define GETHP hyperparameter::get_hp()
206+
#define GETHP hyperparameter::get_hp()
205207

206208
// Implicit create hyperparameter object
207-
#define GETPARAM(p, default_val) \
208-
(GETHP->get(([](){ constexpr uint64_t x = hyperparameter::xxhash(#p,sizeof(#p)-1); return x;})(), default_val))
209+
#define GETPARAM(p, default_val) \
210+
(GETHP->get(([]() { constexpr uint64_t x = hyperparameter::xxhash(#p,sizeof(#p)-1); return x; })(), default_val))
209211
#define PUTPARAM(p, default_val) (GETHP->put(#p, default_val))
210212

211213
#endif

src/entry.rs

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::{ffi::c_void, sync::Arc};
2+
use phf::phf_map;
23

34
#[derive(Debug, Clone, PartialEq)]
45
pub struct DeferUnsafe(pub *mut c_void, pub unsafe fn(*mut c_void));
@@ -126,6 +127,34 @@ impl TryFrom<Value> for String {
126127
}
127128
}
128129

130+
static STR2BOOL: phf::Map<&'static str, bool> = phf_map! {
131+
"true" => true,
132+
"True" => true,
133+
"TRUE" => true,
134+
"T" => true,
135+
"yes" => true,
136+
"y" => true,
137+
"Yes" => true,
138+
"YES" => true,
139+
"Y" => true,
140+
"on" => true,
141+
"On" => true,
142+
"ON" => true,
143+
144+
"false" => false,
145+
"False" => false,
146+
"FALSE" => false,
147+
"F" => false,
148+
"no" => false,
149+
"n" => false,
150+
"No" => false,
151+
"NO" => false,
152+
"N" => false,
153+
"off" => false,
154+
"Off" => false,
155+
"OFF" => false,
156+
};
157+
129158
impl TryFrom<Value> for bool {
130159
type Error = String;
131160

@@ -134,7 +163,12 @@ impl TryFrom<Value> for bool {
134163
Value::Empty => Err("empty value error".into()),
135164
Value::Int(v) => Ok(v != 0),
136165
Value::Float(_) => Err("data type not matched, `Float` and bool".into()),
137-
Value::Text(_) => Err("data type not matched, `Text` and bool".into()),
166+
Value::Text(s) => {
167+
match STR2BOOL.get(&s) {
168+
Some(v) => Ok(v.clone()),
169+
None => Err("data type not matched, `Text` and bool".into()),
170+
}
171+
},
138172
Value::Boolen(v) => Ok(v),
139173
Value::UserDefined(_, _, _) => {
140174
Err("data type not matched, `UserDefined` and str".into())

0 commit comments

Comments
 (0)