Skip to content

Commit 8e72c11

Browse files
jburnimtensorflower-gardener
authored andcommitted
Fix failure in windowed_sampling_test.jax in OSS.
PiperOrigin-RevId: 453202511
1 parent f8107c1 commit 8e72c11

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

tensorflow_probability/python/experimental/mcmc/windowed_sampling.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,9 @@ def step_broadcast(step_size):
269269
shard_axis_names = pinned_model.experimental_shard_axis_names
270270
if any(tf.nest.flatten(shard_axis_names)):
271271
shard_axis_names = nest.flatten_up_to(
272-
initial_transformed_position, pinned_model._model_flatten( # pylint: disable=protected-access
273-
shard_axis_names))
272+
initial_transformed_position,
273+
list(pinned_model._model_flatten(shard_axis_names))) # pylint: disable=protected-access
274+
274275
else:
275276
# No active shard axis names
276277
shard_axis_names = None

0 commit comments

Comments
 (0)