Skip to content

Commit 774c6a4

Browse files
srvasudetensorflower-gardener
authored andcommitted
Add SymmetricMatrixSpace and ConstantDiagonalSymmetricMatrixSpace.
- These can be used for symmetric matrices and any open subset of them (e.g. positive definite matrices) since they will share the same tangent space. - Also refactor spaces / tests to their own files. PiperOrigin-RevId: 529547439
1 parent abc165a commit 774c6a4

File tree

17 files changed

+1546
-772
lines changed

17 files changed

+1546
-772
lines changed

tensorflow_probability/python/distributions/BUILD

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,6 @@ multi_substrate_py_library(
314314
# numpy dep,
315315
# tensorflow dep,
316316
"//tensorflow_probability/python/bijectors:sigmoid",
317-
"//tensorflow_probability/python/experimental/tangent_spaces",
318317
"//tensorflow_probability/python/internal:assert_util",
319318
"//tensorflow_probability/python/internal:batched_rejection_sampler",
320319
"//tensorflow_probability/python/internal:distribution_util",
@@ -494,6 +493,7 @@ multi_substrate_py_library(
494493
# tensorflow dep,
495494
"//tensorflow_probability/python/bijectors:softmax_centered",
496495
"//tensorflow_probability/python/bijectors:softplus",
496+
"//tensorflow_probability/python/experimental/tangent_spaces:simplex",
497497
"//tensorflow_probability/python/internal:assert_util",
498498
"//tensorflow_probability/python/internal:distribution_util",
499499
"//tensorflow_probability/python/internal:dtype_util",
@@ -1631,7 +1631,6 @@ multi_substrate_py_library(
16311631
":distribution",
16321632
# tensorflow dep,
16331633
"//tensorflow_probability/python/bijectors:softmax_centered",
1634-
"//tensorflow_probability/python/experimental/tangent_spaces",
16351634
"//tensorflow_probability/python/internal:assert_util",
16361635
"//tensorflow_probability/python/internal:distribution_util",
16371636
"//tensorflow_probability/python/internal:dtype_util",
@@ -1651,7 +1650,6 @@ multi_substrate_py_library(
16511650
":gamma",
16521651
# tensorflow dep,
16531652
"//tensorflow_probability/python/bijectors:sigmoid",
1654-
"//tensorflow_probability/python/experimental/tangent_spaces",
16551653
"//tensorflow_probability/python/internal:assert_util",
16561654
"//tensorflow_probability/python/internal:distribution_util",
16571655
"//tensorflow_probability/python/internal:dtype_util",
@@ -1926,6 +1924,7 @@ multi_substrate_py_library(
19261924
"//tensorflow_probability/python/bijectors:shift",
19271925
"//tensorflow_probability/python/bijectors:sigmoid",
19281926
"//tensorflow_probability/python/bijectors:softplus",
1927+
"//tensorflow_probability/python/experimental/tangent_spaces:spherical",
19291928
"//tensorflow_probability/python/internal:assert_util",
19301929
"//tensorflow_probability/python/internal:distribution_util",
19311930
"//tensorflow_probability/python/internal:dtype_util",
@@ -2005,6 +2004,7 @@ multi_substrate_py_library(
20052004
"//tensorflow_probability/python/bijectors:exp",
20062005
"//tensorflow_probability/python/bijectors:softmax_centered",
20072006
"//tensorflow_probability/python/bijectors:softplus",
2007+
"//tensorflow_probability/python/experimental/tangent_spaces:simplex",
20082008
"//tensorflow_probability/python/internal:assert_util",
20092009
"//tensorflow_probability/python/internal:dtype_util",
20102010
"//tensorflow_probability/python/internal:parameter_properties",
@@ -2360,6 +2360,7 @@ multi_substrate_py_library(
23602360
"//tensorflow_probability/python/bijectors:sigmoid",
23612361
"//tensorflow_probability/python/bijectors:softmax_centered",
23622362
"//tensorflow_probability/python/bijectors:square",
2363+
"//tensorflow_probability/python/experimental/tangent_spaces:spherical",
23632364
"//tensorflow_probability/python/internal:assert_util",
23642365
"//tensorflow_probability/python/internal:dtype_util",
23652366
"//tensorflow_probability/python/internal:reparameterization",
@@ -2450,6 +2451,7 @@ multi_substrate_py_library(
24502451
"//tensorflow_probability/python/bijectors:softmax_centered",
24512452
"//tensorflow_probability/python/bijectors:softplus",
24522453
"//tensorflow_probability/python/bijectors:square",
2454+
"//tensorflow_probability/python/experimental/tangent_spaces:spherical",
24532455
"//tensorflow_probability/python/internal:assert_util",
24542456
"//tensorflow_probability/python/internal:dtype_util",
24552457
"//tensorflow_probability/python/internal:parameter_properties",

tensorflow_probability/python/distributions/dirichlet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -316,8 +316,8 @@ def _default_event_space_bijector(self):
316316
validate_args=self.validate_args)
317317

318318
def _experimental_tangent_space(self, x):
319-
from tensorflow_probability.python.experimental.tangent_spaces import spaces # pylint:disable=g-import-not-at-top
320-
return spaces.ProbabilitySimplexSpace()
319+
from tensorflow_probability.python.experimental.tangent_spaces import simplex # pylint:disable=g-import-not-at-top
320+
return simplex.ProbabilitySimplexSpace()
321321

322322
def _sample_control_dependencies(self, x):
323323
"""Checks the validity of a sample."""

tensorflow_probability/python/distributions/power_spherical.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,8 @@ def _default_event_space_bijector(self):
352352
], validate_args=self.validate_args)
353353

354354
def _experimental_tangent_space(self, x):
355-
from tensorflow_probability.python.experimental.tangent_spaces import spaces # pylint:disable=g-import-not-at-top
356-
return spaces.SphericalSpace()
355+
from tensorflow_probability.python.experimental.tangent_spaces import spherical # pylint:disable=g-import-not-at-top
356+
return spherical.SphericalSpace()
357357

358358
def _parameter_control_dependencies(self, is_init):
359359
if not self.validate_args:

tensorflow_probability/python/distributions/relaxed_onehot_categorical.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,6 @@ def _default_event_space_bijector(self):
563563
validate_args=self.validate_args)
564564

565565
def _experimental_tangent_space(self, x):
566-
from tensorflow_probability.python.experimental.tangent_spaces import spaces # pylint:disable=g-import-not-at-top
567-
return spaces.ProbabilitySimplexSpace()
566+
from tensorflow_probability.python.experimental.tangent_spaces import simplex # pylint:disable=g-import-not-at-top
567+
return simplex.ProbabilitySimplexSpace()
568568

tensorflow_probability/python/distributions/spherical_uniform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,8 @@ def _default_event_space_bijector(self):
199199
], validate_args=self.validate_args)
200200

201201
def _experimental_tangent_space(self, x):
202-
from tensorflow_probability.python.experimental.tangent_spaces import spaces # pylint:disable=g-import-not-at-top
203-
return spaces.SphericalSpace()
202+
from tensorflow_probability.python.experimental.tangent_spaces import spherical # pylint:disable=g-import-not-at-top
203+
return spherical.SphericalSpace()
204204

205205
def _sample_control_dependencies(self, samples):
206206
inner_sample_dim = samples.shape[-1]

tensorflow_probability/python/distributions/von_mises_fisher.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -495,8 +495,8 @@ def _default_event_space_bijector(self):
495495
], validate_args=self.validate_args)
496496

