Skip to content

Commit 9efb5fb

Browse files
rchen152tensorflower-gardener
authored andcommitted
Add missing typing.Optional type annotations to function parameters.
PiperOrigin-RevId: 376266982
1 parent 05bb482 commit 9efb5fb

File tree

2 files changed

+30
-30
lines changed

2 files changed

+30
-30
lines changed

spinoffs/fun_mc/fun_mc/fun_mc_lib.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -567,8 +567,8 @@ def maybe_broadcast_structure(from_structure: 'Any',
567567
def reparameterize_potential_fn(
568568
potential_fn: 'PotentialFn',
569569
transport_map_fn: 'TransportMap',
570-
init_state: 'State' = None,
571-
state_structure: 'Any' = None,
570+
init_state: 'Optional[State]' = None,
571+
state_structure: 'Optional[Any]' = None,
572572
track_volume: 'bool' = True,
573573
) -> 'Tuple[PotentialFn, Optional[State]]':
574574
"""Performs a change of variables of a potential function.
@@ -651,7 +651,7 @@ def wrapper(*args, **kwargs):
651651

652652
def transform_log_prob_fn(log_prob_fn: 'PotentialFn',
653653
bijector: 'BijectorNest',
654-
init_state: 'State' = None) -> 'Any':
654+
init_state: 'Optional[State]' = None) -> 'Any':
655655
"""Transforms a log-prob function using a bijector.
656656
657657
This takes a log-prob function and creates a new log-prob function that now
@@ -1041,7 +1041,7 @@ def metropolis_hastings_step(
10411041
current_state: 'State',
10421042
proposed_state: 'State',
10431043
energy_change: 'FloatTensor',
1044-
log_uniform: 'FloatTensor' = None,
1044+
log_uniform: 'Optional[FloatTensor]' = None,
10451045
seed=None) -> 'Tuple[State, MetropolisHastingsExtra]':
10461046
"""Metropolis-Hastings step.
10471047
@@ -1086,9 +1086,9 @@ def metropolis_hastings_step(
10861086

10871087

10881088
@util.named_call
1089-
def gaussian_momentum_sample(state: 'State' = None,
1090-
shape: 'IntTensor' = None,
1091-
dtype: 'DTypeNest' = None,
1089+
def gaussian_momentum_sample(state: 'Optional[State]' = None,
1090+
shape: 'Optional[IntTensor]' = None,
1091+
dtype: 'Optional[DTypeNest]' = None,
10921092
seed=None) -> 'State':
10931093
"""Generates a sample from a Gaussian (Normal) momentum distribution.
10941094
@@ -1205,17 +1205,17 @@ def _default_hamiltonian_monte_carlo_energy_change_fn(
12051205
def hamiltonian_monte_carlo_step(
12061206
hmc_state: 'HamiltonianMonteCarloState',
12071207
target_log_prob_fn: 'PotentialFn',
1208-
step_size: 'Any' = None,
1209-
num_integrator_steps: 'IntTensor' = None,
1210-
momentum: 'State' = None,
1211-
kinetic_energy_fn: 'PotentialFn' = None,
1212-
momentum_sample_fn: 'MomentumSampleFn' = None,
1208+
step_size: 'Optional[Any]' = None,
1209+
num_integrator_steps: 'Optional[IntTensor]' = None,
1210+
momentum: 'Optional[State]' = None,
1211+
kinetic_energy_fn: 'Optional[PotentialFn]' = None,
1212+
momentum_sample_fn: 'Optional[MomentumSampleFn]' = None,
12131213
integrator_trace_fn: 'Callable[[IntegratorStepState, IntegratorStepExtras],'
12141214
'TensorNest]' = lambda *args: (),
1215-
log_uniform: 'FloatTensor' = None,
1215+
log_uniform: 'Optional[FloatTensor]' = None,
12161216
integrator_fn=None,
12171217
unroll_integrator: 'bool' = False,
1218-
max_num_integrator_steps: 'IntTensor' = None,
1218+
max_num_integrator_steps: 'Optional[IntTensor]' = None,
12191219
energy_change_fn:
12201220
'Callable[[IntegratorState, IntegratorState, IntegratorExtras], '
12211221
'Tuple[FloatTensor, Any]]' = _default_hamiltonian_monte_carlo_energy_change_fn,
@@ -1298,7 +1298,7 @@ def orig_target_log_prob_fn(x):
12981298

12991299
if kinetic_energy_fn is None:
13001300
kinetic_energy_fn = make_gaussian_kinetic_energy_fn(
1301-
len(target_log_prob.shape) if target_log_prob.shape is not None else tf
1301+
len(target_log_prob.shape) if target_log_prob.shape is not None else tf # pytype: disable=attribute-error
13021302
.rank(target_log_prob))
13031303

13041304
if momentum_sample_fn is None:
@@ -1392,7 +1392,7 @@ def hamiltonian_integrator(
13921392
integrator_trace_fn: 'Callable[[IntegratorStepState, IntegratorStepExtras],'
13931393
'TensorNest]' = lambda *args: (),
13941394
unroll: 'bool' = False,
1395-
max_num_steps: 'IntTensor' = None,
1395+
max_num_steps: 'Optional[IntTensor]' = None,
13961396
) -> 'Tuple[IntegratorState, IntegratorExtras]':
13971397
"""Intergrates a discretized set of Hamiltonian equations.
13981398
@@ -1719,7 +1719,7 @@ def _one_part(state, g, learning_rate):
17191719
def gaussian_proposal(
17201720
state: 'State',
17211721
scale: 'FloatNest' = 1.,
1722-
seed: 'Any' = None) -> 'Tuple[State, Tuple[Tuple[()], float]]':
1722+
seed: 'Optional[Any]' = None) -> 'Tuple[State, Tuple[Tuple[()], float]]':
17231723
"""Axis-aligned gaussian random-walk proposal.
17241724
17251725
Args:
@@ -1757,7 +1757,7 @@ def maximal_reflection_coupling_proposal(
17571757
chain_ndims: 'int' = 0,
17581758
scale: 'FloatNest' = 1,
17591759
epsilon: 'FloatTensor' = 1e-20,
1760-
seed: 'Any' = None
1760+
seed: 'Optional[Any]' = None
17611761
) -> 'Tuple[State, Tuple[MaximalReflectiveCouplingProposalExtra, float]]':
17621762
"""Maximal reflection coupling proposal.
17631763
@@ -1900,7 +1900,7 @@ def random_walk_metropolis_step(
19001900
rwm_state: 'RandomWalkMetropolisState',
19011901
target_log_prob_fn: 'PotentialFn',
19021902
proposal_fn: 'TransitionOperator',
1903-
log_uniform: 'FloatTensor' = None,
1903+
log_uniform: 'Optional[FloatTensor]' = None,
19041904
seed=None) -> 'Tuple[RandomWalkMetropolisState, RandomWalkMetropolisExtra]':
19051905
"""Random Walk Metropolis Hastings `TransitionOperator`.
19061906
@@ -1992,8 +1992,8 @@ def running_variance_init(shape: 'IntTensor',
19921992
def running_variance_step(
19931993
state: 'RunningVarianceState',
19941994
vec: 'FloatNest',
1995-
axis: 'Union[int, List[int], Tuple[int]]' = None,
1996-
window_size: 'IntNest' = None,
1995+
axis: 'Optional[Union[int, List[int], Tuple[int]]]' = None,
1996+
window_size: 'Optional[IntNest]' = None,
19971997
) -> 'Tuple[RunningVarianceState, Tuple[()]]':
19981998
"""Updates the `RunningVarianceState`.
19991999
@@ -2117,8 +2117,8 @@ def running_covariance_init(shape: 'IntTensor',
21172117
def running_covariance_step(
21182118
state: 'RunningCovarianceState',
21192119
vec: 'FloatNest',
2120-
axis: 'Union[int, List[int], Tuple[int]]' = None,
2121-
window_size: 'IntNest' = None,
2120+
axis: 'Optional[Union[int, List[int], Tuple[int]]]' = None,
2121+
window_size: 'Optional[IntNest]' = None,
21222122
) -> 'Tuple[RunningCovarianceState, Tuple[()]]':
21232123
"""Updates the `RunningCovarianceState`.
21242124
@@ -2234,8 +2234,8 @@ def running_mean_init(shape: 'IntTensor',
22342234
def running_mean_step(
22352235
state: 'RunningMeanState',
22362236
vec: 'FloatNest',
2237-
axis: 'Union[int, List[int], Tuple[int]]' = None,
2238-
window_size: 'IntNest' = None,
2237+
axis: 'Optional[Union[int, List[int], Tuple[int]]]' = None,
2238+
window_size: 'Optional[IntNest]' = None,
22392239
) -> 'Tuple[RunningMeanState, Tuple[()]]':
22402240
"""Updates the `RunningMeanState`.
22412241
@@ -2415,7 +2415,7 @@ def running_approximate_auto_covariance_init(
24152415
max_lags: 'int',
24162416
state_shape: 'IntTensor',
24172417
dtype: 'DTypeNest',
2418-
axis: 'Union[int, List[int], Tuple[int]]' = None,
2418+
axis: 'Optional[Union[int, List[int], Tuple[int]]]' = None,
24192419
) -> 'RunningApproximateAutoCovarianceState':
24202420
"""Initializes `RunningApproximateAutoCovarianceState`.
24212421
@@ -2463,7 +2463,7 @@ def _shape_with_lags(shape):
24632463
def running_approximate_auto_covariance_step(
24642464
state: 'RunningApproximateAutoCovarianceState',
24652465
vec: 'TensorNest',
2466-
axis: 'Union[int, List[int], Tuple[int]]' = None,
2466+
axis: 'Optional[Union[int, List[int], Tuple[int]]]' = None,
24672467
) -> 'Tuple[RunningApproximateAutoCovarianceState, Tuple[()]]':
24682468
"""Updates `RunningApproximateAutoCovarianceState`.
24692469
@@ -2568,7 +2568,7 @@ def _one_part(vec, buf, mean, auto_cov):
25682568

25692569

25702570
def make_surrogate_loss_fn(
2571-
grad_fn: 'GradFn' = None,
2571+
grad_fn: 'Optional[GradFn]' = None,
25722572
loss_value: 'tf.Tensor' = 0.,
25732573
) -> 'Any':
25742574
"""Creates a surrogate loss function with specified gradients.

spinoffs/oryx/oryx/experimental/matching/jax_rewrite.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def f(x):
192192
"""
193193
import functools
194194

195-
from typing import Any, Callable, Dict, Iterator, Sequence, Tuple, Union
195+
from typing import Any, Callable, Dict, Iterator, Optional, Sequence, Tuple, Union
196196

197197
import dataclasses
198198
import jax
@@ -380,7 +380,7 @@ def __str__(self):
380380
class Params(matcher.Pattern):
381381
"""An immutable dictionary used to represent parameters of JAX primitives."""
382382

383-
def __init__(self, params: Dict[str, Any] = None, **kwargs: Any):
383+
def __init__(self, params: Optional[Dict[str, Any]] = None, **kwargs: Any):
384384
"""The constructor for a `Params` object.
385385
386386
A `Params` object is an immutable dictionary, meant to encapsulate the

0 commit comments

Comments
 (0)