Skip to content

Commit b0d85b1

Browse files
srvasudetensorflower-gardener
authored andcommitted
Modify gen_linear_operators to properly map broadcast_dynamic_shape.
- Add `LinearOperatorPermutation`. PiperOrigin-RevId: 465162410
1 parent 1bdfdf3 commit b0d85b1

File tree

11 files changed

+325
-8
lines changed

11 files changed

+325
-8
lines changed

tensorflow_probability/python/internal/backend/jax/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ GEN_FILENAMES = [
9191
"gen/linear_operator_kronecker",
9292
"gen/linear_operator_lower_triangular",
9393
"gen/linear_operator_low_rank_update",
94+
"gen/linear_operator_permutation",
9495
"gen/linear_operator",
9596
"gen/linear_operator_toeplitz",
9697
"gen/linear_operator_util",

tensorflow_probability/python/internal/backend/meta/gen_linear_operators.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
'ops import linalg_ops': 'linalg_impl as linalg_ops',
4646
'ops import math_ops': 'numpy_math as math_ops',
4747
'ops import nn': 'nn',
48+
'ops import sort_ops': 'misc as sort_ops',
4849
'ops import variables as variables_module': 'ops as variables_module',
4950
'ops.linalg import linalg_impl as linalg': 'linalg_impl as linalg'
5051
}
@@ -185,6 +186,8 @@ def gen_module(module_name):
185186

186187
code = code.replace('array_ops.shape', 'prefer_static.shape')
187188
code = code.replace('array_ops.concat', 'prefer_static.concat')
189+
code = code.replace('array_ops.broadcast_dynamic_shape',
190+
'_ops.broadcast_dynamic_shape')
188191
code = code.replace('array_ops.broadcast_static_shape',
189192
'_ops.broadcast_static_shape')
190193
code = code.replace('array_ops.broadcast_to', '_ops.broadcast_to')

