We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f8107c1 commit 8e72c11Copy full SHA for 8e72c11
tensorflow_probability/python/experimental/mcmc/windowed_sampling.py
@@ -269,8 +269,9 @@ def step_broadcast(step_size):
269
shard_axis_names = pinned_model.experimental_shard_axis_names
270
if any(tf.nest.flatten(shard_axis_names)):
271
shard_axis_names = nest.flatten_up_to(
272
- initial_transformed_position, pinned_model._model_flatten( # pylint: disable=protected-access
273
- shard_axis_names))
+ initial_transformed_position,
+ list(pinned_model._model_flatten(shard_axis_names))) # pylint: disable=protected-access
274
+
275
else:
276
# No active shard axis names
277
shard_axis_names = None
0 commit comments