Skip to content

Commit 24301c5

Browse files
srvasudetensorflower-gardener
authored andcommitted
Change tfp/experimental to use internal imports to speed up tests.
- Updates a few tests to avoid importing `distributions` or `bijectors` wholesale. PiperOrigin-RevId: 472567989
1 parent c03f005 commit 24301c5

File tree

114 files changed

+4733
-3633
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

114 files changed

+4733
-3633
lines changed

tensorflow_probability/python/distributions/BUILD

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,7 @@ multi_substrate_py_library(
766766
"//tensorflow_probability/python/internal:distribution_util",
767767
"//tensorflow_probability/python/internal:dtype_util",
768768
"//tensorflow_probability/python/internal:tensorshape_util",
769-
"//tensorflow_probability/python/math/psd_kernels",
769+
"//tensorflow_probability/python/math/psd_kernels:schur_complement",
770770
"//tensorflow_probability/python/util",
771771
],
772772
)
@@ -1025,6 +1025,8 @@ multi_substrate_py_library(
10251025
srcs = ["inverse_gaussian.py"],
10261026
deps = [
10271027
":distribution",
1028+
":inflated",
1029+
":negative_binomial",
10281030
":normal",
10291031
# numpy dep,
10301032
# tensorflow dep,
@@ -1145,6 +1147,8 @@ multi_substrate_py_library(
11451147
name = "joint_distribution_util",
11461148
srcs = ["joint_distribution_util.py"],
11471149
deps = [
1150+
":independent",
1151+
":joint_distribution_auto_batched",
11481152
":joint_distribution_named",
11491153
":joint_distribution_sequential",
11501154
# tensorflow dep,
@@ -1420,7 +1424,6 @@ multi_substrate_py_library(
14201424
"//tensorflow_probability/python/internal:parameter_properties",
14211425
"//tensorflow_probability/python/internal:reparameterization",
14221426
"//tensorflow_probability/python/internal:tensor_util",
1423-
"//tensorflow_probability/python/math",
14241427
"//tensorflow_probability/python/util:seed_stream",
14251428
],
14261429
)
@@ -2156,6 +2159,8 @@ multi_substrate_py_library(
21562159
"//tensorflow_probability/python/internal:reparameterization",
21572160
"//tensorflow_probability/python/internal:tensor_util",
21582161
"//tensorflow_probability/python/internal:tensorshape_util",
2162+
"//tensorflow_probability/python/math/psd_kernels:positive_semidefinite_kernel",
2163+
"//tensorflow_probability/python/math/psd_kernels:schur_complement",
21592164
],
21602165
)
21612166

@@ -2470,7 +2475,7 @@ multi_substrate_py_test(
24702475
"//tensorflow_probability/python/bijectors:scale_matvec_tril",
24712476
"//tensorflow_probability/python/internal:reparameterization",
24722477
"//tensorflow_probability/python/internal:test_util",
2473-
"//tensorflow_probability/python/math",
2478+
"//tensorflow_probability/python/math:linalg",
24742479
],
24752480
)
24762481

@@ -3198,8 +3203,12 @@ multi_substrate_py_test(
31983203
name = "inflated_test",
31993204
srcs = ["inflated_test.py"],
32003205
deps = [
3206+
":inflated",
3207+
":negative_binomial",
3208+
":normal",
32013209
# numpy dep,
3202-
"//tensorflow_probability",
3210+
# tensorflow dep,
3211+
"//tensorflow_probability/python/experimental/util",
32033212
"//tensorflow_probability/python/internal:test_util",
32043213
],
32053214
)
@@ -3230,7 +3239,7 @@ multi_substrate_py_test(
32303239
# scipy dep,
32313240
# tensorflow dep,
32323241
"//tensorflow_probability/python/internal:test_util",
3233-
"//tensorflow_probability/python/math",
3242+
"//tensorflow_probability/python/math:gradient",
32343243
],
32353244
)
32363245

