Skip to content

Commit 355b0bc

Browse files
randolf-scholzpytorchmergebot
authored andcommitted
[typing] Add type hints to @property and @lazy_property in torch.distributions. (pytorch#144110)
Fixes pytorch#76772, pytorch#144196 Extends pytorch#144106 - added type annotations to `lazy_property`. - added type annotation to all `@property` and `@lazy_property` inside `torch.distributions` module. - added simply type-check unit test to ensure type inference is working. - replaced deprecated annotations like `typing.List` with the corresponding counterpart. - simplified `torch.Tensor` hints with plain `Tensor`, otherwise signatures can become very verbose. Pull Request resolved: pytorch#144110 Approved by: https://github.com/Skylion007
1 parent aa69d73 commit 355b0bc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+353
-294
lines changed

test/typing/pass/distributions.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from typing_extensions import assert_type
2+
3+
import torch
4+
from torch import distributions, Tensor
5+
6+
7+
dist = distributions.Normal(0, 1)
8+
assert_type(dist.mean, Tensor)
9+
10+
dist = distributions.MultivariateNormal(torch.zeros(2), torch.eye(2))
11+
assert_type(dist.covariance_matrix, Tensor)

torch/distributions/bernoulli.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from numbers import Number
33

44
import torch
5-
from torch import nan
5+
from torch import nan, Tensor
66
from torch.distributions import constraints
77
from torch.distributions.exp_family import ExponentialFamily
88
from torch.distributions.utils import (
@@ -76,29 +76,29 @@ def _new(self, *args, **kwargs):
7676
return self._param.new(*args, **kwargs)
7777

7878
@property
79-
def mean(self):
79+
def mean(self) -> Tensor:
8080
return self.probs
8181

8282
@property
83-
def mode(self):
83+
def mode(self) -> Tensor:
8484
mode = (self.probs >= 0.5).to(self.probs)
8585
mode[self.probs == 0.5] = nan
8686
return mode
8787

8888
@property
89-
def variance(self):
89+
def variance(self) -> Tensor:
9090
return self.probs * (1 - self.probs)
9191

9292
@lazy_property
93-
def logits(self):
93+
def logits(self) -> Tensor:
9494
return probs_to_logits(self.probs, is_binary=True)
9595

9696
@lazy_property
97-
def probs(self):
97+
def probs(self) -> Tensor:
9898
return logits_to_probs(self.logits, is_binary=True)
9999

100100
@property
101-
def param_shape(self):
101+
def param_shape(self) -> torch.Size:
102102
return self._param.size()
103103

104104
def sample(self, sample_shape=torch.Size()):
@@ -125,7 +125,7 @@ def enumerate_support(self, expand=True):
125125
return values
126126

127127
@property
128-
def _natural_params(self):
128+
def _natural_params(self) -> tuple[Tensor]:
129129
return (torch.logit(self.probs),)
130130

131131
def _log_normalizer(self, x):

torch/distributions/beta.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from numbers import Number, Real
33

44
import torch
5+
from torch import Tensor
56
from torch.distributions import constraints
67
from torch.distributions.dirichlet import Dirichlet
78
from torch.distributions.exp_family import ExponentialFamily
@@ -62,19 +63,19 @@ def expand(self, batch_shape, _instance=None):
6263
return new
6364

6465
@property
65-
def mean(self):
66+
def mean(self) -> Tensor:
6667
return self.concentration1 / (self.concentration1 + self.concentration0)
6768

6869
@property
69-
def mode(self):
70+
def mode(self) -> Tensor:
7071
return self._dirichlet.mode[..., 0]
7172

7273
@property
73-
def variance(self):
74+
def variance(self) -> Tensor:
7475
total = self.concentration1 + self.concentration0
7576
return self.concentration1 * self.concentration0 / (total.pow(2) * (total + 1))
7677

77-
def rsample(self, sample_shape: _size = ()) -> torch.Tensor:
78+
def rsample(self, sample_shape: _size = ()) -> Tensor:
7879
return self._dirichlet.rsample(sample_shape).select(-1, 0)
7980

8081
def log_prob(self, value):
@@ -87,23 +88,23 @@ def entropy(self):
8788
return self._dirichlet.entropy()
8889

8990
@property
90-
def concentration1(self):
91+
def concentration1(self) -> Tensor:
9192
result = self._dirichlet.concentration[..., 0]
9293
if isinstance(result, Number):
9394
return torch.tensor([result])
9495
else:
9596
return result
9697

9798
@property
98-
def concentration0(self):
99+
def concentration0(self) -> Tensor:
99100
result = self._dirichlet.concentration[..., 1]
100101
if isinstance(result, Number):
101102
return torch.tensor([result])
102103
else:
103104
return result
104105

105106
@property
106-
def _natural_params(self):
107+
def _natural_params(self) -> tuple[Tensor, Tensor]:
107108
return (self.concentration1, self.concentration0)
108109

109110
def _log_normalizer(self, x, y):

torch/distributions/binomial.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# mypy: allow-untyped-defs
22
import torch
3+
from torch import Tensor
34
from torch.distributions import constraints
45
from torch.distributions.distribution import Distribution
56
from torch.distributions.utils import (
@@ -92,27 +93,27 @@ def support(self):
9293
return constraints.integer_interval(0, self.total_count)
9394

9495
@property
95-
def mean(self):
96+
def mean(self) -> Tensor:
9697
return self.total_count * self.probs
9798

9899
@property
99-
def mode(self):
100+
def mode(self) -> Tensor:
100101
return ((self.total_count + 1) * self.probs).floor().clamp(max=self.total_count)
101102

102103
@property
103-
def variance(self):
104+
def variance(self) -> Tensor:
104105
return self.total_count * self.probs * (1 - self.probs)
105106

106107
@lazy_property
107-
def logits(self):
108+
def logits(self) -> Tensor:
108109
return probs_to_logits(self.probs, is_binary=True)
109110

110111
@lazy_property
111-
def probs(self):
112+
def probs(self) -> Tensor:
112113
return logits_to_probs(self.logits, is_binary=True)
113114

114115
@property
115-
def param_shape(self):
116+
def param_shape(self) -> torch.Size:
116117
return self._param.size()
117118

118119
def sample(self, sample_shape=torch.Size()):

torch/distributions/categorical.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# mypy: allow-untyped-defs
22
import torch
3-
from torch import nan
3+
from torch import nan, Tensor
44
from torch.distributions import constraints
55
from torch.distributions.distribution import Distribution
66
from torch.distributions.utils import lazy_property, logits_to_probs, probs_to_logits
@@ -94,19 +94,19 @@ def support(self):
9494
return constraints.integer_interval(0, self._num_events - 1)
9595

9696
@lazy_property
97-
def logits(self):
97+
def logits(self) -> Tensor:
9898
return probs_to_logits(self.probs)
9999

100100
@lazy_property
101-
def probs(self):
101+
def probs(self) -> Tensor:
102102
return logits_to_probs(self.logits)
103103

104104
@property
105-
def param_shape(self):
105+
def param_shape(self) -> torch.Size:
106106
return self._param.size()
107107

108108
@property
109-
def mean(self):
109+
def mean(self) -> Tensor:
110110
return torch.full(
111111
self._extended_shape(),
112112
nan,
@@ -115,11 +115,11 @@ def mean(self):
115115
)
116116

117117
@property
118-
def mode(self):
119-
return self.probs.argmax(axis=-1)
118+
def mode(self) -> Tensor:
119+
return self.probs.argmax(dim=-1)
120120

121121
@property
122-
def variance(self):
122+
def variance(self) -> Tensor:
123123
return torch.full(
124124
self._extended_shape(),
125125
nan,

torch/distributions/cauchy.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from numbers import Number
44

55
import torch
6-
from torch import inf, nan
6+
from torch import inf, nan, Tensor
77
from torch.distributions import constraints
88
from torch.distributions.distribution import Distribution
99
from torch.distributions.utils import broadcast_all
@@ -52,22 +52,22 @@ def expand(self, batch_shape, _instance=None):
5252
return new
5353

5454
@property
55-
def mean(self):
55+
def mean(self) -> Tensor:
5656
return torch.full(
5757
self._extended_shape(), nan, dtype=self.loc.dtype, device=self.loc.device
5858
)
5959

6060
@property
61-
def mode(self):
61+
def mode(self) -> Tensor:
6262
return self.loc
6363

6464
@property
65-
def variance(self):
65+
def variance(self) -> Tensor:
6666
return torch.full(
6767
self._extended_shape(), inf, dtype=self.loc.dtype, device=self.loc.device
6868
)
6969

70-
def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor:
70+
def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
7171
shape = self._extended_shape(sample_shape)
7272
eps = self.loc.new(shape).cauchy_()
7373
return self.loc + eps * self.scale

torch/distributions/chi2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# mypy: allow-untyped-defs
2+
from torch import Tensor
23
from torch.distributions import constraints
34
from torch.distributions.gamma import Gamma
45

@@ -31,5 +32,5 @@ def expand(self, batch_shape, _instance=None):
3132
return super().expand(batch_shape, new)
3233

3334
@property
34-
def df(self):
35+
def df(self) -> Tensor:
3536
return self.concentration * 2

torch/distributions/constraints.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,13 @@ def __init__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
119119
super().__init__()
120120

121121
@property
122-
def is_discrete(self):
122+
def is_discrete(self) -> bool: # type: ignore[override]
123123
if self._is_discrete is NotImplemented:
124124
raise NotImplementedError(".is_discrete cannot be determined statically")
125125
return self._is_discrete
126126

127127
@property
128-
def event_dim(self):
128+
def event_dim(self) -> int: # type: ignore[override]
129129
if self._event_dim is NotImplemented:
130130
raise NotImplementedError(".event_dim cannot be determined statically")
131131
return self._event_dim
@@ -233,11 +233,11 @@ def __init__(self, base_constraint, reinterpreted_batch_ndims):
233233
super().__init__()
234234

235235
@property
236-
def is_discrete(self):
236+
def is_discrete(self) -> bool: # type: ignore[override]
237237
return self.base_constraint.is_discrete
238238

239239
@property
240-
def event_dim(self):
240+
def event_dim(self) -> int: # type: ignore[override]
241241
return self.base_constraint.event_dim + self.reinterpreted_batch_ndims
242242

243243
def check(self, value):
@@ -599,11 +599,11 @@ def __init__(self, cseq, dim=0, lengths=None):
599599
super().__init__()
600600

601601
@property
602-
def is_discrete(self):
602+
def is_discrete(self) -> bool: # type: ignore[override]
603603
return any(c.is_discrete for c in self.cseq)
604604

605605
@property
606-
def event_dim(self):
606+
def event_dim(self) -> int: # type: ignore[override]
607607
return max(c.event_dim for c in self.cseq)
608608

609609
def check(self, value):
@@ -631,11 +631,11 @@ def __init__(self, cseq, dim=0):
631631
super().__init__()
632632

633633
@property
634-
def is_discrete(self):
634+
def is_discrete(self) -> bool: # type: ignore[override]
635635
return any(c.is_discrete for c in self.cseq)
636636

637637
@property
638-
def event_dim(self):
638+
def event_dim(self) -> int: # type: ignore[override]
639639
dim = max(c.event_dim for c in self.cseq)
640640
if self.dim + dim < 0:
641641
dim += 1

torch/distributions/continuous_bernoulli.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from numbers import Number
44

55
import torch
6+
from torch import Tensor
67
from torch.distributions import constraints
78
from torch.distributions.exp_family import ExponentialFamily
89
from torch.distributions.utils import (
@@ -52,7 +53,7 @@ class ContinuousBernoulli(ExponentialFamily):
5253

5354
def __init__(
5455
self, probs=None, logits=None, lims=(0.499, 0.501), validate_args=None
55-
):
56+
) -> None:
5657
if (probs is None) == (logits is None):
5758
raise ValueError(
5859
"Either `probs` or `logits` must be specified, but not both."
@@ -127,7 +128,7 @@ def _cont_bern_log_norm(self):
127128
return torch.where(self._outside_unstable_region(), log_norm, taylor)
128129

129130
@property
130-
def mean(self):
131+
def mean(self) -> Tensor:
131132
cut_probs = self._cut_probs()
132133
mus = cut_probs / (2.0 * cut_probs - 1.0) + 1.0 / (
133134
torch.log1p(-cut_probs) - torch.log(cut_probs)
@@ -137,11 +138,11 @@ def mean(self):
137138
return torch.where(self._outside_unstable_region(), mus, taylor)
138139

139140
@property
140-
def stddev(self):
141+
def stddev(self) -> Tensor:
141142
return torch.sqrt(self.variance)
142143

143144
@property
144-
def variance(self):
145+
def variance(self) -> Tensor:
145146
cut_probs = self._cut_probs()
146147
vars = cut_probs * (cut_probs - 1.0) / torch.pow(
147148
1.0 - 2.0 * cut_probs, 2
@@ -151,15 +152,15 @@ def variance(self):
151152
return torch.where(self._outside_unstable_region(), vars, taylor)
152153

153154
@lazy_property
154-
def logits(self):
155+
def logits(self) -> Tensor:
155156
return probs_to_logits(self.probs, is_binary=True)
156157

157158
@lazy_property
158-
def probs(self):
159+
def probs(self) -> Tensor:
159160
return clamp_probs(logits_to_probs(self.logits, is_binary=True))
160161

161162
@property
162-
def param_shape(self):
163+
def param_shape(self) -> torch.Size:
163164
return self._param.size()
164165

165166
def sample(self, sample_shape=torch.Size()):
@@ -168,7 +169,7 @@ def sample(self, sample_shape=torch.Size()):
168169
with torch.no_grad():
169170
return self.icdf(u)
170171

171-
def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor:
172+
def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
172173
shape = self._extended_shape(sample_shape)
173174
u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device)
174175
return self.icdf(u)
@@ -220,7 +221,7 @@ def entropy(self):
220221
)
221222

222223
@property
223-
def _natural_params(self):
224+
def _natural_params(self) -> tuple[Tensor]:
224225
return (self.logits,)
225226

226227
def _log_normalizer(self, x):

0 commit comments

Comments
 (0)