12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
# ==============================================================================
15
- """The SPSA minimization algorithm"""
16
- import collections
15
+ """The SPSA minimization algorithm."""
17
16
import tensorflow as tf
18
17
import numpy as np
19
18
@@ -46,6 +45,7 @@ def prefer_static_value(x):
46
45
47
46
48
47
class SPSAOptimizerResults (tf .experimental .ExtensionType ):
48
+ """ExtentionType of SPSA Optimizer tf.while_loop() inner state."""
49
49
converged : tf .Tensor
50
50
# Scalar boolean tensor indicating whether the minimum
51
51
# was found within tolerance.
@@ -60,7 +60,7 @@ class SPSAOptimizerResults(tf.experimental.ExtensionType):
60
60
# this value is the argmin of the objective function.
61
61
# A tensor containing the value of the objective from
62
62
# previous iteration
63
- objective_value_previous_iteration : tf .Tensor
63
+ objective_value_prev : tf .Tensor
64
64
# Save the evaluated value of the objective function
65
65
# from the previous iteration
66
66
objective_value : tf .Tensor
@@ -72,7 +72,7 @@ class SPSAOptimizerResults(tf.experimental.ExtensionType):
72
72
# Define the stop criteria. Iteration will stop when the
73
73
# objective value difference between two iterations is
74
74
# smaller than tolerance
75
- lr : tf .Tensor
75
+ learning_rate : tf .Tensor
76
76
# Specifies the learning rate
77
77
alpha : tf .Tensor
78
78
# Specifies scaling of the learning rate
@@ -89,38 +89,27 @@ class SPSAOptimizerResults(tf.experimental.ExtensionType):
89
89
# (only applies if blocking is true).
90
90
91
91
def to_dict (self ):
92
+ """Transforms immutable data to mutable dictionary."""
92
93
return {
93
- "converged" :
94
- self .converged ,
95
- "num_iterations" :
96
- self .num_iterations ,
97
- "num_objective_evaluations" :
98
- self .num_objective_evaluations ,
99
- "position" :
100
- self .position ,
101
- "objective_value" :
102
- self .objective_value ,
103
- "objective_value_previous_iteration" :
104
- self .objective_value_previous_iteration ,
105
- "tolerance" :
106
- self .tolerance ,
107
- "lr" :
108
- self .lr ,
109
- "alpha" :
110
- self .alpha ,
111
- "perturb" :
112
- self .perturb ,
113
- "gamma" :
114
- self .gamma ,
115
- "blocking" :
116
- self .blocking ,
117
- "allowed_increase" :
118
- self .allowed_increase ,
94
+ "converged" : self .converged ,
95
+ "num_iterations" : self .num_iterations ,
96
+ "num_objective_evaluations" : self .num_objective_evaluations ,
97
+ "position" : self .position ,
98
+ "objective_value" : self .objective_value ,
99
+ "objective_value_prev" : self .objective_value_prev ,
100
+ "tolerance" : self .tolerance ,
101
+ "learning_rate" : self .learning_rate ,
102
+ "alpha" : self .alpha ,
103
+ "perturb" : self .perturb ,
104
+ "gamma" : self .gamma ,
105
+ "blocking" : self .blocking ,
106
+ "allowed_increase" : self .allowed_increase ,
119
107
}
120
108
121
109
122
110
def _get_initial_state (initial_position , tolerance , expectation_value_function ,
123
- lr , alpha , perturb , gamma , blocking , allowed_increase ):
111
+ learning_rate , alpha , perturb , gamma , blocking ,
112
+ allowed_increase ):
124
113
"""Create SPSAOptimizerResults with initial state of search."""
125
114
init_args = {
126
115
"converged" : tf .Variable (False ),
@@ -129,9 +118,9 @@ def _get_initial_state(initial_position, tolerance, expectation_value_function,
129
118
"position" : tf .Variable (initial_position ),
130
119
"objective_value" :
131
120
(tf .cast (expectation_value_function (initial_position ), tf .float32 )),
132
- "objective_value_previous_iteration " : tf .Variable (np .inf ),
121
+ "objective_value_prev " : tf .Variable (np .inf ),
133
122
"tolerance" : tolerance ,
134
- "lr " : tf .Variable (lr ),
123
+ "learning_rate " : tf .Variable (learning_rate ),
135
124
"alpha" : tf .Variable (alpha ),
136
125
"perturb" : tf .Variable (perturb ),
137
126
"gamma" : tf .Variable (gamma ),
@@ -146,7 +135,7 @@ def minimize(expectation_value_function,
146
135
tolerance = 1e-5 ,
147
136
max_iterations = 200 ,
148
137
alpha = 0.602 ,
149
- lr = 1.0 ,
138
+ learning_rate = 1.0 ,
150
139
perturb = 1.0 ,
151
140
gamma = 0.101 ,
152
141
blocking = False ,
@@ -188,7 +177,8 @@ def minimize(expectation_value_function,
188
177
tolerance: Scalar `tf.Tensor` of real dtype. Specifies the tolerance
189
178
for the procedure. If the supremum norm between two iteration
190
179
vector is below this number, the algorithm is stopped.
191
- lr: Scalar `tf.Tensor` of real dtype. Specifies the learning rate
180
+ learning_rate: Scalar `tf.Tensor` of real dtype.
181
+ Specifies the learning rate.
192
182
alpha: Scalar `tf.Tensor` of real dtype. Specifies scaling of the
193
183
learning rate.
194
184
perturb: Scalar `tf.Tensor` of real dtype. Specifies the size of the
@@ -227,7 +217,9 @@ def minimize(expectation_value_function,
227
217
max_iterations = tf .convert_to_tensor (max_iterations ,
228
218
name = 'max_iterations' )
229
219
230
- lr_init = tf .convert_to_tensor (lr , name = 'initial_a' , dtype = 'float32' )
220
+ learning_rate_init = tf .convert_to_tensor (learning_rate ,
221
+ name = 'initial_a' ,
222
+ dtype = 'float32' )
231
223
perturb_init = tf .convert_to_tensor (perturb ,
232
224
name = 'initial_c' ,
233
225
dtype = 'float32' )
@@ -253,7 +245,7 @@ def _spsa_once(state):
253
245
state .perturb * delta_shift )
254
246
255
247
gradient_estimate = (v_p - v_m ) / (2 * state .perturb ) * delta_shift
256
- update = state .lr * gradient_estimate
248
+ update = state .learning_rate * gradient_estimate
257
249
next_state_params = state .to_dict ()
258
250
next_state_params .update ({
259
251
"num_objective_evaluations" :
@@ -263,11 +255,11 @@ def _spsa_once(state):
263
255
current_obj = tf .cast (expectation_value_function (state .position -
264
256
update ),
265
257
dtype = tf .float32 )
266
- if state .objective_value_previous_iteration + \
258
+ if state .objective_value_prev + \
267
259
state .allowed_increase >= current_obj or not state .blocking :
268
260
next_state_params .update ({
269
261
"position" : state .position - update ,
270
- "objective_value_previous_iteration " : state .objective_value ,
262
+ "objective_value_prev " : state .objective_value ,
271
263
"objective_value" : current_obj
272
264
})
273
265
@@ -285,35 +277,35 @@ def _cond(state):
285
277
286
278
def _body (state ):
287
279
"""Main optimization loop."""
288
- new_lr = lr_init / (
280
+ new_learning_rate = learning_rate_init / (
289
281
(tf .cast (state .num_iterations + 1 , tf .float32 ) +
290
282
0.01 * tf .cast (max_iterations , tf .float32 ))** state .alpha )
291
283
new_perturb = perturb_init / (tf .cast (state .num_iterations + 1 ,
292
284
tf .float32 )** state .gamma )
293
285
294
286
pre_state_params = state .to_dict ()
295
287
pre_state_params .update ({
296
- "lr " : new_lr ,
288
+ "learning_rate " : new_learning_rate ,
297
289
"perturb" : new_perturb ,
298
290
})
299
291
300
292
post_state = _spsa_once (SPSAOptimizerResults (** pre_state_params ))[0 ]
301
293
post_state_params = post_state .to_dict ()
302
294
tf .print ("asdf" , state .objective_value .dtype ,
303
- state .objective_value_previous_iteration .dtype )
295
+ state .objective_value_prev .dtype )
304
296
post_state_params .update ({
305
297
"num_iterations" :
306
298
post_state .num_iterations + 1 ,
307
- "converged" : ( tf . abs ( state . objective_value -
308
- state .objective_value_previous_iteration ) <
309
- state .tolerance ),
299
+ "converged" :
300
+ ( tf . abs ( state . objective_value - state .objective_value_prev )
301
+ < state .tolerance ),
310
302
})
311
303
return [SPSAOptimizerResults (** post_state_params )]
312
304
313
305
initial_state = _get_initial_state (initial_position , tolerance ,
314
- expectation_value_function , lr ,
315
- alpha , perturb , gamma , blocking ,
316
- allowed_increase )
306
+ expectation_value_function ,
307
+ learning_rate , alpha , perturb , gamma ,
308
+ blocking , allowed_increase )
317
309
318
310
return tf .while_loop (cond = _cond ,
319
311
body = _body ,
0 commit comments