Skip to content

Commit 30b0cde

Browse files
authored
Optimize GP predictions for Egor optimizer usage (#332)
* Add correlation models bench * Simplify diff matrix computation * Implement predict_valvar for GP to compute value and var in one go * Use predict_valvar where needed * Implement predict_valvar_gradients * Add sanity check test for valvar variant prediction * Use predict_valvar_gradients
1 parent 08ce040 commit 30b0cde

File tree

15 files changed

+783
-268
lines changed

15 files changed

+783
-268
lines changed

crates/ego/benches/ego.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use criterion::{Criterion, criterion_group, criterion_main};
2-
use egobox_ego::{EgorBuilder, InfillStrategy};
2+
use egobox_ego::{EGOBOX_LOG, EgorBuilder, InfillStrategy};
33
use egobox_moe::{CorrelationSpec, RegressionSpec};
4+
use env_logger::{Builder, Env};
45
use ndarray::{Array2, ArrayView2, Zip, array};
56

67
/// Ackley test function: min f(x)=0 at x=(0, 0, 0)
@@ -15,15 +16,21 @@ fn ackley(x: &ArrayView2<f64>) -> Array2<f64> {
1516
fn criterion_ego(c: &mut Criterion) {
1617
let xlimits = array![[-32.768, 32.768], [-32.768, 32.768], [-32.768, 32.768]];
1718
let mut group = c.benchmark_group("ego");
19+
group.sample_size(20);
1820
group.bench_function("ego ackley", |b| {
21+
let env = Env::new().filter_or(EGOBOX_LOG, "error");
22+
let mut builder = Builder::from_env(env);
23+
let builder = builder.target(env_logger::Target::Stdout);
24+
builder.try_init().ok();
25+
1926
b.iter(|| {
2027
std::hint::black_box(
2128
EgorBuilder::optimize(ackley)
2229
.configure(|config| {
2330
config
2431
.configure_gp(|conf| {
2532
conf.regression_spec(RegressionSpec::CONSTANT)
26-
.correlation_spec(CorrelationSpec::ABSOLUTEEXPONENTIAL)
33+
.correlation_spec(CorrelationSpec::MATERN52)
2734
})
2835
.infill_strategy(InfillStrategy::WB2S)
2936
.max_iters(10)

crates/ego/src/criteria/ei.rs

Lines changed: 72 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,20 @@ impl InfillCriterion for ExpectedImprovement {
2828
_scale: Option<f64>,
2929
) -> f64 {
3030
let pt = ArrayView::from_shape((1, x.len()), x).unwrap();
31-
match obj_model.predict(&pt) {
32-
Ok(p) => match obj_model.predict_var(&pt) {
33-
Ok(s) => {
34-
if s[0] < f64::EPSILON {
35-
0.0
36-
} else {
37-
let pred = p[0];
38-
let k = sigma_weight.unwrap_or(1.0);
39-
let sigma = k * s[0].sqrt();
40-
let args0 = (fmin - pred) / sigma;
41-
let args1 = args0 * norm_cdf(args0);
42-
let args2 = norm_pdf(args0);
43-
sigma * (args1 + args2)
44-
}
31+
match obj_model.predict_valvar(&pt) {
32+
Ok((p, s)) => {
33+
if s[0] < f64::EPSILON {
34+
0.0
35+
} else {
36+
let pred = p[0];
37+
let k = sigma_weight.unwrap_or(1.0);
38+
let sigma = k * s[0].sqrt();
39+
let args0 = (fmin - pred) / sigma;
40+
let args1 = args0 * norm_cdf(args0);
41+
let args2 = norm_pdf(args0);
42+
sigma * (args1 + args2)
4543
}
46-
_ => 0.0,
47-
},
44+
}
4845
_ => 0.0,
4946
}
5047
}
@@ -60,36 +57,32 @@ impl InfillCriterion for ExpectedImprovement {
6057
_scale: Option<f64>,
6158
) -> Array1<f64> {
6259
let pt = ArrayView::from_shape((1, x.len()), x).unwrap();
63-
match obj_model.predict(&pt) {
64-
Ok(p) => match obj_model.predict_var(&pt) {
65-
Ok(s) => {
66-
if s[0] < f64::EPSILON {
67-
Array1::zeros(pt.len())
68-
} else {
69-
let pred = p[0];
70-
let diff_y = fmin - pred;
71-
let k = sigma_weight.unwrap_or(1.0);
72-
let sigma = s[0].sqrt();
73-
let arg = (fmin - pred) / (k * sigma);
74-
let y_prime = obj_model.predict_gradients(&pt).unwrap();
75-
let y_prime = y_prime.row(0);
76-
let sig_2_prime = obj_model.predict_var_gradients(&pt).unwrap();
77-
78-
let sig_2_prime = sig_2_prime.row(0);
79-
let sig_prime = sig_2_prime.mapv(|v| k * v / (2. * sigma));
80-
let arg_prime = y_prime.mapv(|v| v / (-k * sigma))
81-
- diff_y.to_owned() * sig_prime.mapv(|v| v / (k * sigma * k * sigma));
82-
let factor = k * sigma * (-arg / SQRT_2PI) * (-(arg * arg) / 2.).exp();
83-
84-
let arg1 = y_prime.mapv(|v| v * (-norm_cdf(arg)));
85-
let arg2 = diff_y * norm_pdf(arg) * arg_prime.to_owned();
86-
let arg3 = sig_prime.to_owned() * norm_pdf(arg);
87-
let arg4 = factor * arg_prime;
88-
arg1 + arg2 + arg3 + arg4
89-
}
60+
match obj_model.predict_valvar(&pt) {
61+
Ok((p, s)) => {
62+
if s[0] < f64::EPSILON {
63+
Array1::zeros(pt.len())
64+
} else {
65+
let pred = p[0];
66+
let diff_y = fmin - pred;
67+
let k = sigma_weight.unwrap_or(1.0);
68+
let sigma = s[0].sqrt();
69+
let arg = (fmin - pred) / (k * sigma);
70+
71+
let (y_prime, var_prime) = obj_model.predict_valvar_gradients(&pt).unwrap();
72+
let y_prime = y_prime.row(0);
73+
let sig_2_prime = var_prime.row(0);
74+
let sig_prime = sig_2_prime.mapv(|v| k * v / (2. * sigma));
75+
let arg_prime = y_prime.mapv(|v| v / (-k * sigma))
76+
- diff_y.to_owned() * sig_prime.mapv(|v| v / (k * sigma * k * sigma));
77+
let factor = k * sigma * (-arg / SQRT_2PI) * (-(arg * arg) / 2.).exp();
78+
79+
let arg1 = y_prime.mapv(|v| v * (-norm_cdf(arg)));
80+
let arg2 = diff_y * norm_pdf(arg) * arg_prime.to_owned();
81+
let arg3 = sig_prime.to_owned() * norm_pdf(arg);
82+
let arg4 = factor * arg_prime;
83+
arg1 + arg2 + arg3 + arg4
9084
}
91-
_ => Array1::zeros(pt.len()),
92-
},
85+
}
9386
_ => Array1::zeros(pt.len()),
9487
}
9588
}
@@ -120,20 +113,17 @@ impl InfillCriterion for LogExpectedImprovement {
120113
) -> f64 {
121114
let pt = ArrayView::from_shape((1, x.len()), x).unwrap();
122115

123-
match obj_model.predict(&pt) {
124-
Ok(p) => match obj_model.predict_var(&pt) {
125-
Ok(s) => {
126-
if s[0] < f64::EPSILON {
127-
f64::MIN
128-
} else {
129-
let pred = p[0];
130-
let sigma = s[0].sqrt();
131-
let u = (fmin - pred) / sigma;
132-
log_ei_helper(u) + sigma.ln()
133-
}
116+
match obj_model.predict_valvar(&pt) {
117+
Ok((p, s)) => {
118+
if s[0] < f64::EPSILON {
119+
f64::MIN
120+
} else {
121+
let pred = p[0];
122+
let sigma = s[0].sqrt();
123+
let u = (fmin - pred) / sigma;
124+
log_ei_helper(u) + sigma.ln()
134125
}
135-
_ => f64::MIN,
136-
},
126+
}
137127
_ => f64::MIN,
138128
}
139129
}
@@ -150,35 +140,31 @@ impl InfillCriterion for LogExpectedImprovement {
150140
) -> Array1<f64> {
151141
let pt = ArrayView::from_shape((1, x.len()), x).unwrap();
152142

153-
match obj_model.predict(&pt) {
154-
Ok(p) => match obj_model.predict_var(&pt) {
155-
Ok(s) => {
156-
if s[0] < f64::EPSILON {
157-
Array1::from_elem(pt.len(), f64::MIN)
158-
} else {
159-
let pred = p[0];
160-
let diff_y = fmin - pred;
161-
let sigma = s[0].sqrt();
162-
let arg = diff_y / sigma;
163-
164-
let y_prime = obj_model.predict_gradients(&pt).unwrap();
165-
let y_prime = y_prime.row(0);
166-
let sig_2_prime = obj_model.predict_var_gradients(&pt).unwrap();
167-
let sig_2_prime = sig_2_prime.row(0);
168-
let sig_prime = sig_2_prime.mapv(|v| v / (2. * sigma));
169-
170-
let arg_prime = y_prime.mapv(|v| v / (-sigma))
171-
- diff_y.to_owned() * sig_prime.mapv(|v| v / (sigma * sigma));
172-
173-
let dhelper = d_log_ei_helper(arg);
174-
let arg1 = arg_prime.mapv(|v| dhelper * v);
175-
176-
let arg2 = sig_prime / sigma;
177-
arg1 + arg2
178-
}
143+
match obj_model.predict_valvar(&pt) {
144+
Ok((p, s)) => {
145+
if s[0] < f64::EPSILON {
146+
Array1::from_elem(pt.len(), f64::MIN)
147+
} else {
148+
let pred = p[0];
149+
let diff_y = fmin - pred;
150+
let sigma = s[0].sqrt();
151+
let arg = diff_y / sigma;
152+
153+
let (y_prime, var_prime) = obj_model.predict_valvar_gradients(&pt).unwrap();
154+
let y_prime = y_prime.row(0);
155+
let sig_2_prime = var_prime.row(0);
156+
let sig_prime = sig_2_prime.mapv(|v| v / (2. * sigma));
157+
158+
let arg_prime = y_prime.mapv(|v| v / (-sigma))
159+
- diff_y.to_owned() * sig_prime.mapv(|v| v / (sigma * sigma));
160+
161+
let dhelper = d_log_ei_helper(arg);
162+
let arg1 = arg_prime.mapv(|v| dhelper * v);
163+
164+
let arg2 = sig_prime / sigma;
165+
arg1 + arg2
179166
}
180-
_ => Array1::from_elem(pt.len(), f64::MIN),
181-
},
167+
}
182168
_ => Array1::from_elem(pt.len(), f64::MIN),
183169
}
184170
}

crates/ego/src/gpmix/mixint.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,19 @@ impl GpSurrogate for MixintGpMixture {
606606
self.moe.predict_var(&xcast)
607607
}
608608

609+
fn predict_valvar(
610+
&self,
611+
x: &ArrayView2<f64>,
612+
) -> egobox_moe::Result<(Array1<f64>, Array1<f64>)> {
613+
let mut xcast = if self.work_in_folded_space {
614+
unfold_with_enum_mask(&self.xtypes, x)
615+
} else {
616+
x.to_owned()
617+
};
618+
cast_to_discrete_values_mut(&self.xtypes, &mut xcast);
619+
self.moe.predict_valvar(&xcast)
620+
}
621+
609622
/// Save Moe model in given file.
610623
#[cfg(feature = "persistent")]
611624
fn save(&self, path: &str, format: GpFileFormat) -> egobox_moe::Result<()> {
@@ -660,6 +673,19 @@ impl GpSurrogateExt for MixintGpMixture {
660673
self.moe.predict_var_gradients(&xcast)
661674
}
662675

676+
fn predict_valvar_gradients(
677+
&self,
678+
x: &ArrayView2<f64>,
679+
) -> egobox_moe::Result<(Array2<f64>, Array2<f64>)> {
680+
let mut xcast = if self.work_in_folded_space {
681+
unfold_with_enum_mask(&self.xtypes, x)
682+
} else {
683+
x.to_owned()
684+
};
685+
cast_to_discrete_values_mut(&self.xtypes, &mut xcast);
686+
self.moe.predict_valvar_gradients(&xcast)
687+
}
688+
663689
fn sample(&self, x: &ArrayView2<f64>, n_traj: usize) -> egobox_moe::Result<Array2<f64>> {
664690
let mut xcast = if self.work_in_folded_space {
665691
unfold_with_enum_mask(&self.xtypes, x)

crates/ego/src/solver/coego.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,14 +198,17 @@ where
198198
) -> Result<Array1<f64>> {
199199
let mut res: Vec<f64> = vec![];
200200
let x = &x.view().insert_axis(Axis(0));
201-
let sigma = obj_model.predict_var(&x.view()).unwrap()[0].sqrt();
201+
202+
let (pred, var) = obj_model.predict_valvar(x)?;
203+
let sigma = var[0].sqrt();
202204
// Use lower trust bound for a minimization
203-
let pred = obj_model.predict(x)?[0] - CSTR_DOUBT * sigma;
205+
let pred = pred[0] - CSTR_DOUBT * sigma;
204206
res.push(pred);
205207
for cstr_model in cstr_models {
206-
let sigma = cstr_model.predict_var(&x.view()).unwrap()[0].sqrt();
208+
let (pred, var) = cstr_model.predict_valvar(x)?;
209+
let sigma = var[0].sqrt();
207210
// Use upper trust bound
208-
res.push(cstr_model.predict(x)?[0] + CSTR_DOUBT * sigma);
211+
res.push(pred[0] + CSTR_DOUBT * sigma);
209212
}
210213
Ok(Array1::from_vec(res))
211214
}

crates/ego/src/solver/solver_computations.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -223,18 +223,19 @@ where
223223
active: &[usize],
224224
) -> f64 {
225225
let x = Array::from_shape_vec((1, x.len()), x.to_vec()).unwrap();
226-
let sigma = cstr_model.predict_var(&x.view()).unwrap()[0].sqrt();
227-
let cstr_val = cstr_model.predict(&x.view()).unwrap()[0];
226+
227+
let (pred, var) = cstr_model.predict_valvar(&x.view()).unwrap();
228+
let sigma = var[0].sqrt();
229+
let cstr_val = pred[0];
228230

229231
if let Some(grad) = gradient {
232+
let (pred_grad, var_grad) = cstr_model.predict_valvar_gradients(&x.view()).unwrap();
230233
let sigma_prime = if sigma < f64::EPSILON {
231234
0.
232235
} else {
233-
cstr_model.predict_var_gradients(&x.view()).unwrap()[[0, 0]] / (2. * sigma)
236+
var_grad[[0, 0]] / (2. * sigma)
234237
};
235-
let grd = cstr_model
236-
.predict_gradients(&x.view())
237-
.unwrap()
238+
let grd = pred_grad
238239
.row(0)
239240
.mapv(|v| (v + CSTR_DOUBT * sigma_prime) / scale_cstr)
240241
.to_vec();

crates/ego/src/utils/cstr_pof.rs

Lines changed: 26 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,17 @@ use ndarray::{Array1, ArrayView};
88
/// a constraint function wrt the tolerance (ie cstr <= cstr_tol)
99
fn pof(x: &[f64], cstr_model: &dyn MixtureGpSurrogate, cstr_tol: f64) -> f64 {
1010
let pt = ArrayView::from_shape((1, x.len()), x).unwrap();
11-
match cstr_model.predict(&pt) {
12-
Ok(p) => match cstr_model.predict_var(&pt) {
13-
Ok(s) => {
14-
if s[0] < f64::EPSILON {
15-
0.0
16-
} else {
17-
let pred = p[0];
18-
let sigma = s[0].sqrt();
19-
let args0 = (cstr_tol - pred) / sigma;
20-
norm_cdf(args0)
21-
}
11+
match cstr_model.predict_valvar(&pt) {
12+
Ok((p, s)) => {
13+
if s[0] < f64::EPSILON {
14+
0.0
15+
} else {
16+
let pred = p[0];
17+
let sigma = s[0].sqrt();
18+
let args0 = (cstr_tol - pred) / sigma;
19+
norm_cdf(args0)
2220
}
23-
_ => 0.0,
24-
},
21+
}
2522
_ => 0.0,
2623
}
2724
}
@@ -30,27 +27,23 @@ fn pof(x: &[f64], cstr_model: &dyn MixtureGpSurrogate, cstr_tol: f64) -> f64 {
3027
/// constraint surrogate model.
3128
fn pof_grad(x: &[f64], cstr_model: &dyn MixtureGpSurrogate, cstr_tol: f64) -> Array1<f64> {
3229
let pt = ArrayView::from_shape((1, x.len()), x).unwrap();
33-
match cstr_model.predict(&pt) {
34-
Ok(p) => match cstr_model.predict_var(&pt) {
35-
Ok(s) => {
36-
if s[0] < f64::EPSILON {
37-
Array1::zeros(pt.len())
38-
} else {
39-
let pred = p[0];
40-
let sigma = s[0].sqrt();
41-
let arg = (cstr_tol - pred) / sigma;
42-
let y_prime = cstr_model.predict_gradients(&pt).unwrap();
43-
let y_prime = y_prime.row(0);
44-
let sig_2_prime = cstr_model.predict_var_gradients(&pt).unwrap();
45-
let sig_2_prime = sig_2_prime.row(0);
46-
let sig_prime = sig_2_prime.mapv(|v| v / (2. * sigma));
47-
let arg_prime = y_prime.mapv(|v| v / (-sigma))
48-
+ sig_prime.mapv(|v| v * pred / (sigma * sigma));
49-
norm_pdf(arg) * arg_prime.to_owned()
50-
}
30+
match cstr_model.predict_valvar(&pt) {
31+
Ok((p, s)) => {
32+
if s[0] < f64::EPSILON {
33+
Array1::zeros(pt.len())
34+
} else {
35+
let pred = p[0];
36+
let sigma = s[0].sqrt();
37+
let arg = (cstr_tol - pred) / sigma;
38+
let (y_prime, var_prime) = cstr_model.predict_valvar_gradients(&pt).unwrap();
39+
let y_prime = y_prime.row(0);
40+
let sig_2_prime = var_prime.row(0);
41+
let sig_prime = sig_2_prime.mapv(|v| v / (2. * sigma));
42+
let arg_prime =
43+
y_prime.mapv(|v| v / (-sigma)) + sig_prime.mapv(|v| v * pred / (sigma * sigma));
44+
norm_pdf(arg) * arg_prime.to_owned()
5145
}
52-
_ => Array1::zeros(pt.len()),
53-
},
46+
}
5447
_ => Array1::zeros(pt.len()),
5548
}
5649
}

crates/gp/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,7 @@ argmin_testfunctions.workspace = true
6161
[[bench]]
6262
name = "gp"
6363
harness = false
64+
65+
[[bench]]
66+
name = "corr"
67+
harness = false

0 commit comments

Comments
 (0)