|
| 1 | +# Copyright 2020 The TensorFlow Probability Authors. All Rights Reserved. |
| 2 | +# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ |
| 3 | +# THIS FILE IS AUTO-GENERATED BY `gen_linear_operators.py`. |
| 4 | +# DO NOT MODIFY DIRECTLY. |
| 5 | +# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ |
| 6 | +# pylint: disable=g-import-not-at-top |
| 7 | +# pylint: disable=g-direct-tensorflow-import |
| 8 | +# pylint: disable=g-bad-import-order |
| 9 | +# pylint: disable=unused-import |
| 10 | +# pylint: disable=line-too-long |
| 11 | +# pylint: disable=reimported |
| 12 | +# pylint: disable=g-bool-id-comparison |
| 13 | +# pylint: disable=g-statement-before-imports |
| 14 | +# pylint: disable=bad-continuation |
| 15 | +# pylint: disable=useless-import-alias |
| 16 | +# pylint: disable=property-with-parameters |
| 17 | +# pylint: disable=trailing-whitespace |
| 18 | +# pylint: disable=g-inconsistent-quotes |
| 19 | + |
| 20 | +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| 21 | +# |
| 22 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 23 | +# you may not use this file except in compliance with the License. |
| 24 | +# You may obtain a copy of the License at |
| 25 | +# |
| 26 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 27 | +# |
| 28 | +# Unless required by applicable law or agreed to in writing, software |
| 29 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 30 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 31 | +# See the License for the specific language governing permissions and |
| 32 | +# limitations under the License. |
| 33 | +# ============================================================================== |
| 34 | +"""`LinearOperator` acting like a permutation matrix.""" |
| 35 | + |
| 36 | +import numpy as np |
| 37 | + |
| 38 | +from tensorflow_probability.python.internal.backend.numpy import dtype as dtypes |
| 39 | +from tensorflow_probability.python.internal.backend.numpy import ops |
| 40 | +from tensorflow_probability.python.internal.backend.numpy import ops |
| 41 | +from tensorflow_probability.python.internal.backend.numpy import numpy_array as array_ops |
| 42 | +from tensorflow_probability.python.internal.backend.numpy import control_flow as control_flow_ops |
| 43 | +from tensorflow_probability.python.internal.backend.numpy import numpy_math as math_ops |
| 44 | +from tensorflow_probability.python.internal.backend.numpy import misc as sort_ops |
| 45 | +from tensorflow_probability.python.internal.backend.numpy import linalg_impl as linalg |
| 46 | +from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator |
| 47 | +from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_util |
| 48 | +# from tensorflow.python.util.tf_export import tf_export |
| 49 | + |
| 50 | +__all__ = ["LinearOperatorPermutation",] |
| 51 | + |
| 52 | + |
| 53 | +# @tf_export("linalg.LinearOperatorPermutation") |
| 54 | +# @linear_operator.make_composite_tensor |
| 55 | +class LinearOperatorPermutation(linear_operator.LinearOperator): |
| 56 | + """`LinearOperator` acting like a [batch] of permutation matrices. |
| 57 | +
|
| 58 | + This operator acts like a [batch] of permutations with shape |
| 59 | + `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a |
| 60 | + batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is |
| 61 | + an `N x N` matrix. This matrix `A` is not materialized, but for |
| 62 | + purposes of broadcasting this shape will be relevant. |
| 63 | +
|
| 64 | + `LinearOperatorPermutation` is initialized with a (batch) vector. |
| 65 | +
|
| 66 | + A permutation, is defined by an integer vector `v` whose values are unique |
| 67 | + and are in the range `[0, ... n]`. Applying the permutation on an input |
| 68 | + matrix has the folllowing meaning: the value of `v` at index `i` |
| 69 | + says to move the `v[i]`-th row of the input matrix to the `i`-th row. |
| 70 | + Because all values are unique, this will result in a permutation of the |
| 71 | + rows the input matrix. Note, that the permutation vector `v` has the same |
| 72 | + semantics as `tf.transpose`. |
| 73 | +
|
| 74 | + ```python |
| 75 | + # Create a 3 x 3 permutation matrix that swaps the last two columns. |
| 76 | + vec = [0, 2, 1] |
| 77 | + operator = LinearOperatorPermutation(vec) |
| 78 | +
|
| 79 | + operator.to_dense() |
| 80 | + ==> [[1., 0., 0.] |
| 81 | + [0., 0., 1.] |
| 82 | + [0., 1., 0.]] |
| 83 | +
|
| 84 | + tensor_shape.TensorShape(operator.shape) |
| 85 | + ==> [3, 3] |
| 86 | +
|
| 87 | + # This will be zero. |
| 88 | + operator.log_abs_determinant() |
| 89 | + ==> scalar Tensor |
| 90 | +
|
| 91 | + x = ... Shape [3, 4] Tensor |
| 92 | + operator.matmul(x) |
| 93 | + ==> Shape [3, 4] Tensor |
| 94 | + ``` |
| 95 | +
|
| 96 | + #### Shape compatibility |
| 97 | +
|
| 98 | + This operator acts on [batch] matrix with compatible shape. |
| 99 | + `x` is a batch matrix with compatible shape for `matmul` and `solve` if |
| 100 | +
|
| 101 | + ``` |
| 102 | + tensor_shape.TensorShape(operator.shape) = [B1,...,Bb] + [N, N], with b >= 0 |
| 103 | + tensor_shape.TensorShape(x.shape) = [C1,...,Cc] + [N, R], |
| 104 | + and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd] |
| 105 | + ``` |
| 106 | +
|
| 107 | + #### Matrix property hints |
| 108 | +
|
| 109 | + This `LinearOperator` is initialized with boolean flags of the form `is_X`, |
| 110 | + for `X = non_singular, self_adjoint, positive_definite, square`. |
| 111 | + These have the following meaning: |
| 112 | +
|
| 113 | + * If `is_X == True`, callers should expect the operator to have the |
| 114 | + property `X`. This is a promise that should be fulfilled, but is *not* a |
| 115 | + runtime assert. For example, finite floating point precision may result |
| 116 | + in these promises being violated. |
| 117 | + * If `is_X == False`, callers should expect the operator to not have `X`. |
| 118 | + * If `is_X == None` (the default), callers should have no expectation either |
| 119 | + way. |
| 120 | + """ |
| 121 | + |
| 122 | + def __init__(self, |
| 123 | + perm, |
| 124 | + dtype=dtypes.float32, |
| 125 | + is_non_singular=None, |
| 126 | + is_self_adjoint=None, |
| 127 | + is_positive_definite=None, |
| 128 | + is_square=None, |
| 129 | + name="LinearOperatorPermutation"): |
| 130 | + r"""Initialize a `LinearOperatorPermutation`. |
| 131 | +
|
| 132 | + Args: |
| 133 | + perm: Shape `[B1,...,Bb, N]` Integer `Tensor` with `b >= 0` |
| 134 | + `N >= 0`. An integer vector that represents the permutation to apply. |
| 135 | + Note that this argument is same as `tf.transpose`. However, this |
| 136 | + permutation is applied on the rows, while the permutation in |
| 137 | + `tf.transpose` is applied on the dimensions of the `Tensor`. `perm` |
| 138 | + is required to have unique entries from `{0, 1, ... N-1}`. |
| 139 | + dtype: The `dtype` of arguments to this operator. Default: `float32`. |
| 140 | + Allowed dtypes: `float16`, `float32`, `float64`, `complex64`, |
| 141 | + `complex128`. |
| 142 | + is_non_singular: Expect that this operator is non-singular. |
| 143 | + is_self_adjoint: Expect that this operator is equal to its hermitian |
| 144 | + transpose. This is autoset to true |
| 145 | + is_positive_definite: Expect that this operator is positive definite, |
| 146 | + meaning the quadratic form `x^H A x` has positive real part for all |
| 147 | + nonzero `x`. Note that we do not require the operator to be |
| 148 | + self-adjoint to be positive-definite. See: |
| 149 | + https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices |
| 150 | + This is autoset to false. |
| 151 | + is_square: Expect that this operator acts like square [batch] matrices. |
| 152 | + This is autoset to true. |
| 153 | + name: A name for this `LinearOperator`. |
| 154 | +
|
| 155 | + Raises: |
| 156 | + ValueError: `is_self_adjoint` is not `True`, `is_positive_definite` is |
| 157 | + not `False` or `is_square` is not `True`. |
| 158 | + """ |
| 159 | + parameters = dict( |
| 160 | + perm=perm, |
| 161 | + dtype=dtype, |
| 162 | + is_non_singular=is_non_singular, |
| 163 | + is_self_adjoint=is_self_adjoint, |
| 164 | + is_positive_definite=is_positive_definite, |
| 165 | + is_square=is_square, |
| 166 | + name=name |
| 167 | + ) |
| 168 | + |
| 169 | + with ops.name_scope(name, values=[perm]): |
| 170 | + self._perm = linear_operator_util.convert_nonref_to_tensor( |
| 171 | + perm, name="perm") |
| 172 | + self._check_perm(self._perm) |
| 173 | + |
| 174 | + # Check and auto-set hints. |
| 175 | + if is_non_singular is False: # pylint:disable=g-bool-id-comparison |
| 176 | + raise ValueError(f"A Permutation operator is always non-singular. " |
| 177 | + f"Expected argument `is_non_singular` to be True. " |
| 178 | + f"Received: {is_non_singular}.") |
| 179 | + |
| 180 | + if is_square is False: # pylint:disable=g-bool-id-comparison |
| 181 | + raise ValueError(f"A Permutation operator is always square. " |
| 182 | + f"Expected argument `is_square` to be True. " |
| 183 | + f"Received: {is_square}.") |
| 184 | + is_square = True |
| 185 | + |
| 186 | + super(LinearOperatorPermutation, self).__init__( |
| 187 | + dtype=dtype, |
| 188 | + is_non_singular=is_non_singular, |
| 189 | + is_self_adjoint=is_self_adjoint, |
| 190 | + is_positive_definite=is_positive_definite, |
| 191 | + is_square=is_square, |
| 192 | + parameters=parameters, |
| 193 | + name=name) |
| 194 | + |
| 195 | + def _check_perm(self, perm): |
| 196 | + """Static check of perm.""" |
| 197 | + if (tensor_shape.TensorShape(perm.shape).ndims is not None and tensor_shape.TensorShape(perm.shape).ndims < 1): |
| 198 | + raise ValueError(f"Argument `perm` must have at least 1 dimension. " |
| 199 | + f"Received: {perm}.") |
| 200 | + if not np.issubdtype(perm.dtype, np.integer): |
| 201 | + raise TypeError(f"Argument `perm` must be integer dtype. " |
| 202 | + f"Received: {perm}.") |
| 203 | + # Check that the permutation satisfies the uniqueness constraint. |
| 204 | + static_perm = ops.get_static_value(perm) |
| 205 | + if static_perm is not None: |
| 206 | + sorted_perm = np.sort(static_perm, axis=-1) |
| 207 | + if np.any(sorted_perm != np.arange(0, tensor_shape.TensorShape(static_perm.shape)[-1])): |
| 208 | + raise ValueError( |
| 209 | + f"Argument `perm` must be a vector of unique integers from " |
| 210 | + f"0 to {tensor_shape.TensorShape(static_perm.shape)[-1] - 1}.") |
| 211 | + |
| 212 | + def _shape(self): |
| 213 | + perm_shape = tensor_shape.TensorShape(self._perm.shape) |
| 214 | + return perm_shape.concatenate(perm_shape[-1:]) |
| 215 | + |
| 216 | + def _shape_tensor(self): |
| 217 | + perm_shape = prefer_static.shape(self._perm) |
| 218 | + k = perm_shape[-1] |
| 219 | + return prefer_static.concat((perm_shape, [k]), 0) |
| 220 | + |
| 221 | + def _assert_non_singular(self): |
| 222 | + return control_flow_ops.no_op("assert_non_singular") |
| 223 | + |
| 224 | + def _domain_dimension_tensor(self, perm=None): |
| 225 | + perm = perm if perm is not None else self.perm |
| 226 | + return prefer_static.shape(perm)[-1] |
| 227 | + |
| 228 | + def _matmul(self, x, adjoint=False, adjoint_arg=False): |
| 229 | + perm = ops.convert_to_tensor(self.perm) |
| 230 | + if adjoint and not self.is_self_adjoint: |
| 231 | + # TODO(srvasude): invert_permutation doesn't work on batches so we use |
| 232 | + # argsort. |
| 233 | + perm = sort_ops.argsort(perm, axis=-1) |
| 234 | + x = linalg.adjoint(x) if adjoint_arg else x |
| 235 | + |
| 236 | + # We need to broadcast x and the permutation since tf.gather doesn't |
| 237 | + # broadcast. |
| 238 | + broadcast_shape = _ops.broadcast_dynamic_shape( |
| 239 | + prefer_static.shape(x)[:-1], prefer_static.shape(perm)) |
| 240 | + k = prefer_static.shape(x)[-1] |
| 241 | + broadcast_x_shape = prefer_static.concat([broadcast_shape, [k]], axis=-1) |
| 242 | + x = _ops.broadcast_to(x, broadcast_x_shape) |
| 243 | + perm = _ops.broadcast_to(perm, broadcast_shape) |
| 244 | + |
| 245 | + m = prefer_static.shape(x)[-2] |
| 246 | + x = array_ops.reshape(x, [-1, m, k]) |
| 247 | + perm = array_ops.reshape(perm, [-1, m]) |
| 248 | + |
| 249 | + y = array_ops.gather(x, perm, axis=-2, batch_dims=1) |
| 250 | + return array_ops.reshape(y, broadcast_x_shape) |
| 251 | + |
| 252 | + # TODO(srvasude): Permutation parity is equivalent to the determinant. |
| 253 | + |
| 254 | + def _log_abs_determinant(self): |
| 255 | + # Permutation matrices have determinant +/- 1. |
| 256 | + return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype) |
| 257 | + |
| 258 | + def _solve(self, rhs, adjoint=False, adjoint_arg=False): |
| 259 | + # The inverse of a permutation matrix is the transpose matrix. |
| 260 | + # Apply a matmul and flip the adjoint bit. |
| 261 | + return self._matmul(rhs, adjoint=(not adjoint), adjoint_arg=adjoint_arg) |
| 262 | + |
| 263 | + def _to_dense(self): |
| 264 | + perm = ops.convert_to_tensor(self.perm) |
| 265 | + return _ops.cast(math_ops.equal( |
| 266 | + array_ops.range(0, self._domain_dimension_tensor(perm)), |
| 267 | + perm[..., _ops.newaxis]), self.dtype) |
| 268 | + |
| 269 | + def _diag_part(self): |
| 270 | + perm = ops.convert_to_tensor(self.perm) |
| 271 | + return _ops.cast(math_ops.equal( |
| 272 | + array_ops.range(0, self._domain_dimension_tensor(perm)), |
| 273 | + perm), self.dtype) |
| 274 | + |
| 275 | + def _cond(self): |
| 276 | + # Permutation matrices are rotations which have condition number 1. |
| 277 | + return array_ops.ones(self.batch_shape_tensor(), dtype=self.dtype) |
| 278 | + |
| 279 | + @property |
| 280 | + def perm(self): |
| 281 | + return self._perm |
| 282 | + |
| 283 | + @property |
| 284 | + def _composite_tensor_fields(self): |
| 285 | + return ("perm", "dtype") |
| 286 | + |
| 287 | + @property |
| 288 | + def _experimental_parameter_ndims_to_matrix_ndims(self): |
| 289 | + return {"perm": 1} |
| 290 | + |
| 291 | +import numpy as np |
| 292 | +from tensorflow_probability.python.internal.backend.numpy import linalg_impl as _linalg |
| 293 | +from tensorflow_probability.python.internal.backend.numpy import ops as _ops |
| 294 | +from tensorflow_probability.python.internal.backend.numpy.gen import tensor_shape |
| 295 | + |
| 296 | +from tensorflow_probability.python.internal.backend.numpy import private |
| 297 | +distribution_util = private.LazyLoader( |
| 298 | + "distribution_util", globals(), |
| 299 | + "tensorflow_probability.substrates.numpy.internal.distribution_util") |
| 300 | +tensorshape_util = private.LazyLoader( |
| 301 | + "tensorshape_util", globals(), |
| 302 | + "tensorflow_probability.substrates.numpy.internal.tensorshape_util") |
| 303 | +prefer_static = private.LazyLoader( |
| 304 | + "prefer_static", globals(), |
| 305 | + "tensorflow_probability.substrates.numpy.internal.prefer_static") |
| 306 | + |
0 commit comments