|
51 | 51 | ) |
52 | 52 | from pytensor.tensor.random.op import RandomVariable |
53 | 53 | from pytensor.tensor.random.utils import normalize_size_param |
54 | | -from pytensor.tensor.variable import TensorConstant |
| 54 | +from pytensor.tensor.variable import TensorConstant, TensorVariable |
55 | 55 |
|
56 | 56 | from pymc.logprob.abstract import _logprob_helper |
57 | | -from pymc.logprob.basic import icdf |
| 57 | +from pymc.logprob.basic import TensorLike, icdf |
58 | 58 | from pymc.pytensorf import normalize_rng_param |
59 | 59 |
|
60 | 60 | try: |
@@ -148,7 +148,7 @@ class BoundedContinuous(Continuous): |
148 | 148 | """Base class for bounded continuous distributions.""" |
149 | 149 |
|
150 | 150 | # Indices of the arguments that define the lower and upper bounds of the distribution |
151 | | - bound_args_indices: list[int] | None = None |
| 151 | + bound_args_indices: tuple[int | None, int | None] | None = None |
152 | 152 |
|
153 | 153 |
|
154 | 154 | @_default_transform.register(PositiveContinuous) |
@@ -210,7 +210,9 @@ def assert_negative_support(var, label, distname, value=-1e-6): |
210 | 210 | return Assert(msg)(var, pt.all(pt.ge(var, 0.0))) |
211 | 211 |
|
212 | 212 |
|
213 | | -def get_tau_sigma(tau=None, sigma=None): |
| 213 | +def get_tau_sigma( |
| 214 | + tau: TensorLike | None = None, sigma: TensorLike | None = None |
| 215 | +) -> tuple[TensorVariable, TensorVariable]: |
214 | 216 | r""" |
215 | 217 | Find precision and standard deviation. |
216 | 218 |
|
@@ -239,13 +241,14 @@ def get_tau_sigma(tau=None, sigma=None): |
239 | 241 | sigma = pt.as_tensor_variable(1.0) |
240 | 242 | tau = pt.as_tensor_variable(1.0) |
241 | 243 | elif tau is None: |
| 244 | + assert sigma is not None # Just for type checker |
242 | 245 | sigma = pt.as_tensor_variable(sigma) |
243 | 246 | # Keep tau negative, if sigma was negative, so that it will |
244 | 247 | # fail when used |
245 | 248 | tau = (sigma**-2.0) * pt.sign(sigma) |
246 | 249 | else: |
247 | 250 | tau = pt.as_tensor_variable(tau) |
248 | | - # Keep tau negative, if sigma was negative, so that it will |
| 251 | + # Keep sigma negative, if tau was negative, so that it will |
249 | 252 | # fail when used |
250 | 253 | sigma = pt.abs(tau) ** -0.5 * pt.sign(tau) |
251 | 254 |
|
|
0 commit comments