Skip to content

Commit 3f5749a

Browse files
emilyfertigtensorflower-gardener
authored andcommitted
Update generated LinearOperator files in the Numpy backend.
PiperOrigin-RevId: 456412003
1 parent cb2a94c commit 3f5749a

File tree

6 files changed

+201
-32
lines changed

6 files changed

+201
-32
lines changed

tensorflow_probability/python/internal/backend/numpy/gen/adjoint_registrations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,14 +128,14 @@ def _adjoint_kronecker(kronecker_operator):
128128

129129

130130
@linear_operator_algebra.RegisterAdjoint(
131-
linear_operator_circulant.LinearOperatorCirculant)
131+
linear_operator_circulant._BaseLinearOperatorCirculant) # pylint: disable=protected-access
132132
def _adjoint_circulant(circulant_operator):
133133
spectrum = circulant_operator.spectrum
134134
if np.issubdtype(spectrum.dtype, np.complexfloating):
135135
spectrum = math_ops.conj(spectrum)
136136

137137
# Conjugating the spectrum is sufficient to get the adjoint.
138-
return linear_operator_circulant.LinearOperatorCirculant(
138+
return circulant_operator.__class__(
139139
spectrum=spectrum,
140140
is_non_singular=circulant_operator.is_non_singular,
141141
is_self_adjoint=circulant_operator.is_self_adjoint,

tensorflow_probability/python/internal/backend/numpy/gen/inverse_registrations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,10 +222,10 @@ def _inverse_kronecker(kronecker_operator):
222222

223223

224224
@linear_operator_algebra.RegisterInverse(
225-
linear_operator_circulant.LinearOperatorCirculant)
225+
linear_operator_circulant._BaseLinearOperatorCirculant) # pylint: disable=protected-access
226226
def _inverse_circulant(circulant_operator):
227227
# Inverting the spectrum is sufficient to get the inverse.
228-
return linear_operator_circulant.LinearOperatorCirculant(
228+
return circulant_operator.__class__(
229229
spectrum=1. / circulant_operator.spectrum,
230230
is_non_singular=circulant_operator.is_non_singular,
231231
is_self_adjoint=circulant_operator.is_self_adjoint,

tensorflow_probability/python/internal/backend/numpy/gen/linear_operator.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from tensorflow_probability.python.internal.backend.numpy import deprecation
6262
# from tensorflow_probability.python.internal.backend.numpy import dispatch
6363
from tensorflow_probability.python.internal.backend.numpy import nest
64+
from tensorflow_probability.python.internal.backend.numpy import variable_utils
6465
# from tensorflow.python.util.tf_export import tf_export
6566

6667
__all__ = ["LinearOperator"]
@@ -1211,6 +1212,28 @@ def _type_spec(self):
12111212
# `@make_composite_tensor` decorator.
12121213
pass
12131214

1215+
def _convert_variables_to_tensors(self):
1216+
"""Recursively converts ResourceVariables in the LinearOperator to Tensors.
1217+
1218+
The usage of `self._type_spec._from_components` violates the contract of
1219+
`CompositeTensor`, since it is called on a different nested structure
1220+
(one containing only `Tensor`s) than `self.type_spec` specifies (one that
1221+
may contain `ResourceVariable`s). Since `LinearOperator`'s
1222+
`_from_components` method just passes the contents of the nested structure
1223+
to `__init__` to rebuild the operator, and any `LinearOperator` that may be
1224+
instantiated with `ResourceVariables` may also be instantiated with
1225+
`Tensor`s, this usage is valid.
1226+
1227+
Returns:
1228+
tensor_operator: `self` with all internal Variables converted to Tensors.
1229+
"""
1230+
# pylint: disable=protected-access
1231+
components = self._type_spec._to_components(self)
1232+
tensor_components = variable_utils.convert_variables_to_tensors(
1233+
components)
1234+
return self._type_spec._from_components(tensor_components)
1235+
# pylint: enable=protected-access
1236+
12141237
def __getitem__(self, slices):
12151238
return slicing.batch_slice(self, params_overrides={}, slices=slices)
12161239

tensorflow_probability/python/internal/backend/numpy/gen/linear_operator_circulant.py

Lines changed: 158 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,118 @@
7070
_IFFT_OP = {1: fft_ops.ifft, 2: fft_ops.ifft2d, 3: fft_ops.ifft3d}
7171

7272

73+
def exponential_power_convolution_kernel(
74+
grid_shape,
75+
length_scale,
76+
power=None,
77+
divisor=None,
78+
zero_inflation=None,
79+
):
80+
"""Make an exponentiated convolution kernel.
81+
82+
In signal processing, a [kernel]
83+
(https://en.wikipedia.org/wiki/Kernel_(image_processing)) `h` can be convolved
84+
with a signal `x` to filter its spectral content.
85+
86+
This function makes a `d-dimensional` convolution kernel `h` of shape
87+
`grid_shape = [N0, N1, ...]`. For `n` a multi-index with `n[i] < Ni / 2`,
88+
89+
```h[n] = exp{sum(|n / (length_scale * grid_shape)|**power) / divisor}.```
90+
91+
For other `n`, `h` is extended to be circularly symmetric. That is
92+
93+
```h[n0 % N0, ...] = h[(-n0) % N0, ...]```
94+
95+
Since `h` is circularly symmetric and real valued, `H = FFTd[h]` is the
96+
spectrum of a symmetric (real) circulant operator `A`.
97+
98+
#### Example uses
99+
100+
```
101+
# Matern one-half kernel, d=1.
102+
# Will be positive definite without zero_inflation.
103+
h = exponential_power_convolution_kernel(
104+
grid_shape=[10], length_scale=[0.1], power=1)
105+
A = LinearOperatorCirculant(
106+
tf.signal.fft(tf.cast(h, tf.complex64)),
107+
is_self_adjoint=True, is_positive_definite=True)
108+
109+
# Gaussian RBF kernel, d=3.
110+
# Needs zero_inflation since `length_scale` is long enough to cause aliasing.
111+
h = exponential_power_convolution_kernel(
112+
grid_shape=[10, 10, 10], length_scale=[0.1, 0.2, 0.2], power=2,
113+
zero_inflation=0.15)
114+
A = LinearOperatorCirculant3D(
115+
tf.signal.fft3d(tf.cast(h, tf.complex64)),
116+
is_self_adjoint=True, is_positive_definite=True)
117+
```
118+
119+
Args:
120+
grid_shape: Length `d` (`d` in {1, 2, 3}) list-like of Python integers. The
121+
shape of the grid on which the convolution kernel is defined.
122+
length_scale: Length `d` `float` `Tensor`. The scale at which the kernel
123+
decays in each direction, as a fraction of `grid_shape`.
124+
power: Scalar `Tensor` of same `dtype` as `length_scale`, default `2`.
125+
Higher (lower) `power` results in nearby points being more (less)
126+
correlated, and far away points being less (more) correlated.
127+
divisor: Scalar `Tensor` of same `dtype` as `length_scale`. The slope of
128+
decay of `log(kernel)` in terms of fractional grid points, along each
129+
axis, at `length_scale`, is `power/divisor`. By default, `divisor` is set
130+
to `power`. This means, by default, `power=2` results in an exponentiated
131+
quadratic (Gaussian) kernel, and `power=1` is a Matern one-half.
132+
zero_inflation: Scalar `Tensor` of same `dtype` as `length_scale`, in
133+
`[0, 1]`. Let `delta` be the Kronecker delta. That is,
134+
`delta[0, ..., 0] = 1` and all other entries are `0`. Then
135+
`zero_inflation` modifies the return value via
136+
`h --> (1 - zero_inflation) * h + zero_inflation * delta`. This may be
137+
needed to ensure a positive definite kernel, especially if `length_scale`
138+
is large enough for aliasing and `power > 1`.
139+
140+
Returns:
141+
`Tensor` of shape `grid_shape` with same `dtype` as `length_scale`.
142+
"""
143+
nd = len(grid_shape)
144+
145+
length_scale = ops.convert_to_tensor(
146+
length_scale, name="length_scale")
147+
dtype = length_scale.dtype
148+
149+
power = 2. if power is None else power
150+
power = ops.convert_to_tensor(
151+
power, name="power", dtype=dtype)
152+
divisor = power if divisor is None else divisor
153+
divisor = ops.convert_to_tensor(
154+
divisor, name="divisor", dtype=dtype)
155+
156+
# With K = grid_shape[i], we implicitly assume the grid vertices along the
157+
# ith dimension are at:
158+
# 0 = 0 / (K - 1), 1 / (K - 1), 2 / (K - 1), ..., (K - 1) / (K - 1) = 1.
159+
zero = _ops.cast(0., dtype)
160+
one = _ops.cast(1., dtype)
161+
ts = [math_ops.linspace(zero, one, num=n) for n in grid_shape]
162+
163+
log_vals = []
164+
for i, x in enumerate(array_ops.meshgrid(*ts, indexing="ij")):
165+
# midpoint[i] is the vertex just to the left of 1 / 2.
166+
# ifftshift will shift this vertex to position 0.
167+
midpoint = ts[i][_ops.cast(
168+
math_ops.floor(one / 2. * grid_shape[i]), dtypes.int32)]
169+
log_vals.append(-(math_ops.abs(
170+
(x - midpoint) / length_scale[i]))**power / divisor)
171+
kernel = math_ops.exp(
172+
fft_ops.ifftshift(sum(log_vals), axes=[-i for i in range(1, nd + 1)]))
173+
174+
if zero_inflation:
175+
# tensor_shape.TensorShape(delta.shape) = grid_shape, delta[0, 0, 0] = 1., all other entries are 0.
176+
zero_inflation = ops.convert_to_tensor(
177+
zero_inflation, name="zero_inflation", dtype=dtype)
178+
delta = array_ops.pad(
179+
array_ops.reshape(one, [1] * nd), [[0, dim - 1] for dim in grid_shape])
180+
kernel = (1. - zero_inflation) * kernel + zero_inflation * delta
181+
182+
return kernel
183+
184+
73185
# TODO(langmore) Add transformations that create common spectrums, e.g.
74186
# starting with the convolution kernel
75187
# start with half a spectrum, and create a Hermitian one.
@@ -94,9 +206,9 @@ def __init__(self,
94206
r"""Initialize an `_BaseLinearOperatorCirculant`.
95207
96208
Args:
97-
spectrum: Shape `[B1,...,Bb, N]` `Tensor`. Allowed dtypes: `float16`,
98-
`float32`, `float64`, `complex64`, `complex128`. Type can be different
99-
than `input_output_dtype`
209+
spectrum: Shape `[B1,...,Bb] + N` `Tensor`, where `rank(N) in {1, 2, 3}`.
210+
Allowed dtypes: `float16`, `float32`, `float64`, `complex64`,
211+
`complex128`. Type can be different than `input_output_dtype`
100212
block_depth: Python integer, either 1, 2, or 3. Will be 1 for circulant,
101213
2 for block circulant, and 3 for nested block circulant.
102214
input_output_dtype: `dtype` for input/output.
@@ -255,6 +367,33 @@ def _vectorize_then_blockify(self, matrix):
255367
(vec_leading_shape, self.block_shape_tensor()), 0)
256368
return array_ops.reshape(vec, final_shape)
257369

370+
def _unblockify(self, x):
371+
"""Flatten the trailing block dimensions."""
372+
# Suppose
373+
# tensor_shape.TensorShape(x.shape) = [v0, v1, v2, v3],
374+
# self.block_depth = 2.
375+
# Then
376+
# leading shape = [v0, v1]
377+
# block shape = [v2, v3].
378+
# We will reshape x to
379+
# [v0, v1, v2*v3].
380+
if tensor_shape.TensorShape(x.shape).is_fully_defined():
381+
# x_shape = [v0, v1, v2, v3]
382+
x_shape = tensor_shape.TensorShape(x.shape).as_list()
383+
# x_leading_shape = [v0, v1]
384+
x_leading_shape = x_shape[:-self.block_depth]
385+
# x_block_shape = [v2, v3]
386+
x_block_shape = x_shape[-self.block_depth:]
387+
# flat_shape = [v0, v1, v2*v3]
388+
flat_shape = x_leading_shape + [np.prod(x_block_shape)]
389+
else:
390+
x_shape = prefer_static.shape(x)
391+
x_leading_shape = x_shape[:-self.block_depth]
392+
x_block_shape = x_shape[-self.block_depth:]
393+
flat_shape = prefer_static.concat(
394+
(x_leading_shape, [math_ops.reduce_prod(x_block_shape)]), 0)
395+
return array_ops.reshape(x, flat_shape)
396+
258397
def _unblockify_then_matricize(self, vec):
259398
"""Flatten the block dimensions then reshape to a batch matrix."""
260399
# Suppose
@@ -268,22 +407,7 @@ def _unblockify_then_matricize(self, vec):
268407

269408
# Un-blockify: Flatten block dimensions. Reshape
270409
# [v0, v1, v2, v3] --> [v0, v1, v2*v3].
271-
if tensor_shape.TensorShape(vec.shape).is_fully_defined():
272-
# vec_shape = [v0, v1, v2, v3]
273-
vec_shape = tensor_shape.TensorShape(vec.shape).as_list()
274-
# vec_leading_shape = [v0, v1]
275-
vec_leading_shape = vec_shape[:-self.block_depth]
276-
# vec_block_shape = [v2, v3]
277-
vec_block_shape = vec_shape[-self.block_depth:]
278-
# flat_shape = [v0, v1, v2*v3]
279-
flat_shape = vec_leading_shape + [np.prod(vec_block_shape)]
280-
else:
281-
vec_shape = prefer_static.shape(vec)
282-
vec_leading_shape = vec_shape[:-self.block_depth]
283-
vec_block_shape = vec_shape[-self.block_depth:]
284-
flat_shape = prefer_static.concat(
285-
(vec_leading_shape, [math_ops.reduce_prod(vec_block_shape)]), 0)
286-
vec_flat = array_ops.reshape(vec, flat_shape)
410+
vec_flat = self._unblockify(vec)
287411

288412
# Matricize: Reshape to batch matrix.
289413
# [v0, v1, v2*v3] --> [v1, v2*v3, v0],
@@ -433,6 +557,21 @@ def _broadcast_batch_dims(self, x, spectrum):
433557

434558
return x, spectrum
435559

560+
def _cond(self):
561+
# Regardless of whether the operator is real, it is always diagonalizable by
562+
# the Fourier basis F. I.e. A = F S F^H, with S a diagonal matrix
563+
# containing the spectrum. We then have:
564+
# A A^H = F SS^H F^H = F K F^H,
565+
# where K = diag with squared absolute values of the spectrum.
566+
# So in all cases,
567+
abs_singular_values = math_ops.abs(self._unblockify(self.spectrum))
568+
return (math_ops.reduce_max(abs_singular_values, axis=-1) /
569+
math_ops.reduce_min(abs_singular_values, axis=-1))
570+
571+
def _eigvals(self):
572+
return ops.convert_to_tensor(
573+
self._unblockify(self.spectrum))
574+
436575
def _matmul(self, x, adjoint=False, adjoint_arg=False):
437576
x = linalg.adjoint(x) if adjoint_arg else x
438577
# With F the matrix of a DFT, and F^{-1}, F^H the inverse and Hermitian
@@ -805,9 +944,6 @@ def __init__(self,
805944
parameters=parameters,
806945
name=name)
807946

808-
def _eigvals(self):
809-
return ops.convert_to_tensor(self.spectrum)
810-
811947

812948
# @tf_export("linalg.LinearOperatorCirculant2D")
813949
# @linear_operator.make_composite_tensor

tensorflow_probability/python/internal/backend/numpy/gen/matmul_registrations.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,11 +218,15 @@ def _matmul_linear_operator_tril_diag(linop_triangular, linop_diag):
218218
# Circulant.
219219

220220

221+
# pylint: disable=protected-access
221222
@linear_operator_algebra.RegisterMatmul(
222-
linear_operator_circulant.LinearOperatorCirculant,
223-
linear_operator_circulant.LinearOperatorCirculant)
223+
linear_operator_circulant._BaseLinearOperatorCirculant,
224+
linear_operator_circulant._BaseLinearOperatorCirculant)
224225
def _matmul_linear_operator_circulant_circulant(linop_a, linop_b):
225-
return linear_operator_circulant.LinearOperatorCirculant(
226+
if not isinstance(linop_a, linop_b.__class__):
227+
return _matmul_linear_operator(linop_a, linop_b)
228+
229+
return linop_a.__class__(
226230
spectrum=linop_a.spectrum * linop_b.spectrum,
227231
is_non_singular=registrations_util.combined_non_singular_hint(
228232
linop_a, linop_b),
@@ -232,6 +236,7 @@ def _matmul_linear_operator_circulant_circulant(linop_a, linop_b):
232236
registrations_util.combined_commuting_positive_definite_hint(
233237
linop_a, linop_b)),
234238
is_square=True)
239+
# pylint: enable=protected-access
235240

236241
# Block Diag
237242

tensorflow_probability/python/internal/backend/numpy/gen/solve_registrations.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,11 +190,15 @@ def _solve_linear_operator_diag_tril(linop_diag, linop_triangular):
190190
# Circulant.
191191

192192

193+
# pylint: disable=protected-access
193194
@linear_operator_algebra.RegisterSolve(
194-
linear_operator_circulant.LinearOperatorCirculant,
195-
linear_operator_circulant.LinearOperatorCirculant)
195+
linear_operator_circulant._BaseLinearOperatorCirculant,
196+
linear_operator_circulant._BaseLinearOperatorCirculant)
196197
def _solve_linear_operator_circulant_circulant(linop_a, linop_b):
197-
return linear_operator_circulant.LinearOperatorCirculant(
198+
if not isinstance(linop_a, linop_b.__class__):
199+
return _solve_linear_operator(linop_a, linop_b)
200+
201+
return linop_a.__class__(
198202
spectrum=linop_b.spectrum / linop_a.spectrum,
199203
is_non_singular=registrations_util.combined_non_singular_hint(
200204
linop_a, linop_b),
@@ -204,6 +208,7 @@ def _solve_linear_operator_circulant_circulant(linop_a, linop_b):
204208
registrations_util.combined_commuting_positive_definite_hint(
205209
linop_a, linop_b)),
206210
is_square=True)
211+
# pylint: enable=protected-access
207212

208213

209214
# Block Diag

0 commit comments

Comments
 (0)