@@ -567,8 +567,8 @@ def maybe_broadcast_structure(from_structure: 'Any',
567
567
def reparameterize_potential_fn (
568
568
potential_fn : 'PotentialFn' ,
569
569
transport_map_fn : 'TransportMap' ,
570
- init_state : 'State' = None ,
571
- state_structure : 'Any' = None ,
570
+ init_state : 'Optional[ State] ' = None ,
571
+ state_structure : 'Optional[ Any] ' = None ,
572
572
track_volume : 'bool' = True ,
573
573
) -> 'Tuple[PotentialFn, Optional[State]]' :
574
574
"""Performs a change of variables of a potential function.
@@ -651,7 +651,7 @@ def wrapper(*args, **kwargs):
651
651
652
652
def transform_log_prob_fn (log_prob_fn : 'PotentialFn' ,
653
653
bijector : 'BijectorNest' ,
654
- init_state : 'State' = None ) -> 'Any' :
654
+ init_state : 'Optional[ State] ' = None ) -> 'Any' :
655
655
"""Transforms a log-prob function using a bijector.
656
656
657
657
This takes a log-prob function and creates a new log-prob function that now
@@ -1041,7 +1041,7 @@ def metropolis_hastings_step(
1041
1041
current_state : 'State' ,
1042
1042
proposed_state : 'State' ,
1043
1043
energy_change : 'FloatTensor' ,
1044
- log_uniform : 'FloatTensor' = None ,
1044
+ log_uniform : 'Optional[ FloatTensor] ' = None ,
1045
1045
seed = None ) -> 'Tuple[State, MetropolisHastingsExtra]' :
1046
1046
"""Metropolis-Hastings step.
1047
1047
@@ -1086,9 +1086,9 @@ def metropolis_hastings_step(
1086
1086
1087
1087
1088
1088
@util .named_call
1089
- def gaussian_momentum_sample (state : 'State' = None ,
1090
- shape : 'IntTensor' = None ,
1091
- dtype : 'DTypeNest' = None ,
1089
+ def gaussian_momentum_sample (state : 'Optional[ State] ' = None ,
1090
+ shape : 'Optional[ IntTensor] ' = None ,
1091
+ dtype : 'Optional[ DTypeNest] ' = None ,
1092
1092
seed = None ) -> 'State' :
1093
1093
"""Generates a sample from a Gaussian (Normal) momentum distribution.
1094
1094
@@ -1205,17 +1205,17 @@ def _default_hamiltonian_monte_carlo_energy_change_fn(
1205
1205
def hamiltonian_monte_carlo_step (
1206
1206
hmc_state : 'HamiltonianMonteCarloState' ,
1207
1207
target_log_prob_fn : 'PotentialFn' ,
1208
- step_size : 'Any' = None ,
1209
- num_integrator_steps : 'IntTensor' = None ,
1210
- momentum : 'State' = None ,
1211
- kinetic_energy_fn : 'PotentialFn' = None ,
1212
- momentum_sample_fn : 'MomentumSampleFn' = None ,
1208
+ step_size : 'Optional[ Any] ' = None ,
1209
+ num_integrator_steps : 'Optional[ IntTensor] ' = None ,
1210
+ momentum : 'Optional[ State] ' = None ,
1211
+ kinetic_energy_fn : 'Optional[ PotentialFn] ' = None ,
1212
+ momentum_sample_fn : 'Optional[ MomentumSampleFn] ' = None ,
1213
1213
integrator_trace_fn : 'Callable[[IntegratorStepState, IntegratorStepExtras],'
1214
1214
'TensorNest]' = lambda * args : (),
1215
- log_uniform : 'FloatTensor' = None ,
1215
+ log_uniform : 'Optional[ FloatTensor] ' = None ,
1216
1216
integrator_fn = None ,
1217
1217
unroll_integrator : 'bool' = False ,
1218
- max_num_integrator_steps : 'IntTensor' = None ,
1218
+ max_num_integrator_steps : 'Optional[ IntTensor] ' = None ,
1219
1219
energy_change_fn :
1220
1220
'Callable[[IntegratorState, IntegratorState, IntegratorExtras], '
1221
1221
'Tuple[FloatTensor, Any]]' = _default_hamiltonian_monte_carlo_energy_change_fn ,
@@ -1298,7 +1298,7 @@ def orig_target_log_prob_fn(x):
1298
1298
1299
1299
if kinetic_energy_fn is None :
1300
1300
kinetic_energy_fn = make_gaussian_kinetic_energy_fn (
1301
- len (target_log_prob .shape ) if target_log_prob .shape is not None else tf
1301
+ len (target_log_prob .shape ) if target_log_prob .shape is not None else tf # pytype: disable=attribute-error
1302
1302
.rank (target_log_prob ))
1303
1303
1304
1304
if momentum_sample_fn is None :
@@ -1392,7 +1392,7 @@ def hamiltonian_integrator(
1392
1392
integrator_trace_fn : 'Callable[[IntegratorStepState, IntegratorStepExtras],'
1393
1393
'TensorNest]' = lambda * args : (),
1394
1394
unroll : 'bool' = False ,
1395
- max_num_steps : 'IntTensor' = None ,
1395
+ max_num_steps : 'Optional[ IntTensor] ' = None ,
1396
1396
) -> 'Tuple[IntegratorState, IntegratorExtras]' :
1397
1397
"""Intergrates a discretized set of Hamiltonian equations.
1398
1398
@@ -1719,7 +1719,7 @@ def _one_part(state, g, learning_rate):
1719
1719
def gaussian_proposal (
1720
1720
state : 'State' ,
1721
1721
scale : 'FloatNest' = 1. ,
1722
- seed : 'Any' = None ) -> 'Tuple[State, Tuple[Tuple[()], float]]' :
1722
+ seed : 'Optional[ Any] ' = None ) -> 'Tuple[State, Tuple[Tuple[()], float]]' :
1723
1723
"""Axis-aligned gaussian random-walk proposal.
1724
1724
1725
1725
Args:
@@ -1757,7 +1757,7 @@ def maximal_reflection_coupling_proposal(
1757
1757
chain_ndims : 'int' = 0 ,
1758
1758
scale : 'FloatNest' = 1 ,
1759
1759
epsilon : 'FloatTensor' = 1e-20 ,
1760
- seed : 'Any' = None
1760
+ seed : 'Optional[ Any] ' = None
1761
1761
) -> 'Tuple[State, Tuple[MaximalReflectiveCouplingProposalExtra, float]]' :
1762
1762
"""Maximal reflection coupling proposal.
1763
1763
@@ -1900,7 +1900,7 @@ def random_walk_metropolis_step(
1900
1900
rwm_state : 'RandomWalkMetropolisState' ,
1901
1901
target_log_prob_fn : 'PotentialFn' ,
1902
1902
proposal_fn : 'TransitionOperator' ,
1903
- log_uniform : 'FloatTensor' = None ,
1903
+ log_uniform : 'Optional[ FloatTensor] ' = None ,
1904
1904
seed = None ) -> 'Tuple[RandomWalkMetropolisState, RandomWalkMetropolisExtra]' :
1905
1905
"""Random Walk Metropolis Hastings `TransitionOperator`.
1906
1906
@@ -1992,8 +1992,8 @@ def running_variance_init(shape: 'IntTensor',
1992
1992
def running_variance_step (
1993
1993
state : 'RunningVarianceState' ,
1994
1994
vec : 'FloatNest' ,
1995
- axis : 'Union[int, List[int], Tuple[int]]' = None ,
1996
- window_size : 'IntNest' = None ,
1995
+ axis : 'Optional[ Union[int, List[int], Tuple[int] ]]' = None ,
1996
+ window_size : 'Optional[ IntNest] ' = None ,
1997
1997
) -> 'Tuple[RunningVarianceState, Tuple[()]]' :
1998
1998
"""Updates the `RunningVarianceState`.
1999
1999
@@ -2117,8 +2117,8 @@ def running_covariance_init(shape: 'IntTensor',
2117
2117
def running_covariance_step (
2118
2118
state : 'RunningCovarianceState' ,
2119
2119
vec : 'FloatNest' ,
2120
- axis : 'Union[int, List[int], Tuple[int]]' = None ,
2121
- window_size : 'IntNest' = None ,
2120
+ axis : 'Optional[ Union[int, List[int], Tuple[int] ]]' = None ,
2121
+ window_size : 'Optional[ IntNest] ' = None ,
2122
2122
) -> 'Tuple[RunningCovarianceState, Tuple[()]]' :
2123
2123
"""Updates the `RunningCovarianceState`.
2124
2124
@@ -2234,8 +2234,8 @@ def running_mean_init(shape: 'IntTensor',
2234
2234
def running_mean_step (
2235
2235
state : 'RunningMeanState' ,
2236
2236
vec : 'FloatNest' ,
2237
- axis : 'Union[int, List[int], Tuple[int]]' = None ,
2238
- window_size : 'IntNest' = None ,
2237
+ axis : 'Optional[ Union[int, List[int], Tuple[int] ]]' = None ,
2238
+ window_size : 'Optional[ IntNest] ' = None ,
2239
2239
) -> 'Tuple[RunningMeanState, Tuple[()]]' :
2240
2240
"""Updates the `RunningMeanState`.
2241
2241
@@ -2415,7 +2415,7 @@ def running_approximate_auto_covariance_init(
2415
2415
max_lags : 'int' ,
2416
2416
state_shape : 'IntTensor' ,
2417
2417
dtype : 'DTypeNest' ,
2418
- axis : 'Union[int, List[int], Tuple[int]]' = None ,
2418
+ axis : 'Optional[ Union[int, List[int], Tuple[int] ]]' = None ,
2419
2419
) -> 'RunningApproximateAutoCovarianceState' :
2420
2420
"""Initializes `RunningApproximateAutoCovarianceState`.
2421
2421
@@ -2463,7 +2463,7 @@ def _shape_with_lags(shape):
2463
2463
def running_approximate_auto_covariance_step (
2464
2464
state : 'RunningApproximateAutoCovarianceState' ,
2465
2465
vec : 'TensorNest' ,
2466
- axis : 'Union[int, List[int], Tuple[int]]' = None ,
2466
+ axis : 'Optional[ Union[int, List[int], Tuple[int] ]]' = None ,
2467
2467
) -> 'Tuple[RunningApproximateAutoCovarianceState, Tuple[()]]' :
2468
2468
"""Updates `RunningApproximateAutoCovarianceState`.
2469
2469
@@ -2568,7 +2568,7 @@ def _one_part(vec, buf, mean, auto_cov):
2568
2568
2569
2569
2570
2570
def make_surrogate_loss_fn (
2571
- grad_fn : 'GradFn' = None ,
2571
+ grad_fn : 'Optional[ GradFn] ' = None ,
2572
2572
loss_value : 'tf.Tensor' = 0. ,
2573
2573
) -> 'Any' :
2574
2574
"""Creates a surrogate loss function with specified gradients.
0 commit comments