Skip to content

Commit c6e4fab

Browse files
emilyfertigtensorflower-gardener
authored andcommitted
Enable AutoCompositeTensor for simple distributions.
PiperOrigin-RevId: 378226943
1 parent 0fc11b3 commit c6e4fab

File tree

88 files changed

+252
-103
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

88 files changed

+252
-103
lines changed

tensorflow_probability/python/distributions/BUILD

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,7 @@ multi_substrate_py_library(
384384
srcs = ["chi.py"],
385385
deps = [
386386
":chi2",
387+
":distribution",
387388
":kullback_leibler",
388389
":transformed_distribution",
389390
# numpy dep,
@@ -610,8 +611,11 @@ multi_substrate_py_library(
610611
# tensorflow dep,
611612
"//tensorflow_probability/python/bijectors:identity",
612613
"//tensorflow_probability/python/bijectors:softplus",
614+
"//tensorflow_probability/python/internal:assert_util",
613615
"//tensorflow_probability/python/internal:dtype_util",
614616
"//tensorflow_probability/python/internal:parameter_properties",
617+
"//tensorflow_probability/python/internal:prefer_static",
618+
"//tensorflow_probability/python/internal:reparameterization",
615619
"//tensorflow_probability/python/internal:samplers",
616620
"//tensorflow_probability/python/internal:tensor_util",
617621
],
@@ -806,6 +810,7 @@ multi_substrate_py_library(
806810
name = "gumbel",
807811
srcs = ["gumbel.py"],
808812
deps = [
813+
":distribution",
809814
":kullback_leibler",
810815
":transformed_distribution",
811816
":uniform",
@@ -826,6 +831,7 @@ multi_substrate_py_library(
826831
name = "gev",
827832
srcs = ["gev.py"],
828833
deps = [
834+
":distribution",
829835
":kullback_leibler",
830836
":transformed_distribution",
831837
":uniform",
@@ -995,6 +1001,7 @@ multi_substrate_py_library(
9951001
name = "johnson_su",
9961002
srcs = ["johnson_su.py"],
9971003
deps = [
1004+
":distribution",
9981005
":normal",
9991006
":transformed_distribution",
10001007
# tensorflow dep,
@@ -1132,6 +1139,7 @@ multi_substrate_py_library(
11321139
name = "kumaraswamy",
11331140
srcs = ["kumaraswamy.py"],
11341141
deps = [
1142+
":distribution",
11351143
":transformed_distribution",
11361144
":uniform",
11371145
# numpy dep,
@@ -1278,6 +1286,7 @@ multi_substrate_py_library(
12781286
name = "lognormal",
12791287
srcs = ["lognormal.py"],
12801288
deps = [
1289+
":distribution",
12811290
":normal",
12821291
":transformed_distribution",
12831292
# numpy dep,
@@ -1294,6 +1303,7 @@ multi_substrate_py_library(
12941303
name = "logitnormal",
12951304
srcs = ["logitnormal.py"],
12961305
deps = [
1306+
":distribution",
12971307
":normal",
12981308
":transformed_distribution",
12991309
# numpy dep,
@@ -1426,6 +1436,7 @@ multi_substrate_py_library(
14261436
name = "moyal",
14271437
srcs = ["moyal.py"],
14281438
deps = [
1439+
":distribution",
14291440
":kullback_leibler",
14301441
":transformed_distribution",
14311442
":uniform",
@@ -1500,6 +1511,7 @@ multi_substrate_py_library(
15001511
name = "mvn_linear_operator",
15011512
srcs = ["mvn_linear_operator.py"],
15021513
deps = [
1514+
":distribution",
15031515
":kullback_leibler",
15041516
":normal",
15051517
":sample",
@@ -2219,6 +2231,7 @@ multi_substrate_py_library(
22192231
name = "vector_exponential_linear_operator",
22202232
srcs = ["vector_exponential_linear_operator.py"],
22212233
deps = [
2234+
":distribution",
22222235
":exponential",
22232236
":sample",
22242237
":transformed_distribution",
@@ -2227,6 +2240,7 @@ multi_substrate_py_library(
22272240
"//tensorflow_probability/python/bijectors:scale_matvec_linear_operator",
22282241
"//tensorflow_probability/python/bijectors:shift",
22292242
"//tensorflow_probability/python/bijectors:softplus",
2243+
"//tensorflow_probability/python/internal:assert_util",
22302244
"//tensorflow_probability/python/internal:distribution_util",
22312245
"//tensorflow_probability/python/internal:dtype_util",
22322246
"//tensorflow_probability/python/internal:tensorshape_util",
@@ -2287,6 +2301,7 @@ multi_substrate_py_library(
22872301
srcs = ["weibull.py"],
22882302
deps = [
22892303
":distribution",
2304+
":transformed_distribution",
22902305
# numpy dep,
22912306
# tensorflow dep,
22922307
"//tensorflow_probability/python/bijectors:chain",
@@ -3382,6 +3397,7 @@ multi_substrate_py_test(
33823397
name = "mvn_diag_plus_low_rank_covariance_test",
33833398
srcs = ["mvn_diag_plus_low_rank_covariance_test.py"],
33843399
deps = [
3400+
":distribution",
33853401
":mvn_diag_plus_low_rank_covariance",
33863402
# numpy dep,
33873403
# tensorflow dep,

tensorflow_probability/python/distributions/bates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
}
5959

6060

61-
class Bates(distribution.Distribution):
61+
class Bates(distribution.AutoCompositeTensorDistribution):
6262
"""Bates distribution.
6363
6464
The Bates distribution is the distribution of the average of `total_count`

tensorflow_probability/python/distributions/bernoulli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from tensorflow_probability.python.internal import tensor_util
3333

3434

35-
class Bernoulli(distribution.Distribution):
35+
class Bernoulli(distribution.AutoCompositeTensorDistribution):
3636
"""Bernoulli distribution.
3737
3838
The Bernoulli distribution with `probs` parameter, i.e., the probability of a

tensorflow_probability/python/distributions/beta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
`[0, 1].` It must have a shape compatible with `self.batch_shape()`."""
4949

5050

51-
class Beta(distribution.Distribution):
51+
class Beta(distribution.AutoCompositeTensorDistribution):
5252
"""Beta distribution.
5353
5454
The Beta distribution is defined over the `(0, 1)` interval using parameters

tensorflow_probability/python/distributions/beta_binomial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
"""
5151

5252

53-
class BetaBinomial(distribution.Distribution):
53+
class BetaBinomial(distribution.AutoCompositeTensorDistribution):
5454
"""Beta-Binomial compound distribution.
5555
5656
The Beta-Binomial distribution is parameterized by (a batch of) `total_count`

tensorflow_probability/python/distributions/beta_quotient.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
]
4242

4343

44-
class BetaQuotient(distribution.Distribution):
44+
class BetaQuotient(distribution.AutoCompositeTensorDistribution):
4545
"""BetaQuotient distribution.
4646
4747
The Beta Quotient distribution is defined over the positive reals, as

tensorflow_probability/python/distributions/binomial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def _random_binomial(
266266
return sampler_impl(**params)
267267

268268

269-
class Binomial(distribution.Distribution):
269+
class Binomial(distribution.AutoCompositeTensorDistribution):
270270
"""Binomial distribution.
271271
272272
This distribution is parameterized by `probs`, a (batch of) probabilities for

tensorflow_probability/python/distributions/categorical.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def _broadcast_cat_event_and_params(event, params, base_dtype):
6262
return event, params
6363

6464

65-
class Categorical(distribution.Distribution):
65+
class Categorical(distribution.AutoCompositeTensorDistribution):
6666
"""Categorical distribution over integers.
6767
6868
The Categorical distribution is parameterized by either probabilities or

tensorflow_probability/python/distributions/cauchy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
]
4141

4242

43-
class Cauchy(distribution.Distribution):
43+
class Cauchy(distribution.AutoCompositeTensorDistribution):
4444
"""The Cauchy distribution with location `loc` and scale `scale`.
4545
4646
#### Mathematical details

tensorflow_probability/python/distributions/chi.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from tensorflow_probability.python.bijectors import softplus as softplus_bijector
2727
from tensorflow_probability.python.bijectors import square as square_bijector
2828
from tensorflow_probability.python.distributions import chi2
29+
from tensorflow_probability.python.distributions import distribution
2930
from tensorflow_probability.python.distributions import kullback_leibler
3031
from tensorflow_probability.python.distributions import transformed_distribution
3132
from tensorflow_probability.python.internal import assert_util
@@ -34,7 +35,10 @@
3435
from tensorflow_probability.python.internal import tensor_util
3536

3637

37-
class Chi(transformed_distribution.TransformedDistribution):
38+
# TODO(b/182603117): Remove `AutoCompositeTensor` subclass when
39+
# `TransformedDistribution` is converted to `CompositeTensor`.
40+
class Chi(transformed_distribution.TransformedDistribution,
41+
distribution.AutoCompositeTensorDistribution):
3842
"""Chi distribution.
3943
4044
The Chi distribution is defined over nonnegative real numbers and uses a

0 commit comments

Comments
 (0)