Skip to content

Commit a836d66

Browse files
Merge branch 'tensorflow:main' into betaincgrad
2 parents 885cdca + 35119de commit a836d66

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+582
-159
lines changed

spinoffs/oryx/oryx/core/trace_util.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import threading
1818
from typing import Any, Dict, Generator, List
1919

20-
import jax
2120
from jax import abstract_arrays
2221
from jax import api_util
2322
from jax import core as jax_core
@@ -63,8 +62,6 @@ def wrapped(*args, **kwargs):
6362
flat_args, in_tree = tree_util.tree_flatten(args)
6463
flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
6564
flat_avals = safe_map(get_shaped_aval, flat_args)
66-
if not jax.config.omnistaging_enabled:
67-
raise ValueError('Oryx must be used with JAX omnistaging enabled.')
6865
if dynamic:
6966
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
7067
flat_fun,

tensorflow_probability/python/bijectors/BUILD

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1254,7 +1254,6 @@ multi_substrate_py_test(
12541254
size = "medium",
12551255
srcs = ["correlation_cholesky_test.py"],
12561256
jax_size = "large",
1257-
numpy_tags = ["notap"],
12581257
shard_count = 2,
12591258
deps = [
12601259
":bijector_test_util",
@@ -1636,7 +1635,6 @@ multi_substrate_py_test(
16361635
name = "permute_test",
16371636
size = "small",
16381637
srcs = ["permute_test.py"],
1639-
numpy_tags = ["notap"],
16401638
deps = [
16411639
":bijector_test_util",
16421640
":bijectors",
@@ -1705,7 +1703,6 @@ multi_substrate_py_test(
17051703
name = "rational_quadratic_spline_test",
17061704
size = "medium",
17071705
srcs = ["rational_quadratic_spline_test.py"],
1708-
numpy_tags = ["notap"],
17091706
tags = ["hypothesis"],
17101707
deps = [
17111708
":bijector_test_util",

tensorflow_probability/python/bijectors/bijector_properties_test.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from tensorflow_probability.python.internal import tensor_util
3333
from tensorflow_probability.python.internal import tensorshape_util
3434
from tensorflow_probability.python.internal import test_util
35-
from tensorflow_probability.python.util.deferred_tensor import DeferredTensor
3635

3736
from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import
3837

@@ -789,13 +788,6 @@ def testCompositeTensor(self, bijector_name, data):
789788
'bijectors.')
790789
self.skipTest('`_Invert` bijectors are not `CompositeTensor`s.')
791790

792-
if not tf.executing_eagerly():
793-
bijector = tf.nest.map_structure(
794-
lambda x: (tf.convert_to_tensor(x) # pylint: disable=g-long-lambda
795-
if isinstance(x, DeferredTensor) else x),
796-
bijector,
797-
expand_composites=True)
798-
799791
self.assertIsInstance(bijector, tf.__internal__.CompositeTensor)
800792
flat = tf.nest.flatten(bijector, expand_composites=True)
801793
unflat = tf.nest.pack_sequence_as(bijector, flat, expand_composites=True)
@@ -833,6 +825,8 @@ def testCompositeTensor(self, bijector_name, data):
833825
grads = tape.gradient(ys, wrt_vars)
834826
assert_no_none_grad(bijector, 'forward', wrt_vars, grads)
835827

828+
self.assertConvertVariablesToTensorsWorks(bijector)
829+
836830

837831
def ensure_nonzero(x):
838832
return tf.where(x < 1e-6, tf.constant(1e-3, x.dtype), x)

tensorflow_probability/python/bijectors/bijector_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -904,7 +904,8 @@ def testInstanceCache(self):
904904
z = instance_cache_bijector.forward(x)
905905
self.assertIsNot(y, z)
906906

907-
@test_util.jax_disable_test_missing_functionality('keras')
907+
@test_util.disable_test_for_backend(
908+
disable_numpy=True, disable_jax=True, reason='keras')
908909
@parameterized.named_parameters(
909910
('Keras', True),
910911
('NoKeras', False))
@@ -1093,13 +1094,15 @@ def test_caches(self):
10931094
@test_util.test_all_tf_execution_regimes
10941095
class TfModuleTest(test_util.TestCase):
10951096

1097+
@test_util.numpy_disable_variable_test
10961098
@test_util.jax_disable_variable_test
10971099
def test_variable_tracking(self):
10981100
x = tf.Variable(1.)
10991101
b = ForwardOnlyBijector(scale=x, validate_args=True)
11001102
self.assertIsInstance(b, tf.Module)
11011103
self.assertEqual((x,), b.trainable_variables)
11021104

1105+
@test_util.numpy_disable_variable_test
11031106
@test_util.jax_disable_variable_test
11041107
def test_gradient(self):
11051108
x = tf.Variable(1.)

tensorflow_probability/python/bijectors/correlation_cholesky_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def sample_mcmc_chain():
227227
cdf=beta_dist.cdf,
228228
false_fail_rate=1e-9))
229229

230+
@test_util.numpy_disable_gradient_test
230231
def testTheoreticalFldj(self):
231232
bijector = tfb.CorrelationCholesky()
232233
x = np.linspace(-50, 50, num=30).reshape(5, 6).astype(np.float64)

tensorflow_probability/python/bijectors/permute_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,9 @@ def testBijectiveAndFiniteAxis(self):
8383
bijector, x, y, eval_func=self.evaluate, event_ndims=2, rtol=1e-6,
8484
atol=0)
8585

86-
@test_util.jax_disable_test_missing_functionality(
87-
'Test specific to Keras with losing shape information.')
86+
@test_util.disable_test_for_backend(
87+
disable_numpy=True, disable_jax=True,
88+
reason='Test specific to Keras with losing shape information.')
8889
def testPreservesShape(self):
8990
# TODO(b/131157549, b/131124359): Test should not be needed. Consider
9091
# deleting when underlying issue with constant eager tensors is fixed.

tensorflow_probability/python/bijectors/rational_quadratic_spline_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def testDegenerateSplines(self):
144144
np.zeros_like(xs),
145145
self.evaluate(bijector.forward_log_det_jacobian(xs, event_ndims=0)))
146146

147+
@test_util.numpy_disable_gradient_test
147148
def testTheoreticalFldjSimple(self):
148149
bijector = tfb.RationalQuadraticSpline(
149150
bin_widths=[1., 1],

tensorflow_probability/python/distributions/BUILD

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ multi_substrate_py_library(
3636
srcs = ["__init__.py"],
3737
substrates_omit_deps = [
3838
":pixel_cnn",
39-
":mixture",
4039
],
4140
deps = [
4241
":autoregressive",
@@ -2669,7 +2668,6 @@ multi_substrate_py_test(
26692668
size = "medium",
26702669
srcs = ["distribution_test.py"],
26712670
jax_size = "medium",
2672-
numpy_tags = ["notap"],
26732671
deps = [
26742672
# numpy dep,
26752673
# tensorflow dep,
@@ -2919,7 +2917,6 @@ multi_substrate_py_test(
29192917
multi_substrate_py_test(
29202918
name = "half_cauchy_test",
29212919
srcs = ["half_cauchy_test.py"],
2922-
numpy_tags = ["notap"],
29232920
shard_count = 5,
29242921
deps = [
29252922
# numpy dep,
@@ -2933,7 +2930,6 @@ multi_substrate_py_test(
29332930
multi_substrate_py_test(
29342931
name = "half_normal_test",
29352932
srcs = ["half_normal_test.py"],
2936-
numpy_tags = ["notap"],
29372933
deps = [
29382934
# numpy dep,
29392935
# scipy dep,
@@ -3158,7 +3154,6 @@ multi_substrate_py_test(
31583154
multi_substrate_py_test(
31593155
name = "lambertw_f_test",
31603156
srcs = ["lambertw_f_test.py"],
3161-
numpy_tags = ["notap"],
31623157
deps = [
31633158
":lambertw_f",
31643159
# absl/testing:parameterized dep,
@@ -3229,7 +3224,6 @@ multi_substrate_py_test(
32293224
multi_substrate_py_test(
32303225
name = "loglogistic_test",
32313226
srcs = ["loglogistic_test.py"],
3232-
numpy_tags = ["notap"],
32333227
deps = [
32343228
# numpy dep,
32353229
# scipy dep,
@@ -3322,7 +3316,6 @@ multi_substrate_py_test(
33223316
name = "mixture_test",
33233317
size = "medium",
33243318
srcs = ["mixture_test.py"],
3325-
jax_tags = ["notap"],
33263319
numpy_tags = ["notap"],
33273320
deps = [
33283321
# hypothesis dep,
@@ -3341,7 +3334,6 @@ multi_substrate_py_test(
33413334
size = "medium",
33423335
srcs = ["mixture_same_family_test.py"],
33433336
jax_size = "large",
3344-
numpy_tags = ["notap"],
33453337
shard_count = 3,
33463338
deps = [
33473339
# hypothesis dep,
@@ -3372,7 +3364,6 @@ multi_substrate_py_test(
33723364
size = "medium",
33733365
srcs = ["multinomial_test.py"],
33743366
jax_size = "medium",
3375-
numpy_tags = ["notap"],
33763367
shard_count = 5,
33773368
deps = [
33783369
# hypothesis dep,
@@ -3428,7 +3419,6 @@ multi_substrate_py_test(
34283419
multi_substrate_py_test(
34293420
name = "mvn_diag_test",
34303421
srcs = ["mvn_diag_test.py"],
3431-
numpy_tags = ["notap"],
34323422
deps = [
34333423
# numpy dep,
34343424
# scipy dep,
@@ -3456,7 +3446,6 @@ multi_substrate_py_test(
34563446
multi_substrate_py_test(
34573447
name = "mvn_linear_operator_test",
34583448
srcs = ["mvn_linear_operator_test.py"],
3459-
numpy_tags = ["notap"],
34603449
deps = [
34613450
# numpy dep,
34623451
# scipy dep,
@@ -3483,7 +3472,6 @@ multi_substrate_py_test(
34833472
name = "mvn_tril_test",
34843473
size = "medium",
34853474
srcs = ["mvn_tril_test.py"],
3486-
numpy_tags = ["notap"],
34873475
shard_count = 7,
34883476
tags = ["nomsan"],
34893477
deps = [

tensorflow_probability/python/distributions/distribution_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,7 @@ def __init__(self, extra_param, *args, **kwargs):
787787
@test_util.test_all_tf_execution_regimes
788788
class TfModuleTest(test_util.TestCase):
789789

790+
@test_util.numpy_disable_variable_test
790791
@test_util.jax_disable_variable_test
791792
def test_variable_tracking_works(self):
792793
scale = tf.Variable(1.)

tensorflow_probability/python/distributions/gaussian_process.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
from tensorflow_probability.python.internal import tensorshape_util
3939
from tensorflow_probability.python.math.psd_kernels.internal import util as psd_kernels_util
4040
from tensorflow.python.util import deprecation # pylint: disable=g-direct-tensorflow-import
41+
from tensorflow.python.util import variable_utils # pylint: disable=g-direct-tensorflow-import
42+
4143

4244
__all__ = [
4345
'GaussianProcess',
@@ -767,6 +769,14 @@ def _type_spec(self):
767769
omit_kwargs=('parameters', '_check_marginal_cholesky_fn'),
768770
non_identifying_kwargs=('name',))
769771

772+
def _convert_variables_to_tensors(self):
773+
# pylint: disable=protected-access
774+
components = self._type_spec._to_components(self)
775+
tensor_components = variable_utils.convert_variables_to_tensors(
776+
components)
777+
return self._type_spec._from_components(tensor_components)
778+
# pylint: enable=protected-access
779+
770780

771781
@auto_composite_tensor.type_spec_register(
772782
'tfp.distributions.GaussianProcess_ACTTypeSpec')

0 commit comments

Comments
 (0)