70
70
_IFFT_OP = {1 : fft_ops .ifft , 2 : fft_ops .ifft2d , 3 : fft_ops .ifft3d }
71
71
72
72
73
+ def exponential_power_convolution_kernel (
74
+ grid_shape ,
75
+ length_scale ,
76
+ power = None ,
77
+ divisor = None ,
78
+ zero_inflation = None ,
79
+ ):
80
+ """Make an exponentiated convolution kernel.
81
+
82
+ In signal processing, a [kernel]
83
+ (https://en.wikipedia.org/wiki/Kernel_(image_processing)) `h` can be convolved
84
+ with a signal `x` to filter its spectral content.
85
+
86
+ This function makes a `d-dimensional` convolution kernel `h` of shape
87
+ `grid_shape = [N0, N1, ...]`. For `n` a multi-index with `n[i] < Ni / 2`,
88
+
89
+ ```h[n] = exp{sum(|n / (length_scale * grid_shape)|**power) / divisor}.```
90
+
91
+ For other `n`, `h` is extended to be circularly symmetric. That is
92
+
93
+ ```h[n0 % N0, ...] = h[(-n0) % N0, ...]```
94
+
95
+ Since `h` is circularly symmetric and real valued, `H = FFTd[h]` is the
96
+ spectrum of a symmetric (real) circulant operator `A`.
97
+
98
+ #### Example uses
99
+
100
+ ```
101
+ # Matern one-half kernel, d=1.
102
+ # Will be positive definite without zero_inflation.
103
+ h = exponential_power_convolution_kernel(
104
+ grid_shape=[10], length_scale=[0.1], power=1)
105
+ A = LinearOperatorCirculant(
106
+ tf.signal.fft(tf.cast(h, tf.complex64)),
107
+ is_self_adjoint=True, is_positive_definite=True)
108
+
109
+ # Gaussian RBF kernel, d=3.
110
+ # Needs zero_inflation since `length_scale` is long enough to cause aliasing.
111
+ h = exponential_power_convolution_kernel(
112
+ grid_shape=[10, 10, 10], length_scale=[0.1, 0.2, 0.2], power=2,
113
+ zero_inflation=0.15)
114
+ A = LinearOperatorCirculant3D(
115
+ tf.signal.fft3d(tf.cast(h, tf.complex64)),
116
+ is_self_adjoint=True, is_positive_definite=True)
117
+ ```
118
+
119
+ Args:
120
+ grid_shape: Length `d` (`d` in {1, 2, 3}) list-like of Python integers. The
121
+ shape of the grid on which the convolution kernel is defined.
122
+ length_scale: Length `d` `float` `Tensor`. The scale at which the kernel
123
+ decays in each direction, as a fraction of `grid_shape`.
124
+ power: Scalar `Tensor` of same `dtype` as `length_scale`, default `2`.
125
+ Higher (lower) `power` results in nearby points being more (less)
126
+ correlated, and far away points being less (more) correlated.
127
+ divisor: Scalar `Tensor` of same `dtype` as `length_scale`. The slope of
128
+ decay of `log(kernel)` in terms of fractional grid points, along each
129
+ axis, at `length_scale`, is `power/divisor`. By default, `divisor` is set
130
+ to `power`. This means, by default, `power=2` results in an exponentiated
131
+ quadratic (Gaussian) kernel, and `power=1` is a Matern one-half.
132
+ zero_inflation: Scalar `Tensor` of same `dtype` as `length_scale`, in
133
+ `[0, 1]`. Let `delta` be the Kronecker delta. That is,
134
+ `delta[0, ..., 0] = 1` and all other entries are `0`. Then
135
+ `zero_inflation` modifies the return value via
136
+ `h --> (1 - zero_inflation) * h + zero_inflation * delta`. This may be
137
+ needed to ensure a positive definite kernel, especially if `length_scale`
138
+ is large enough for aliasing and `power > 1`.
139
+
140
+ Returns:
141
+ `Tensor` of shape `grid_shape` with same `dtype` as `length_scale`.
142
+ """
143
+ nd = len (grid_shape )
144
+
145
+ length_scale = ops .convert_to_tensor (
146
+ length_scale , name = "length_scale" )
147
+ dtype = length_scale .dtype
148
+
149
+ power = 2. if power is None else power
150
+ power = ops .convert_to_tensor (
151
+ power , name = "power" , dtype = dtype )
152
+ divisor = power if divisor is None else divisor
153
+ divisor = ops .convert_to_tensor (
154
+ divisor , name = "divisor" , dtype = dtype )
155
+
156
+ # With K = grid_shape[i], we implicitly assume the grid vertices along the
157
+ # ith dimension are at:
158
+ # 0 = 0 / (K - 1), 1 / (K - 1), 2 / (K - 1), ..., (K - 1) / (K - 1) = 1.
159
+ zero = _ops .cast (0. , dtype )
160
+ one = _ops .cast (1. , dtype )
161
+ ts = [math_ops .linspace (zero , one , num = n ) for n in grid_shape ]
162
+
163
+ log_vals = []
164
+ for i , x in enumerate (array_ops .meshgrid (* ts , indexing = "ij" )):
165
+ # midpoint[i] is the vertex just to the left of 1 / 2.
166
+ # ifftshift will shift this vertex to position 0.
167
+ midpoint = ts [i ][_ops .cast (
168
+ math_ops .floor (one / 2. * grid_shape [i ]), dtypes .int32 )]
169
+ log_vals .append (- (math_ops .abs (
170
+ (x - midpoint ) / length_scale [i ]))** power / divisor )
171
+ kernel = math_ops .exp (
172
+ fft_ops .ifftshift (sum (log_vals ), axes = [- i for i in range (1 , nd + 1 )]))
173
+
174
+ if zero_inflation :
175
+ # tensor_shape.TensorShape(delta.shape) = grid_shape, delta[0, 0, 0] = 1., all other entries are 0.
176
+ zero_inflation = ops .convert_to_tensor (
177
+ zero_inflation , name = "zero_inflation" , dtype = dtype )
178
+ delta = array_ops .pad (
179
+ array_ops .reshape (one , [1 ] * nd ), [[0 , dim - 1 ] for dim in grid_shape ])
180
+ kernel = (1. - zero_inflation ) * kernel + zero_inflation * delta
181
+
182
+ return kernel
183
+
184
+
73
185
# TODO(langmore) Add transformations that create common spectrums, e.g.
74
186
# starting with the convolution kernel
75
187
# start with half a spectrum, and create a Hermitian one.
@@ -94,9 +206,9 @@ def __init__(self,
94
206
r"""Initialize an `_BaseLinearOperatorCirculant`.
95
207
96
208
Args:
97
- spectrum: Shape `[B1,...,Bb, N] ` `Tensor`. Allowed dtypes: `float16`,
98
- `float32 `, `float64 `, `complex64 `, `complex128`. Type can be different
99
- than `input_output_dtype`
209
+ spectrum: Shape `[B1,...,Bb] + N ` `Tensor`, where `rank(N) in {1, 2, 3}`.
210
+ Allowed dtypes: `float16 `, `float32 `, `float64 `, `complex64`,
211
+ `complex128`. Type can be different than `input_output_dtype`
100
212
block_depth: Python integer, either 1, 2, or 3. Will be 1 for circulant,
101
213
2 for block circulant, and 3 for nested block circulant.
102
214
input_output_dtype: `dtype` for input/output.
@@ -255,6 +367,33 @@ def _vectorize_then_blockify(self, matrix):
255
367
(vec_leading_shape , self .block_shape_tensor ()), 0 )
256
368
return array_ops .reshape (vec , final_shape )
257
369
370
+ def _unblockify (self , x ):
371
+ """Flatten the trailing block dimensions."""
372
+ # Suppose
373
+ # tensor_shape.TensorShape(x.shape) = [v0, v1, v2, v3],
374
+ # self.block_depth = 2.
375
+ # Then
376
+ # leading shape = [v0, v1]
377
+ # block shape = [v2, v3].
378
+ # We will reshape x to
379
+ # [v0, v1, v2*v3].
380
+ if tensor_shape .TensorShape (x .shape ).is_fully_defined ():
381
+ # x_shape = [v0, v1, v2, v3]
382
+ x_shape = tensor_shape .TensorShape (x .shape ).as_list ()
383
+ # x_leading_shape = [v0, v1]
384
+ x_leading_shape = x_shape [:- self .block_depth ]
385
+ # x_block_shape = [v2, v3]
386
+ x_block_shape = x_shape [- self .block_depth :]
387
+ # flat_shape = [v0, v1, v2*v3]
388
+ flat_shape = x_leading_shape + [np .prod (x_block_shape )]
389
+ else :
390
+ x_shape = prefer_static .shape (x )
391
+ x_leading_shape = x_shape [:- self .block_depth ]
392
+ x_block_shape = x_shape [- self .block_depth :]
393
+ flat_shape = prefer_static .concat (
394
+ (x_leading_shape , [math_ops .reduce_prod (x_block_shape )]), 0 )
395
+ return array_ops .reshape (x , flat_shape )
396
+
258
397
def _unblockify_then_matricize (self , vec ):
259
398
"""Flatten the block dimensions then reshape to a batch matrix."""
260
399
# Suppose
@@ -268,22 +407,7 @@ def _unblockify_then_matricize(self, vec):
268
407
269
408
# Un-blockify: Flatten block dimensions. Reshape
270
409
# [v0, v1, v2, v3] --> [v0, v1, v2*v3].
271
- if tensor_shape .TensorShape (vec .shape ).is_fully_defined ():
272
- # vec_shape = [v0, v1, v2, v3]
273
- vec_shape = tensor_shape .TensorShape (vec .shape ).as_list ()
274
- # vec_leading_shape = [v0, v1]
275
- vec_leading_shape = vec_shape [:- self .block_depth ]
276
- # vec_block_shape = [v2, v3]
277
- vec_block_shape = vec_shape [- self .block_depth :]
278
- # flat_shape = [v0, v1, v2*v3]
279
- flat_shape = vec_leading_shape + [np .prod (vec_block_shape )]
280
- else :
281
- vec_shape = prefer_static .shape (vec )
282
- vec_leading_shape = vec_shape [:- self .block_depth ]
283
- vec_block_shape = vec_shape [- self .block_depth :]
284
- flat_shape = prefer_static .concat (
285
- (vec_leading_shape , [math_ops .reduce_prod (vec_block_shape )]), 0 )
286
- vec_flat = array_ops .reshape (vec , flat_shape )
410
+ vec_flat = self ._unblockify (vec )
287
411
288
412
# Matricize: Reshape to batch matrix.
289
413
# [v0, v1, v2*v3] --> [v1, v2*v3, v0],
@@ -433,6 +557,21 @@ def _broadcast_batch_dims(self, x, spectrum):
433
557
434
558
return x , spectrum
435
559
560
+ def _cond (self ):
561
+ # Regardless of whether the operator is real, it is always diagonalizable by
562
+ # the Fourier basis F. I.e. A = F S F^H, with S a diagonal matrix
563
+ # containing the spectrum. We then have:
564
+ # A A^H = F SS^H F^H = F K F^H,
565
+ # where K = diag with squared absolute values of the spectrum.
566
+ # So in all cases,
567
+ abs_singular_values = math_ops .abs (self ._unblockify (self .spectrum ))
568
+ return (math_ops .reduce_max (abs_singular_values , axis = - 1 ) /
569
+ math_ops .reduce_min (abs_singular_values , axis = - 1 ))
570
+
571
+ def _eigvals (self ):
572
+ return ops .convert_to_tensor (
573
+ self ._unblockify (self .spectrum ))
574
+
436
575
def _matmul (self , x , adjoint = False , adjoint_arg = False ):
437
576
x = linalg .adjoint (x ) if adjoint_arg else x
438
577
# With F the matrix of a DFT, and F^{-1}, F^H the inverse and Hermitian
@@ -805,9 +944,6 @@ def __init__(self,
805
944
parameters = parameters ,
806
945
name = name )
807
946
808
- def _eigvals (self ):
809
- return ops .convert_to_tensor (self .spectrum )
810
-
811
947
812
948
# @tf_export("linalg.LinearOperatorCirculant2D")
813
949
# @linear_operator.make_composite_tensor
0 commit comments