tensorflow_probability/python/internal/backend/numpy/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,7 @@ LINOP_FILES = [
548548
"linear_operator_kronecker",
549549
"linear_operator_lower_triangular",
550550
"linear_operator_low_rank_update",
551+
"linear_operator_permutation",
551552
"linear_operator",
552553
"linear_operator_toeplitz",
553554
"linear_operator_util",

tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_kronecker.py

100755100644
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def _shape_tensor(self):
292292
# broadcast_shape checks for compatibility.
293293
batch_shape = self.operators[0].batch_shape_tensor()
294294
for operator in self.operators[1:]:
295-
batch_shape = array_ops.broadcast_dynamic_shape(
295+
batch_shape = _ops.broadcast_dynamic_shape(
296296
batch_shape, operator.batch_shape_tensor())
297297

298298
return prefer_static.concat((batch_shape, matrix_shape), 0)

tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_low_rank_update.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -370,13 +370,13 @@ def _shape(self):
370370
return batch_shape.concatenate(tensor_shape.TensorShape(self.base_operator.shape)[-2:])
371371

372372
def _shape_tensor(self):
373-
batch_shape = array_ops.broadcast_dynamic_shape(
373+
batch_shape = _ops.broadcast_dynamic_shape(
374374
self.base_operator.batch_shape_tensor(),
375375
self.diag_operator.batch_shape_tensor())
376-
batch_shape = array_ops.broadcast_dynamic_shape(
376+
batch_shape = _ops.broadcast_dynamic_shape(
377377
batch_shape,
378378
prefer_static.shape(self.u)[:-2])
379-
batch_shape = array_ops.broadcast_dynamic_shape(
379+
batch_shape = _ops.broadcast_dynamic_shape(
380380
batch_shape,
381381
prefer_static.shape(self.v)[:-2])
382382
return prefer_static.concat(
Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
# Copyright 2020 The TensorFlow Probability Authors. All Rights Reserved.
2+
# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
3+
# THIS FILE IS AUTO-GENERATED BY `gen_linear_operators.py`.
4+
# DO NOT MODIFY DIRECTLY.
5+
# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
6+
# pylint: disable=g-import-not-at-top
7+
# pylint: disable=g-direct-tensorflow-import
8+
# pylint: disable=g-bad-import-order
9+
# pylint: disable=unused-import
10+
# pylint: disable=line-too-long
11+
# pylint: disable=reimported
12+
# pylint: disable=g-bool-id-comparison
13+
# pylint: disable=g-statement-before-imports
14+
# pylint: disable=bad-continuation
15+
# pylint: disable=useless-import-alias
16+
# pylint: disable=property-with-parameters
17+
# pylint: disable=trailing-whitespace
18+
# pylint: disable=g-inconsistent-quotes
19+
20+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
21+
#
22+
# Licensed under the Apache License, Version 2.0 (the "License");
23+
# you may not use this file except in compliance with the License.
24+
# You may obtain a copy of the License at
25+
#
26+
# http://www.apache.org/licenses/LICENSE-2.0
27+
#
28+
# Unless required by applicable law or agreed to in writing, software
29+
# distributed under the License is distributed on an "AS IS" BASIS,
30+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31+
# See the License for the specific language governing permissions and
32+
# limitations under the License.
33+
# ==============================================================================
34+
"""`LinearOperator` acting like a permutation matrix."""
35+
36+
import numpy as np
37+
38+
from tensorflow_probability.python.internal.backend.numpy import dtype as dtypes
39+
from tensorflow_probability.python.internal.backend.numpy import ops
40+
from tensorflow_probability.python.internal.backend.numpy import ops
41+
from tensorflow_probability.python.internal.backend.numpy import numpy_array as array_ops
42+
from tensorflow_probability.python.internal.backend.numpy import control_flow as control_flow_ops
43+
from tensorflow_probability.python.internal.backend.numpy import numpy_math as math_ops
44+
from tensorflow_probability.python.internal.backend.numpy import misc as sort_ops
45+
from tensorflow_probability.python.internal.backend.numpy import linalg_impl as linalg
46+
from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator
47+
from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator_util
48+
# from tensorflow.python.util.tf_export import tf_export
49+
50+
__all__ = ["LinearOperatorPermutation",]
51+
52+
53+
# @tf_export("linalg.LinearOperatorPermutation")
54+
# @linear_operator.make_composite_tensor
55+
class LinearOperatorPermutation(linear_operator.LinearOperator):
56+
"""`LinearOperator` acting like a [batch] of permutation matrices.
57+
58+
This operator acts like a [batch] of permutations with shape
59+
`[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a
60+
batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
61+
an `N x N` matrix. This matrix `A` is not materialized, but for
62+
purposes of broadcasting this shape will be relevant.
63+
64+
`LinearOperatorPermutation` is initialized with a (batch) vector.
65+
66+
A permutation, is defined by an integer vector `v` whose values are unique
67+
and are in the range `[0, ... n]`. Applying the permutation on an input
68+
matrix has the folllowing meaning: the value of `v` at index `i`
69+
says to move the `v[i]`-th row of the input matrix to the `i`-th row.
70+
Because all values are unique, this will result in a permutation of the
71+
rows the input matrix. Note, that the permutation vector `v` has the same
72+
semantics as `tf.transpose`.
73+
74+
```python
75+
# Create a 3 x 3 permutation matrix that swaps the last two columns.
76+
vec = [0, 2, 1]
77+
operator = LinearOperatorPermutation(vec)
78+
79+
operator.to_dense()
80+
==> [[1., 0., 0.]
81+
[0., 0., 1.]
82+
[0., 1., 0.]]
83+
84+
tensor_shape.TensorShape(operator.shape)
85+
==> [3, 3]
86+
87+
# This will be zero.
88+
operator.log_abs_determinant()
89+
==> scalar Tensor
90+
91+
x = ... Shape [3, 4] Tensor
92+
operator.matmul(x)
93+
==> Shape [3, 4] Tensor
94+
```
95+
96+
#### Shape compatibility
97+
98+
This operator acts on [batch] matrix with compatible shape.
99+
`x` is a batch matrix with compatible shape for `matmul` and `solve` if
100+
101+
```
102+
tensor_shape.TensorShape(operator.shape) = [B1,...,Bb] + [N, N], with b >= 0
103+
tensor_shape.TensorShape(x.shape) = [C1,...,Cc] + [N, R],
104+
and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd]
105+
```
106+
107+
#### Matrix property hints
108+
109+
This `LinearOperator` is initialized with boolean flags of the form `is_X`,
110+
for `X = non_singular, self_adjoint, positive_definite, square`.
111+
These have the following meaning:
112+
113+
* If `is_X == True`, callers should expect the operator to have the
114+
property `X`. This is a promise that should be fulfilled, but is *not* a
115+
runtime assert. For example, finite floating point precision may result
116+
in these promises being violated.
117+
* If `is_X == False`, callers should expect the operator to not have `X`.
118+
* If `is_X == None` (the default), callers should have no expectation either
119+
way.
120+
"""
121+
122+
def __init__(self,
123+
perm,
124+
dtype=dtypes.float32,
125+
is_non_singular=None,
126+
is_self_adjoint=None,
127+
is_positive_definite=None,
128+
is_square=None,
129+
name="LinearOperatorPermutation"):
130+
r"""Initialize a `LinearOperatorPermutation`.
131+
132+
Args:
133+
perm: Shape `[B1,...,Bb, N]` Integer `Tensor` with `b >= 0`
134+
`N >= 0`. An integer vector that represents the permutation to apply.
135+
Note that this argument is same as `tf.transpose`. However, this
136+
permutation is applied on the rows, while the permutation in
137+
`tf.transpose` is applied on the dimensions of the `Tensor`. `perm`
138+
is required to have unique entries from `{0, 1, ... N-1}`.
139+
dtype: The `dtype` of arguments to this operator. Default: `float32`.
140+
Allowed dtypes: `float16`, `float32`, `float64`, `complex64`,
141+
`complex128`.
142+
is_non_singular: Expect that this operator is non-singular.
143+
is_self_adjoint: Expect that this operator is equal to its hermitian
144+
transpose. This is autoset to true
145+
is_positive_definite: Expect that this operator is positive definite,
146+
meaning the quadratic form `x^H A x` has positive real part for all
147+
nonzero `x`. Note that we do not require the operator to be
148+
self-adjoint to be positive-definite. See:
149+
https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
150+
This is autoset to false.
151+
is_square: Expect that this operator acts like square [batch] matrices.
152+
This is autoset to true.
153+
name: A name for this `LinearOperator`.
154+
155+
Raises:
156+
ValueError: `is_self_adjoint` is not `True`, `is_positive_definite` is
157+
not `False` or `is_square` is not `True`.
158+
"""
159+
parameters = dict(
160+
perm=perm,
161+
dtype=dtype,
162+
is_non_singular=is_non_singular,
163+
is_self_adjoint=is_self_adjoint,
164+
is_positive_definite=is_positive_definite,
165+
is_square=is_square,
166+
name=name
167+
)
168+
169+
with ops.name_scope(name, values=[perm]):
170+
self._perm = linear_operator_util.convert_nonref_to_tensor(
171+
perm, name="perm")
172+
self._check_perm(self._perm)
173+
174+
# Check and auto-set hints.
175+
if is_non_singular is False: # pylint:disable=g-bool-id-comparison
176+
raise ValueError(f"A Permutation operator is always non-singular. "
177+
f"Expected argument `is_non_singular` to be True. "
178+
f"Received: {is_non_singular}.")
179+
180+
if is_square is False: # pylint:disable=g-bool-id-comparison
181+
raise ValueError(f"A Permutation operator is always square. "
182+
f"Expected argument `is_square` to be True. "
183+
f"Received: {is_square}.")
184+
is_square = True
185+
186+
super(LinearOperatorPermutation, self).__init__(
187+
dtype=dtype,
188+
is_non_singular=is_non_singular,
189+
is_self_adjoint=is_self_adjoint,
190+
is_positive_definite=is_positive_definite,
191+
is_square=is_square,
192+
parameters=parameters,
193+
name=name)
194+
195+
def _check_perm(self, perm):
196+
"""Static check of perm."""
197+
if (tensor_shape.TensorShape(perm.shape).ndims is not None and tensor_shape.TensorShape(perm.shape).ndims < 1):
198+
raise ValueError(f"Argument `perm` must have at least 1 dimension. "
199+
f"Received: {perm}.")
200+
if not np.issubdtype(perm.dtype, np.integer):
201+
raise TypeError(f"Argument `perm` must be integer dtype. "
202+
f"Received: {perm}.")
203+
# Check that the permutation satisfies the uniqueness constraint.
204+
static_perm = ops.get_static_value(perm)
205+
if static_perm is not None:
206+
sorted_perm = np.sort(static_perm, axis=-1)
207+
if np.any(sorted_perm != np.arange(0, tensor_shape.TensorShape(static_perm.shape)[-1])):
208+
raise ValueError(
209+
f"Argument `perm` must be a vector of unique integers from "
210+
f"0 to {tensor_shape.TensorShape(static_perm.shape)[-1] - 1}.")
211+
212+
def _shape(self):
213+
perm_shape = tensor_shape.TensorShape(self._perm.shape)
214+
return perm_shape.concatenate(perm_shape[-1:])
215+
216+
def _shape_tensor(self):
217+
perm_shape = prefer_static.shape(self._perm)
218+
k = perm_shape[-1]
219+
return prefer_static.concat((perm_shape, [k]), 0)
220+
221+
def _assert_non_singular(self):
222+
return control_flow_ops.no_op("assert_non_singular")
223+
224+
def _domain_dimension_tensor(self, perm=None):
225+
perm = perm if perm is not None else self.perm
226+
return prefer_static.shape(perm)[-1]
227+
228+
def _matmul(self, x, adjoint=False, adjoint_arg=False):
229+
perm = ops.convert_to_tensor(self.perm)
230+
if adjoint and not self.is_self_adjoint:
231+
# TODO(srvasude): invert_permutation doesn't work on batches so we use
232+
# argsort.
233+
perm = sort_ops.argsort(perm, axis=-1)
234+
x = linalg.adjoint(x) if adjoint_arg else x
235+
236+
# We need to broadcast x and the permutation since tf.gather doesn't
237+
# broadcast.
238+
broadcast_shape = _ops.broadcast_dynamic_shape(
239+
prefer_static.shape(x)[:-1], prefer_static.shape(perm))
240+
k = prefer_static.shape(x)[-1]
241+
broadcast_x_shape = prefer_static.concat([broadcast_shape, [k]], axis=-1)
242+
x = _ops.broadcast_to(x, broadcast_x_shape)
243+
perm = _ops.broadcast_to(perm, broadcast_shape)
244+
245+
m = prefer_static.shape(x)[-2]
246+
x = array_ops.reshape(x, [-1, m, k])
247+
perm = array_ops.reshape(perm, [-1, m])
248+
249+
y = array_ops.gather(x, perm, axis=-2, batch_dims=1)
250+
return array_ops.reshape(y, broadcast_x_shape)
251+
252+
# TODO(srvasude): Permutation parity is equivalent to the determinant.
253+
254+
def _log_abs_determinant(self):
255+
# Permutation matrices have determinant +/- 1.
256+
return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype)
257+
258+
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
259+
# The inverse of a permutation matrix is the transpose matrix.
260+
# Apply a matmul and flip the adjoint bit.
261+
return self._matmul(rhs, adjoint=(not adjoint), adjoint_arg=adjoint_arg)
262+
263+
def _to_dense(self):
264+
perm = ops.convert_to_tensor(self.perm)
265+
return _ops.cast(math_ops.equal(
266+
array_ops.range(0, self._domain_dimension_tensor(perm)),
267+
perm[..., _ops.newaxis]), self.dtype)
268+
269+
def _diag_part(self):
270+
perm = ops.convert_to_tensor(self.perm)
271+
return _ops.cast(math_ops.equal(
272+
array_ops.range(0, self._domain_dimension_tensor(perm)),
273+
perm), self.dtype)
274+
275+
def _cond(self):
276+
# Permutation matrices are rotations which have condition number 1.
277+
return array_ops.ones(self.batch_shape_tensor(), dtype=self.dtype)
278+
279+
@property
280+
def perm(self):
281+
return self._perm
282+
283+
@property
284+
def _composite_tensor_fields(self):
285+
return ("perm", "dtype")
286+
287+
@property
288+
def _experimental_parameter_ndims_to_matrix_ndims(self):
289+
return {"perm": 1}
290+
291+
import numpy as np
292+
from tensorflow_probability.python.internal.backend.numpy import linalg_impl as _linalg
293+
from tensorflow_probability.python.internal.backend.numpy import ops as _ops
294+
from tensorflow_probability.python.internal.backend.numpy.gen import tensor_shape
295+
296+
from tensorflow_probability.python.internal.backend.numpy import private
297+
distribution_util = private.LazyLoader(
298+
"distribution_util", globals(),
299+
"tensorflow_probability.substrates.numpy.internal.distribution_util")
300+
tensorshape_util = private.LazyLoader(
301+
"tensorshape_util", globals(),
302+
"tensorflow_probability.substrates.numpy.internal.tensorshape_util")
303+
prefer_static = private.LazyLoader(
304+
"prefer_static", globals(),
305+
"tensorflow_probability.substrates.numpy.internal.prefer_static")
306+

tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_toeplitz.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def _shape(self):
204204
def _shape_tensor(self, row=None, col=None):
205205
row = self.row if row is None else row
206206
col = self.col if col is None else col
207-
v_shape = array_ops.broadcast_dynamic_shape(
207+
v_shape = _ops.broadcast_dynamic_shape(
208208
prefer_static.shape(row),
209209
prefer_static.shape(col))
210210
k = v_shape[-1]
@@ -262,7 +262,7 @@ def _diag_part(self):
262262
def _to_dense(self):
263263
row = ops.convert_to_tensor(self.row)
264264
col = ops.convert_to_tensor(self.col)
265-
total_shape = array_ops.broadcast_dynamic_shape(
265+
total_shape = _ops.broadcast_dynamic_shape(
266266
prefer_static.shape(row), prefer_static.shape(col))
267267
n = prefer_static.shape(row)[-1]
268268
row = _ops.broadcast_to(row, total_shape)

tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def broadcast_matrix_batch_dims(batch_matrices, name=None):
367367
# Since static didn't work, do dynamic, which always copies data.
368368
bcast_batch_shape = prefer_static.shape(batch_matrices[0])[:-2]
369369
for mat in batch_matrices[1:]:
370-
bcast_batch_shape = array_ops.broadcast_dynamic_shape(
370+
bcast_batch_shape = _ops.broadcast_dynamic_shape(
371371
bcast_batch_shape,
372372
prefer_static.shape(mat)[:-2])
373373
for i, mat in enumerate(batch_matrices):

0 commit comments

Comments
 (0)