497497
def _experimental_tangent_space(self, x):
498-
from tensorflow_probability.python.experimental.tangent_spaces import spaces # pylint:disable=g-import-not-at-top
499-
return spaces.SphericalSpace()
498+
from tensorflow_probability.python.experimental.tangent_spaces import spherical # pylint:disable=g-import-not-at-top
499+
return spherical.SphericalSpace()
500500

501501
def _parameter_control_dependencies(self, is_init):
502502
if not self.validate_args:

tensorflow_probability/python/experimental/tangent_spaces/BUILD

Lines changed: 123 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,25 @@ multi_substrate_py_library(
3232
name = "tangent_spaces",
3333
srcs = ["__init__.py"],
3434
deps = [
35+
":simplex",
3536
":spaces",
37+
":spherical",
38+
":symmetric_matrix",
39+
],
40+
)
41+
42+
multi_substrate_py_library(
43+
name = "spaces_test_util",
44+
testonly = True,
45+
srcs = ["spaces_test_util.py"],
46+
deps = [
47+
":spaces",
48+
# numpy dep,
49+
# tensorflow dep,
50+
"//tensorflow_probability/python/bijectors:identity",
51+
"//tensorflow_probability/python/internal:tensorshape_util",
52+
"//tensorflow_probability/python/internal:test_util",
53+
"//tensorflow_probability/python/math:gradient",
3654
],
3755
)
3856