tensorflow_probability/python/distributions/autoregressive_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
import numpy as np
2222
import tensorflow.compat.v2 as tf
23-
from tensorflow_probability.python import math
2423
from tensorflow_probability.python.bijectors import masked_autoregressive
2524
from tensorflow_probability.python.bijectors import scale_matvec_tril
2625
from tensorflow_probability.python.distributions import autoregressive
@@ -32,6 +31,7 @@
3231
from tensorflow_probability.python.distributions import transformed_distribution
3332
from tensorflow_probability.python.internal import reparameterization
3433
from tensorflow_probability.python.internal import test_util
34+
from tensorflow_probability.python.math import linalg
3535

3636

3737
@test_util.test_all_tf_execution_regimes
@@ -46,7 +46,7 @@ def setUp(self):
4646
def _random_scale_tril(self, event_size):
4747
n = np.int32(event_size * (event_size + 1) // 2)
4848
p = 2. * self._rng.random_sample(n).astype(np.float32) - 1.
49-
return math.fill_triangular(0.25 * p)
49+
return linalg.fill_triangular(0.25 * p)
5050

5151
def _normal_fn(self, affine_bijector):
5252
def _fn(samples):

tensorflow_probability/python/distributions/gaussian_process_regression_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from tensorflow_probability.python.internal import parameter_properties
2626
from tensorflow_probability.python.internal import tensor_util
2727
from tensorflow_probability.python.internal import tensorshape_util
28-
from tensorflow_probability.python.math import psd_kernels as tfpk
28+
from tensorflow_probability.python.math.psd_kernels import schur_complement
2929
from tensorflow.python.util import deprecation # pylint: disable=g-direct-tensorflow-import
3030

3131

@@ -534,7 +534,7 @@ def __init__(self,
534534

535535
with tf.name_scope('init'):
536536
if _conditional_kernel is None:
537-
_conditional_kernel = tfpk.SchurComplement(
537+
_conditional_kernel = schur_complement.SchurComplement(
538538
base_kernel=kernel,
539539
fixed_inputs=observation_index_points,
540540
cholesky_fn=cholesky_fn,
@@ -749,7 +749,7 @@ def precompute_regression_model(
749749
if cholesky_fn is None:
750750
cholesky_fn = cholesky_util.make_cholesky_with_jitter_fn(jitter)
751751

752-
conditional_kernel = tfpk.SchurComplement.with_precomputed_divisor(
752+
conditional_kernel = schur_complement.SchurComplement.with_precomputed_divisor(
753753
base_kernel=kernel,
754754
fixed_inputs=observation_index_points,
755755
fixed_inputs_is_missing=observations_is_missing,

tensorflow_probability/python/distributions/inflated_test.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,41 +16,42 @@
1616
import numpy as np
1717
import tensorflow.compat.v2 as tf
1818

19-
import tensorflow_probability as tfp
20-
21-
from tensorflow_probability.python import distributions as tfd
22-
from tensorflow_probability.python import experimental as tfe
19+
from tensorflow_probability.python.distributions import inflated
20+
from tensorflow_probability.python.distributions import negative_binomial
21+
from tensorflow_probability.python.distributions import normal
22+
from tensorflow_probability.python.experimental import util
23+
from tensorflow_probability.python.experimental.util import trainable
2324
from tensorflow_probability.python.internal import test_util
2425

25-
tfe_util = tfp.experimental.util
26-
2726

2827
class DistributionsTest(test_util.TestCase):
2928

3029
def test_inflated(self):
31-
zinb = tfd.Inflated(
32-
tfd.NegativeBinomial(5.0, probs=0.1), inflated_loc_probs=0.2)
30+
zinb = inflated.Inflated(
31+
negative_binomial.NegativeBinomial(5.0, probs=0.1),
32+
inflated_loc_probs=0.2)
3333
samples = zinb.sample(sample_shape=10, seed=test_util.test_seed())
3434
self.assertEqual((10,), samples.shape)
3535

36-
spike_and_slab = tfd.Inflated(
37-
tfd.Normal(loc=1.0, scale=2.0), inflated_loc_probs=0.5)
36+
spike_and_slab = inflated.Inflated(
37+
normal.Normal(loc=1.0, scale=2.0), inflated_loc_probs=0.5)
3838
lprob = self.evaluate(spike_and_slab.log_prob(99.0))
3939
self.assertLess(lprob, 0.0)
4040

41-
param_props = tfd.Inflated.parameter_properties(dtype=tf.float32)
41+
param_props = inflated.Inflated.parameter_properties(dtype=tf.float32)
4242
self.assertFalse(param_props['distribution'].is_tensor)
4343
self.assertTrue(param_props['inflated_loc_logits'].is_preferred)
4444
self.assertFalse(param_props['inflated_loc_probs'].is_preferred)
4545
self.assertTrue(param_props['inflated_loc'].is_tensor)
4646

4747
def test_inflated_batched(self):
48-
nb = tfd.NegativeBinomial(
48+
nb = negative_binomial.NegativeBinomial(
4949
total_count=np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32),
5050
logits=np.array([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=np.float32))
51-
zinb = tfd.Inflated(
52-
nb, inflated_loc_probs=np.array(
53-
[0.2, 0.4, 0.6, 0.8, 1.0], dtype=np.float32))
51+
zinb = inflated.Inflated(
52+
nb,
53+
inflated_loc_probs=np.array([0.2, 0.4, 0.6, 0.8, 1.0],
54+
dtype=np.float32))
5455

5556
lprob = zinb.log_prob([0, 1, 2, 3, 4])
5657
self.assertEqual((5,), lprob.shape)
@@ -59,24 +60,24 @@ def test_inflated_batched(self):
5960
self.assertEqual((5,), samples.shape)
6061

6162
def test_inflated_factory(self):
62-
spike_and_slab_class = tfe.distributions.inflated_factory(
63-
'SpikeAndSlab', tfd.Normal, 0.0)
63+
spike_and_slab_class = inflated.inflated_factory('SpikeAndSlab',
64+
normal.Normal, 0.0)
6465
spike_and_slab = spike_and_slab_class(
6566
inflated_loc_probs=0.3, loc=5.0, scale=2.0)
66-
spike_and_slab2 = tfd.Inflated(
67-
tfd.Normal(loc=5.0, scale=2.0), inflated_loc_probs=0.3)
67+
spike_and_slab2 = inflated.Inflated(
68+
normal.Normal(loc=5.0, scale=2.0), inflated_loc_probs=0.3)
6869
self.assertEqual(
6970
self.evaluate(spike_and_slab.log_prob(7.0)),
7071
self.evaluate(spike_and_slab2.log_prob(7.0)))
7172

7273
def test_zero_inflated_negative_binomial(self):
73-
zinb = tfd.ZeroInflatedNegativeBinomial(
74+
zinb = inflated.ZeroInflatedNegativeBinomial(
7475
inflated_loc_probs=0.2, probs=0.5, total_count=10.0)
7576
self.assertEqual('ZeroInflatedNegativeBinomial', zinb.name)
7677

7778
def test_zinb_is_trainable(self):
78-
init_fn, apply_fn = tfe_util.make_trainable_stateless(
79-
tfd.ZeroInflatedNegativeBinomial,
79+
init_fn, apply_fn = trainable.make_trainable_stateless(
80+
inflated.ZeroInflatedNegativeBinomial,
8081
batch_and_event_shape=[5],
8182
parameter_dtype=tf.float32)
8283
init_obj = init_fn(seed=test_util.test_seed())
@@ -96,8 +97,9 @@ def test_zinb_is_trainable(self):
9697
disable_numpy=True,
9798
reason='Only TF has composite tensors')
9899
def test_zinb_as_composite_tensor(self):
99-
zinb = tfd.ZeroInflatedNegativeBinomial(0.1, total_count=10.0, probs=0.4)
100-
comp_zinb = tfe.as_composite(zinb)
100+
zinb = inflated.ZeroInflatedNegativeBinomial(
101+
0.1, total_count=10.0, probs=0.4)
102+
comp_zinb = util.as_composite(zinb)
101103
unused_as_tensors = tf.nest.flatten(comp_zinb)
102104

103105

tensorflow_probability/python/distributions/inverse_gaussian_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
from scipy import stats
1818
import tensorflow.compat.v1 as tf1
1919
import tensorflow.compat.v2 as tf
20-
from tensorflow_probability.python import math as tfm
2120
from tensorflow_probability.python.distributions import inverse_gaussian
2221
from tensorflow_probability.python.internal import test_util
22+
from tensorflow_probability.python.math import gradient
2323

2424

2525
def _scipy_invgauss(loc, concentration):
@@ -373,7 +373,7 @@ def testInverseGaussianSampleMultidimensionalVariance(self):
373373
def testInverseGaussianFullyReparameterized(self):
374374
concentration = tf.constant(4.0)
375375
loc = tf.constant(3.0)
376-
_, [grad_concentration, grad_loc] = tfm.value_and_gradient(
376+
_, [grad_concentration, grad_loc] = gradient.value_and_gradient(
377377
lambda a, b: inverse_gaussian.InverseGaussian(a, b, validate_args=True). # pylint: disable=g-long-lambda
378378
sample(100, seed=test_util.test_seed()),
379379
[concentration, loc])
@@ -393,7 +393,7 @@ def gen_samples(l, c):
393393
2, seed=test_util.test_seed())
394394

395395
samples, [loc_grad, concentration_grad] = self.evaluate(
396-
tfm.value_and_gradient(gen_samples, [loc, concentration]))
396+
gradient.value_and_gradient(gen_samples, [loc, concentration]))
397397
self.assertEqual(samples.shape, (2, 4, 3))
398398
self.assertEqual(concentration_grad.shape, concentration.shape)
399399
self.assertEqual(loc_grad.shape, loc.shape)

tensorflow_probability/python/distributions/student_t_process_regression_model.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
from tensorflow_probability.python.internal import prefer_static as ps
2929
from tensorflow_probability.python.internal import tensor_util
3030
from tensorflow_probability.python.internal import tensorshape_util
31-
from tensorflow_probability.python.math import psd_kernels as tfpk
31+
from tensorflow_probability.python.math.psd_kernels import positive_semidefinite_kernel as psd_kernel
32+
from tensorflow_probability.python.math.psd_kernels import schur_complement as schur_complement_lib
3233

3334

3435
__all__ = [
@@ -102,7 +103,7 @@ def _validate_observation_data(
102103
index_point_count, observation_count))
103104

104105

105-
class DampedSchurComplement(tfpk.AutoCompositeTensorPsdKernel):
106+
class DampedSchurComplement(psd_kernel.AutoCompositeTensorPsdKernel):
106107
"""Schur complement kernel, damped by scalar factors.
107108
108109
This kernel is the same as the SchurComplement kernel, except we multiply by
@@ -398,7 +399,7 @@ def __init__(
398399
if _conditional_kernel is None:
399400
_conditional_kernel = DampedSchurComplement(
400401
df=df,
401-
schur_complement=tfpk.SchurComplement(
402+
schur_complement=schur_complement_lib.SchurComplement(
402403
base_kernel=kernel,
403404
fixed_inputs=self._observation_index_points,
404405
diag_shift=observation_noise_variance),
@@ -606,7 +607,7 @@ def precompute_regression_model(
606607

607608
conditional_kernel = DampedSchurComplement(
608609
df=df,
609-
schur_complement=tfpk.SchurComplement(
610+
schur_complement=schur_complement_lib.SchurComplement(
610611
base_kernel=kernel,
611612
fixed_inputs=observation_index_points,
612613
diag_shift=observation_noise_variance),

0 commit comments

Comments
 (0)