@@ -28,23 +28,20 @@ impl InfillCriterion for ExpectedImprovement {
2828 _scale : Option < f64 > ,
2929 ) -> f64 {
3030 let pt = ArrayView :: from_shape ( ( 1 , x. len ( ) ) , x) . unwrap ( ) ;
31- match obj_model. predict ( & pt) {
32- Ok ( p) => match obj_model. predict_var ( & pt) {
33- Ok ( s) => {
34- if s[ 0 ] < f64:: EPSILON {
35- 0.0
36- } else {
37- let pred = p[ 0 ] ;
38- let k = sigma_weight. unwrap_or ( 1.0 ) ;
39- let sigma = k * s[ 0 ] . sqrt ( ) ;
40- let args0 = ( fmin - pred) / sigma;
41- let args1 = args0 * norm_cdf ( args0) ;
42- let args2 = norm_pdf ( args0) ;
43- sigma * ( args1 + args2)
44- }
31+ match obj_model. predict_valvar ( & pt) {
32+ Ok ( ( p, s) ) => {
33+ if s[ 0 ] < f64:: EPSILON {
34+ 0.0
35+ } else {
36+ let pred = p[ 0 ] ;
37+ let k = sigma_weight. unwrap_or ( 1.0 ) ;
38+ let sigma = k * s[ 0 ] . sqrt ( ) ;
39+ let args0 = ( fmin - pred) / sigma;
40+ let args1 = args0 * norm_cdf ( args0) ;
41+ let args2 = norm_pdf ( args0) ;
42+ sigma * ( args1 + args2)
4543 }
46- _ => 0.0 ,
47- } ,
44+ }
4845 _ => 0.0 ,
4946 }
5047 }
@@ -60,36 +57,32 @@ impl InfillCriterion for ExpectedImprovement {
6057 _scale : Option < f64 > ,
6158 ) -> Array1 < f64 > {
6259 let pt = ArrayView :: from_shape ( ( 1 , x. len ( ) ) , x) . unwrap ( ) ;
63- match obj_model. predict ( & pt) {
64- Ok ( p) => match obj_model. predict_var ( & pt) {
65- Ok ( s) => {
66- if s[ 0 ] < f64:: EPSILON {
67- Array1 :: zeros ( pt. len ( ) )
68- } else {
69- let pred = p[ 0 ] ;
70- let diff_y = fmin - pred;
71- let k = sigma_weight. unwrap_or ( 1.0 ) ;
72- let sigma = s[ 0 ] . sqrt ( ) ;
73- let arg = ( fmin - pred) / ( k * sigma) ;
74- let y_prime = obj_model. predict_gradients ( & pt) . unwrap ( ) ;
75- let y_prime = y_prime. row ( 0 ) ;
76- let sig_2_prime = obj_model. predict_var_gradients ( & pt) . unwrap ( ) ;
77-
78- let sig_2_prime = sig_2_prime. row ( 0 ) ;
79- let sig_prime = sig_2_prime. mapv ( |v| k * v / ( 2. * sigma) ) ;
80- let arg_prime = y_prime. mapv ( |v| v / ( -k * sigma) )
81- - diff_y. to_owned ( ) * sig_prime. mapv ( |v| v / ( k * sigma * k * sigma) ) ;
82- let factor = k * sigma * ( -arg / SQRT_2PI ) * ( -( arg * arg) / 2. ) . exp ( ) ;
83-
84- let arg1 = y_prime. mapv ( |v| v * ( -norm_cdf ( arg) ) ) ;
85- let arg2 = diff_y * norm_pdf ( arg) * arg_prime. to_owned ( ) ;
86- let arg3 = sig_prime. to_owned ( ) * norm_pdf ( arg) ;
87- let arg4 = factor * arg_prime;
88- arg1 + arg2 + arg3 + arg4
89- }
60+ match obj_model. predict_valvar ( & pt) {
61+ Ok ( ( p, s) ) => {
62+ if s[ 0 ] < f64:: EPSILON {
63+ Array1 :: zeros ( pt. len ( ) )
64+ } else {
65+ let pred = p[ 0 ] ;
66+ let diff_y = fmin - pred;
67+ let k = sigma_weight. unwrap_or ( 1.0 ) ;
68+ let sigma = s[ 0 ] . sqrt ( ) ;
69+ let arg = ( fmin - pred) / ( k * sigma) ;
70+
71+ let ( y_prime, var_prime) = obj_model. predict_valvar_gradients ( & pt) . unwrap ( ) ;
72+ let y_prime = y_prime. row ( 0 ) ;
73+ let sig_2_prime = var_prime. row ( 0 ) ;
74+ let sig_prime = sig_2_prime. mapv ( |v| k * v / ( 2. * sigma) ) ;
75+ let arg_prime = y_prime. mapv ( |v| v / ( -k * sigma) )
76+ - diff_y. to_owned ( ) * sig_prime. mapv ( |v| v / ( k * sigma * k * sigma) ) ;
77+ let factor = k * sigma * ( -arg / SQRT_2PI ) * ( -( arg * arg) / 2. ) . exp ( ) ;
78+
79+ let arg1 = y_prime. mapv ( |v| v * ( -norm_cdf ( arg) ) ) ;
80+ let arg2 = diff_y * norm_pdf ( arg) * arg_prime. to_owned ( ) ;
81+ let arg3 = sig_prime. to_owned ( ) * norm_pdf ( arg) ;
82+ let arg4 = factor * arg_prime;
83+ arg1 + arg2 + arg3 + arg4
9084 }
91- _ => Array1 :: zeros ( pt. len ( ) ) ,
92- } ,
85+ }
9386 _ => Array1 :: zeros ( pt. len ( ) ) ,
9487 }
9588 }
@@ -120,20 +113,17 @@ impl InfillCriterion for LogExpectedImprovement {
120113 ) -> f64 {
121114 let pt = ArrayView :: from_shape ( ( 1 , x. len ( ) ) , x) . unwrap ( ) ;
122115
123- match obj_model. predict ( & pt) {
124- Ok ( p) => match obj_model. predict_var ( & pt) {
125- Ok ( s) => {
126- if s[ 0 ] < f64:: EPSILON {
127- f64:: MIN
128- } else {
129- let pred = p[ 0 ] ;
130- let sigma = s[ 0 ] . sqrt ( ) ;
131- let u = ( fmin - pred) / sigma;
132- log_ei_helper ( u) + sigma. ln ( )
133- }
116+ match obj_model. predict_valvar ( & pt) {
117+ Ok ( ( p, s) ) => {
118+ if s[ 0 ] < f64:: EPSILON {
119+ f64:: MIN
120+ } else {
121+ let pred = p[ 0 ] ;
122+ let sigma = s[ 0 ] . sqrt ( ) ;
123+ let u = ( fmin - pred) / sigma;
124+ log_ei_helper ( u) + sigma. ln ( )
134125 }
135- _ => f64:: MIN ,
136- } ,
126+ }
137127 _ => f64:: MIN ,
138128 }
139129 }
@@ -150,35 +140,31 @@ impl InfillCriterion for LogExpectedImprovement {
150140 ) -> Array1 < f64 > {
151141 let pt = ArrayView :: from_shape ( ( 1 , x. len ( ) ) , x) . unwrap ( ) ;
152142
153- match obj_model. predict ( & pt) {
154- Ok ( p) => match obj_model. predict_var ( & pt) {
155- Ok ( s) => {
156- if s[ 0 ] < f64:: EPSILON {
157- Array1 :: from_elem ( pt. len ( ) , f64:: MIN )
158- } else {
159- let pred = p[ 0 ] ;
160- let diff_y = fmin - pred;
161- let sigma = s[ 0 ] . sqrt ( ) ;
162- let arg = diff_y / sigma;
163-
164- let y_prime = obj_model. predict_gradients ( & pt) . unwrap ( ) ;
165- let y_prime = y_prime. row ( 0 ) ;
166- let sig_2_prime = obj_model. predict_var_gradients ( & pt) . unwrap ( ) ;
167- let sig_2_prime = sig_2_prime. row ( 0 ) ;
168- let sig_prime = sig_2_prime. mapv ( |v| v / ( 2. * sigma) ) ;
169-
170- let arg_prime = y_prime. mapv ( |v| v / ( -sigma) )
171- - diff_y. to_owned ( ) * sig_prime. mapv ( |v| v / ( sigma * sigma) ) ;
172-
173- let dhelper = d_log_ei_helper ( arg) ;
174- let arg1 = arg_prime. mapv ( |v| dhelper * v) ;
175-
176- let arg2 = sig_prime / sigma;
177- arg1 + arg2
178- }
143+ match obj_model. predict_valvar ( & pt) {
144+ Ok ( ( p, s) ) => {
145+ if s[ 0 ] < f64:: EPSILON {
146+ Array1 :: from_elem ( pt. len ( ) , f64:: MIN )
147+ } else {
148+ let pred = p[ 0 ] ;
149+ let diff_y = fmin - pred;
150+ let sigma = s[ 0 ] . sqrt ( ) ;
151+ let arg = diff_y / sigma;
152+
153+ let ( y_prime, var_prime) = obj_model. predict_valvar_gradients ( & pt) . unwrap ( ) ;
154+ let y_prime = y_prime. row ( 0 ) ;
155+ let sig_2_prime = var_prime. row ( 0 ) ;
156+ let sig_prime = sig_2_prime. mapv ( |v| v / ( 2. * sigma) ) ;
157+
158+ let arg_prime = y_prime. mapv ( |v| v / ( -sigma) )
159+ - diff_y. to_owned ( ) * sig_prime. mapv ( |v| v / ( sigma * sigma) ) ;
160+
161+ let dhelper = d_log_ei_helper ( arg) ;
162+ let arg1 = arg_prime. mapv ( |v| dhelper * v) ;
163+
164+ let arg2 = sig_prime / sigma;
165+ arg1 + arg2
179166 }
180- _ => Array1 :: from_elem ( pt. len ( ) , f64:: MIN ) ,
181- } ,
167+ }
182168 _ => Array1 :: from_elem ( pt. len ( ) , f64:: MIN ) ,
183169 }
184170 }
0 commit comments