@@ -42,9 +60,7 @@ multi_substrate_py_library(
4260
deps = [
4361
# numpy dep,
4462
# tensorflow dep,
45-
"//tensorflow_probability/python/experimental/linalg:linear_operator_row_block",
4663
"//tensorflow_probability/python/internal:distribution_util",
47-
"//tensorflow_probability/python/internal:dtype_util",
4864
"//tensorflow_probability/python/internal:nest_util",
4965
"//tensorflow_probability/python/internal:prefer_static",
5066
"//tensorflow_probability/python/internal:tensor_util",
@@ -56,16 +72,16 @@ multi_substrate_py_test(
5672
name = "spaces_test",
5773
size = "medium",
5874
srcs = ["spaces_test.py"],
59-
shard_count = 10,
75+
shard_count = 3,
6076
tags = [
6177
"tf1-broken",
6278
],
6379
deps = [
6480
":spaces",
81+
":spaces_test_util",
6582
# numpy dep,
6683
# tensorflow dep,
6784
"//tensorflow_probability/python/bijectors:exp",
68-
"//tensorflow_probability/python/bijectors:identity",
6985
"//tensorflow_probability/python/bijectors:reshape",
7086
"//tensorflow_probability/python/bijectors:scale",
7187
"//tensorflow_probability/python/bijectors:scale_matvec_tril",
@@ -74,3 +90,106 @@ multi_substrate_py_test(
7490
"//tensorflow_probability/python/internal:test_util",
7591
],
7692
)
93+
94+
multi_substrate_py_library(
95+
name = "simplex",
96+
srcs = ["simplex.py"],
97+
deps = [
98+
":spaces",
99+
# numpy dep,
100+
# tensorflow dep,
101+
"//tensorflow_probability/python/experimental/linalg:linear_operator_row_block",
102+
"//tensorflow_probability/python/internal:distribution_util",
103+
"//tensorflow_probability/python/internal:dtype_util",
104+
"//tensorflow_probability/python/internal:nest_util",
105+
"//tensorflow_probability/python/internal:prefer_static",
106+
],
107+
)
108+
109+
multi_substrate_py_test(
110+
name = "simplex_test",
111+
size = "medium",
112+
srcs = ["simplex_test.py"],
113+
shard_count = 3,
114+
tags = [
115+
"tf1-broken",
116+
],
117+
deps = [
118+
":simplex",
119+
":spaces_test_util",
120+
# numpy dep,
121+
# scipy dep,
122+
# tensorflow dep,
123+
"//tensorflow_probability/python/bijectors:exp",
124+
"//tensorflow_probability/python/bijectors:identity",
125+
"//tensorflow_probability/python/bijectors:scale",
126+
"//tensorflow_probability/python/bijectors:scale_matvec_tril",
127+
"//tensorflow_probability/python/internal:test_util",
128+
],
129+
)
130+
131+
multi_substrate_py_library(
132+
name = "spherical",
133+
srcs = ["spherical.py"],
134+
deps = [
135+
":spaces",
136+
# tensorflow dep,
137+
"//tensorflow_probability/python/experimental/linalg:linear_operator_row_block",
138+
"//tensorflow_probability/python/internal:distribution_util",
139+
"//tensorflow_probability/python/internal:prefer_static",
140+
],
141+
)
142+
143+
multi_substrate_py_test(
144+
name = "spherical_test",
145+
size = "medium",
146+
srcs = ["spherical_test.py"],
147+
shard_count = 3,
148+
tags = [
149+
"tf1-broken",
150+
],
151+
deps = [
152+
":spaces_test_util",
153+
":spherical",
154+
# numpy dep,
155+
# tensorflow dep,
156+
"//tensorflow_probability/python/bijectors:exp",
157+
"//tensorflow_probability/python/bijectors:scale",
158+
"//tensorflow_probability/python/bijectors:scale_matvec_tril",
159+
"//tensorflow_probability/python/internal:test_util",
160+
],
161+
)
162+
163+
multi_substrate_py_library(
164+
name = "symmetric_matrix",
165+
srcs = ["symmetric_matrix.py"],
166+
deps = [
167+
":spaces",
168+
# numpy dep,
169+
# tensorflow dep,
170+
"//tensorflow_probability/python/internal:dtype_util",
171+
"//tensorflow_probability/python/internal:prefer_static",
172+
"//tensorflow_probability/python/math:linalg",
173+
],
174+
)
175+
176+
multi_substrate_py_test(
177+
name = "symmetric_matrix_test",
178+
size = "medium",
179+
srcs = ["symmetric_matrix_test.py"],
180+
shard_count = 3,
181+
tags = [
182+
"tf1-broken",
183+
],
184+
deps = [
185+
":spaces_test_util",
186+
":symmetric_matrix",
187+
# numpy dep,
188+
# tensorflow dep,
189+
"//tensorflow_probability/python/bijectors:exp",
190+
"//tensorflow_probability/python/bijectors:reshape",
191+
"//tensorflow_probability/python/bijectors:scale",
192+
"//tensorflow_probability/python/internal:test_util",
193+
"//tensorflow_probability/python/math:linalg",
194+
],
195+
)

tensorflow_probability/python/experimental/tangent_spaces/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,25 @@
1414
# ============================================================================
1515
"""TensorFlow Probability experimental tangent spaces package."""
1616

17+
from tensorflow_probability.python.experimental.tangent_spaces.simplex import ProbabilitySimplexSpace
1718
from tensorflow_probability.python.experimental.tangent_spaces.spaces import AxisAlignedSpace
1819
from tensorflow_probability.python.experimental.tangent_spaces.spaces import FullSpace
1920
from tensorflow_probability.python.experimental.tangent_spaces.spaces import GeneralSpace
20-
from tensorflow_probability.python.experimental.tangent_spaces.spaces import ProbabilitySimplexSpace
21-
from tensorflow_probability.python.experimental.tangent_spaces.spaces import SphericalSpace
2221
from tensorflow_probability.python.experimental.tangent_spaces.spaces import TangentSpace
2322
from tensorflow_probability.python.experimental.tangent_spaces.spaces import UnspecifiedTangentSpaceError
2423
from tensorflow_probability.python.experimental.tangent_spaces.spaces import ZeroSpace
24+
from tensorflow_probability.python.experimental.tangent_spaces.spherical import SphericalSpace
25+
from tensorflow_probability.python.experimental.tangent_spaces.symmetric_matrix import ConstantDiagonalSymmetricMatrixSpace
26+
from tensorflow_probability.python.experimental.tangent_spaces.symmetric_matrix import SymmetricMatrixSpace
2527

2628
__all__ = [
2729
'AxisAlignedSpace',
30+
'ConstantDiagonalSymmetricMatrixSpace',
2831
'FullSpace',
2932
'GeneralSpace',
3033
'ProbabilitySimplexSpace',
3134
'SphericalSpace',
35+
'SymmetricMatrixSpace',
3236
'TangentSpace',
3337
'UnspecifiedTangentSpaceError',
3438
'ZeroSpace',
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2023 The TensorFlow Probability Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
16+
"""Tangent Spaces related to simplices."""
17+
18+
import numpy as np
19+
import tensorflow.compat.v2 as tf
20+
from tensorflow_probability.python.experimental.linalg import linear_operator_row_block as lorb
21+
from tensorflow_probability.python.experimental.tangent_spaces import spaces
22+
from tensorflow_probability.python.internal import distribution_util
23+
from tensorflow_probability.python.internal import dtype_util
24+
from tensorflow_probability.python.internal import prefer_static as ps
25+
26+
27+
class ProbabilitySimplexSpace(spaces.TangentSpace):
28+
"""Tangent space of M for Simplex distributions in R^n."""
29+
30+
def compute_basis(self, x):
31+
"""Returns a `TangentSpace` of a n-simplex."""
32+
# The tangent space of the simplex satisfies `{x | <1, x> = 0}`, where `1`
33+
# is the vector of all `1`s. This can be seen by the fact that `1` is
34+
# orthogonal to the unit simplex.
35+
# We can do this by using the basis: e_i - e_n, 1 <= i <= n - 1. For n = 4,
36+
# this looks like:
37+
# [[1, 0., 0., -1],
38+
# [0, 1., 0., -1],
39+
# [0, 0., 1., -1]]
40+
dim = ps.shape(x)[-1]
41+
block1 = tf.linalg.LinearOperatorIdentity(num_rows=dim - 1, dtype=x.dtype)
42+
block2 = tf.linalg.LinearOperatorFullMatrix(
43+
-tf.ones([dim - 1, 1], dtype=x.dtype))
44+
simplex_basis_linop = lorb.LinearOperatorRowBlock([block1, block2])
45+
return spaces.LinearOperatorBasis(simplex_basis_linop)
46+
47+
def _transform_general(self, x, f, **kwargs):
48+
basis = self.compute_basis(x)
49+
# Note that B @ B.T results in the matrix I + 11^T, where 1 is the vector of
50+
# all ones. By the matrix determinant lemma we have det(I + 11^T) = n + 1,
51+
# or the dimension of the ambient space.
52+
dim = ps.shape(x)[-1]
53+
result = dtype_util.as_numpy_dtype(x.dtype)(0.5 * np.log(dim))
54+
new_basis_tensor = spaces.compute_new_basis_tensor(f, x, basis)
55+
new_log_volume = spaces.volume_coefficient(
56+
distribution_util.move_dimension(new_basis_tensor, 0, -2))
57+
result = new_log_volume - result
58+
return result, spaces.GeneralSpace(
59+
spaces.DenseBasis(new_basis_tensor), computed_log_volume=new_log_volume)
60+
61+
def _transform_coordinatewise(self, x, f, **kwargs):
62+
# Compute the diagonal. New matrix is Linop that we can easily write.
63+
dim = ps.shape(x)[-1]
64+
diag_jacobian = spaces.coordinatewise_jvp(f, x)
65+
# Multiplying the basis written in block form as [I, 1] by the diagonal
66+
# results in this operator:
67+
block1 = tf.linalg.LinearOperatorDiag(diag_jacobian[..., :-1])
68+
block2 = tf.linalg.LinearOperatorFullMatrix(
69+
diag_jacobian[..., -1:, tf.newaxis] *
70+
tf.ones([dim - 1, 1], dtype=x.dtype))
71+
linop = lorb.LinearOperatorRowBlock([block1, block2])
72+
73+
# The volume can be calculated again by the matrix determinant lemma:
74+
# det(D**2 + d_n**2 11^T) = (1 + d_n**2 1(D^-1)**21^T) * det(D**2)
75+
# = (\sum d_i**-2) * \prod d_i**2
76+
log_diag_jacobian = tf.math.log(tf.math.abs(diag_jacobian))
77+
log_volume = tf.math.reduce_sum(log_diag_jacobian, axis=-1)
78+
log_volume = log_volume + 0.5 * tf.math.reduce_logsumexp(
79+
-2. * log_diag_jacobian, axis=-1) - 0.5 * np.log(dim)
80+
return log_volume, spaces.GeneralSpace(
81+
spaces.LinearOperatorBasis(linop), computed_log_volume=log_volume)
82+
83+

0 commit comments

Comments
 (0)