|
| 1 | +# Copyright 2021 The TensorFlow Quantum Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +"""The SPSA minimization algorithm""" |
| 16 | +import collections |
| 17 | +import tensorflow as tf |
| 18 | +import numpy as np |
| 19 | + |
| 20 | + |
| 21 | +def prefer_static_shape(x): |
| 22 | + """Return static shape of tensor `x` if available, |
| 23 | +
|
| 24 | + else `tf.shape(x)`. |
| 25 | +
|
| 26 | + Args: |
| 27 | + x: `tf.Tensor` (already converted). |
| 28 | + Returns: |
| 29 | + Numpy array (if static shape is obtainable), else `tf.Tensor`. |
| 30 | + """ |
| 31 | + return prefer_static_value(tf.shape(x)) |
| 32 | + |
| 33 | + |
| 34 | +def prefer_static_value(x): |
| 35 | + """Return static value of tensor `x` if available, else `x`. |
| 36 | +
|
| 37 | + Args: |
| 38 | + x: `tf.Tensor` (already converted). |
| 39 | + Returns: |
| 40 | + Numpy array (if static value is obtainable), else `tf.Tensor`. |
| 41 | + """ |
| 42 | + static_x = tf.get_static_value(x) |
| 43 | + if static_x is not None: |
| 44 | + return static_x |
| 45 | + return x |
| 46 | + |
| 47 | + |
| 48 | +SPSAOptimizerResults = collections.namedtuple( |
| 49 | + 'SPSAOptimizerResults', |
| 50 | + [ |
| 51 | + 'converged', |
| 52 | + # Scalar boolean tensor indicating whether the minimum |
| 53 | + # was found within tolerance. |
| 54 | + 'num_iterations', |
| 55 | + # The number of iterations of the SPSA update. |
| 56 | + 'num_objective_evaluations', |
| 57 | + # The total number of objective |
| 58 | + # evaluations performed. |
| 59 | + 'position', |
| 60 | + # A tensor containing the last argument value found |
| 61 | + # during the search. If the search converged, then |
| 62 | + # this value is the argmin of the objective function. |
| 63 | + # A tensor containing the value of the objective from |
| 64 | + # previous iteration |
| 65 | + 'objective_value_previous_iteration', |
| 66 | + # Save the evaluated value of the objective function |
| 67 | + # from the previous iteration |
| 68 | + 'objective_value', |
| 69 | + # A tensor containing the value of the objective |
| 70 | + # function at the `position`. If the search |
| 71 | + # converged, then this is the (local) minimum of |
| 72 | + # the objective function. |
| 73 | + 'tolerance', |
| 74 | + # Define the stop criteria. Iteration will stop when the |
| 75 | + # objective value difference between two iterations is |
| 76 | + # smaller than tolerance |
| 77 | + 'lr', |
| 78 | + # Specifies the learning rate |
| 79 | + 'alpha', |
| 80 | + # Specifies scaling of the learning rate |
| 81 | + 'perturb', |
| 82 | + # Specifies the size of the perturbations |
| 83 | + 'gamma', |
| 84 | + # Specifies scaling of the size of the perturbations |
| 85 | + 'blocking', |
| 86 | + # If true, then the optimizer will only accept updates that improve |
| 87 | + # the objective function. |
| 88 | + 'allowed_increase' |
| 89 | + # Specifies maximum allowable increase in objective function |
| 90 | + # (only applies if blocking is true). |
| 91 | + ]) |
| 92 | + |
| 93 | + |
| 94 | +def _get_initial_state(initial_position, tolerance, expectation_value_function, |
| 95 | + lr, alpha, perturb, gamma, blocking, allowed_increase): |
| 96 | + """Create SPSAOptimizerResults with initial state of search.""" |
| 97 | + init_args = { |
| 98 | + "converged": tf.Variable(False), |
| 99 | + "num_iterations": tf.Variable(0), |
| 100 | + "num_objective_evaluations": tf.Variable(0), |
| 101 | + "position": tf.Variable(initial_position), |
| 102 | + "objective_value": tf.Variable(0.), |
| 103 | + "objective_value_previous_iteration": tf.Variable(np.inf), |
| 104 | + "tolerance": tolerance, |
| 105 | + "lr": tf.Variable(lr), |
| 106 | + "alpha": tf.Variable(alpha), |
| 107 | + "perturb": tf.Variable(perturb), |
| 108 | + "gamma": tf.Variable(gamma), |
| 109 | + "blocking": tf.Variable(blocking), |
| 110 | + "allowed_increase": tf.Variable(allowed_increase) |
| 111 | + } |
| 112 | + return SPSAOptimizerResults(**init_args) |
| 113 | + |
| 114 | + |
| 115 | +def minimize(expectation_value_function, |
| 116 | + initial_position, |
| 117 | + tolerance=1e-5, |
| 118 | + max_iterations=200, |
| 119 | + alpha=0.602, |
| 120 | + lr=1.0, |
| 121 | + perturb=1.0, |
| 122 | + gamma=0.101, |
| 123 | + blocking=False, |
| 124 | + allowed_increase=0.5, |
| 125 | + seed=None, |
| 126 | + name=None): |
| 127 | + """Applies the SPSA algorithm. |
| 128 | +
|
| 129 | + The SPSA algorithm can be used to minimize a noisy function. See: |
| 130 | +
|
| 131 | + [SPSA website](https://www.jhuapl.edu/SPSA/) |
| 132 | +
|
| 133 | + Usage: |
| 134 | +
|
| 135 | + Here is an example of optimize a function which consists the |
| 136 | + summation of a few quadratics. |
| 137 | +
|
| 138 | + >>> n = 5 # Number of quadractics |
| 139 | + >>> coefficient = tf.random.uniform(minval=0, maxval=1, shape=[n]) |
| 140 | + >>> min_value = 0 |
| 141 | + >>> func = func = lambda x : tf.math.reduce_sum(np.power(x, 2) * \ |
| 142 | + coefficient) |
| 143 | + >>> # Optimize the function with SPSA, start with random parameters |
| 144 | + >>> result = tfq.optimizers.spsa_minimize(func, np.random.random(n)) |
| 145 | + >>> result.converged |
| 146 | + tf.Tensor(True, shape=(), dtype=bool) |
| 147 | + >>> result.objective_value |
| 148 | + tf.Tensor(0.0013349084, shape=(), dtype=float32) |
| 149 | +
|
| 150 | + Args: |
| 151 | + expectation_value_function: Python callable that accepts a real |
| 152 | + valued tf.Tensor with shape [n] where n is the number of function |
| 153 | + parameters. The return value is a real `tf.Tensor` Scalar |
| 154 | + (matching shape `[1]`). |
| 155 | + initial_position: Real `tf.Tensor` of shape `[n]`. The starting |
| 156 | + point, or points when using batching dimensions, of the search |
| 157 | + procedure. At these points the function value and the gradient |
| 158 | + norm should be finite. |
| 159 | + tolerance: Scalar `tf.Tensor` of real dtype. Specifies the tolerance |
| 160 | + for the procedure. If the supremum norm between two iteration |
| 161 | + vector is below this number, the algorithm is stopped. |
| 162 | + lr: Scalar `tf.Tensor` of real dtype. Specifies the learning rate |
| 163 | + alpha: Scalar `tf.Tensor` of real dtype. Specifies scaling of the |
| 164 | + learning rate. |
| 165 | + perturb: Scalar `tf.Tensor` of real dtype. Specifies the size of the |
| 166 | + perturbations. |
| 167 | + gamma: Scalar `tf.Tensor` of real dtype. Specifies scaling of the |
| 168 | + size of the perturbations. |
| 169 | + blocking: Boolean. If true, then the optimizer will only accept |
| 170 | + updates that improve the objective function. |
| 171 | + allowed_increase: Scalar `tf.Tensor` of real dtype. Specifies maximum |
| 172 | + allowable increase in objective function (only applies if blocking |
| 173 | + is true). |
| 174 | + seed: (Optional) Python integer. Used to create a random seed for the |
| 175 | + perturbations. |
| 176 | + name: (Optional) Python `str`. The name prefixed to the ops created |
| 177 | + by this function. If not supplied, the default name 'minimize' |
| 178 | + is used. |
| 179 | +
|
| 180 | + Returns: |
| 181 | + optimizer_results: A SPSAOptimizerResults object contains the |
| 182 | + result of the optimization process. |
| 183 | + """ |
| 184 | + |
| 185 | + with tf.name_scope(name or 'minimize'): |
| 186 | + if seed is not None: |
| 187 | + generator = tf.random.Generator.from_seed(seed) |
| 188 | + else: |
| 189 | + generator = tf.random |
| 190 | + |
| 191 | + initial_position = tf.convert_to_tensor(initial_position, |
| 192 | + name='initial_position', |
| 193 | + dtype='float32') |
| 194 | + dtype = initial_position.dtype.base_dtype |
| 195 | + tolerance = tf.convert_to_tensor(tolerance, |
| 196 | + dtype=dtype, |
| 197 | + name='grad_tolerance') |
| 198 | + max_iterations = tf.convert_to_tensor(max_iterations, |
| 199 | + name='max_iterations') |
| 200 | + |
| 201 | + lr_init = tf.convert_to_tensor(lr, name='initial_a', dtype='float32') |
| 202 | + perturb_init = tf.convert_to_tensor(perturb, |
| 203 | + name='initial_c', |
| 204 | + dtype='float32') |
| 205 | + |
| 206 | + def _spsa_once(state): |
| 207 | + """Caclulate single SPSA gradient estimation |
| 208 | +
|
| 209 | + Args: |
| 210 | + state: A SPSAOptimizerResults object stores the |
| 211 | + current state of the minimizer. |
| 212 | +
|
| 213 | + Returns: |
| 214 | + states: A list which the first element is the new state |
| 215 | + """ |
| 216 | + delta_shift = tf.cast( |
| 217 | + 2 * generator.uniform(shape=state.position.shape, |
| 218 | + minval=0, |
| 219 | + maxval=2, |
| 220 | + dtype=tf.int32) - 1, tf.float32) |
| 221 | + v_m = expectation_value_function(state.position - |
| 222 | + state.perturb * delta_shift) |
| 223 | + v_p = expectation_value_function(state.position + |
| 224 | + state.perturb * delta_shift) |
| 225 | + |
| 226 | + gradient_estimate = (v_p - v_m) / (2 * state.perturb) * delta_shift |
| 227 | + update = state.lr * gradient_estimate |
| 228 | + |
| 229 | + state.num_objective_evaluations.assign_add(2) |
| 230 | + |
| 231 | + current_obj = expectation_value_function(state.position - update) |
| 232 | + if state.objective_value_previous_iteration + \ |
| 233 | + state.allowed_increase >= current_obj or not state.blocking: |
| 234 | + state.position.assign(state.position - update) |
| 235 | + state.objective_value_previous_iteration.assign( |
| 236 | + state.objective_value) |
| 237 | + state.objective_value.assign(current_obj) |
| 238 | + |
| 239 | + return [state] |
| 240 | + |
| 241 | + # The `state` here is a `SPSAOptimizerResults` tuple with |
| 242 | + # values for the current state of the algorithm computation. |
| 243 | + def _cond(state): |
| 244 | + """Continue if iterations remain and stopping condition |
| 245 | + is not met.""" |
| 246 | + return (state.num_iterations < max_iterations) \ |
| 247 | + and (not state.converged) |
| 248 | + |
| 249 | + def _body(state): |
| 250 | + """Main optimization loop.""" |
| 251 | + new_lr = lr_init / ( |
| 252 | + (tf.cast(state.num_iterations + 1, tf.float32) + |
| 253 | + 0.01 * tf.cast(max_iterations, tf.float32))**state.alpha) |
| 254 | + new_perturb = perturb_init / (tf.cast(state.num_iterations + 1, |
| 255 | + tf.float32)**state.gamma) |
| 256 | + |
| 257 | + state.lr.assign(new_lr) |
| 258 | + state.perturb.assign(new_perturb) |
| 259 | + |
| 260 | + _spsa_once(state) |
| 261 | + state.num_iterations.assign_add(1) |
| 262 | + state.converged.assign( |
| 263 | + tf.abs(state.objective_value - |
| 264 | + state.objective_value_previous_iteration) < |
| 265 | + state.tolerance) |
| 266 | + return [state] |
| 267 | + |
| 268 | + initial_state = _get_initial_state(initial_position, tolerance, |
| 269 | + expectation_value_function, lr, |
| 270 | + alpha, perturb, gamma, blocking, |
| 271 | + allowed_increase) |
| 272 | + |
| 273 | + initial_state.objective_value.assign( |
| 274 | + tf.cast(expectation_value_function(initial_state.position), |
| 275 | + tf.float32)) |
| 276 | + |
| 277 | + return tf.while_loop(cond=_cond, |
| 278 | + body=_body, |
| 279 | + loop_vars=[initial_state], |
| 280 | + parallel_iterations=1)[0] |
0 commit comments