Skip to content

Commit 677fd25

Browse files
JXRivertensorflower-gardener
authored andcommitted
Prepare tensorflow_probablity to make ResourceVariable as CompositeTensor.
PiperOrigin-RevId: 464132524
1 parent 3800556 commit 677fd25

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

tensorflow_probability/python/internal/loop_util.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ def _initialize_arrays(initial_values,
4545
lambda ta, t: ta.write(0, t), trace_arrays, initial_values)
4646

4747

48+
def _convert_variables_to_tensors(values):
49+
"""Read `tf.Variables` in `values` and keep other objects unchanged."""
50+
return tf.nest.map_structure(
51+
lambda x: tf.convert_to_tensor(x) if isinstance(x, tf.Variable) else x,
52+
values)
53+
54+
4855
def smart_for_loop(loop_num_iter, body_fn, initial_loop_vars,
4956
parallel_iterations=10, unroll_threshold=1, name=None):
5057
"""Construct a for loop, preferring a python loop if `n` is statically known.
@@ -127,7 +134,7 @@ def trace_scan(loop_fn,
127134
elems: A `Tensor` that is split along the first dimension and each element
128135
of which is passed to `loop_fn`.
129136
trace_fn: A callable that takes in the return value of `loop_fn` and returns
130-
a `Tensor` or a nested collection of `Tensor`s.
137+
a `Tensor`, 'Variable' or a nested collection of `Tensor`s or 'Variable's.
131138
trace_criterion_fn: Optional callable that takes in the return value of
132139
`loop_fn` and returns a boolean `Tensor` indicating whether to trace it.
133140
If `None`, all steps are traced.
@@ -182,7 +189,8 @@ def trace_scan(loop_fn,
182189
dynamic_size, initial_size = False, length
183190
else:
184191
dynamic_size, initial_size = True, 0
185-
initial_trace = trace_fn(initial_state)
192+
# Convert variables returned by trace_fn to tensors.
193+
initial_trace = _convert_variables_to_tensors(trace_fn(initial_state))
186194
flat_initial_trace = tf.nest.flatten(initial_trace, expand_composites=True)
187195
trace_arrays = []
188196
for trace_elt in flat_initial_trace:
@@ -195,9 +203,9 @@ def trace_scan(loop_fn,
195203

196204
# Helper for writing a (structured) state to (structured) arrays.
197205
def trace_one_step(num_steps_traced, trace_arrays, state):
198-
return [ta.write(num_steps_traced, x) for ta, x in
199-
zip(trace_arrays,
200-
tf.nest.flatten(trace_fn(state), expand_composites=True))]
206+
trace = _convert_variables_to_tensors(trace_fn(state))
207+
return [ta.write(num_steps_traced, x) for ta, x in zip(
208+
trace_arrays, tf.nest.flatten(trace, expand_composites=True))]
201209

202210
def _body(i, state, num_steps_traced, trace_arrays):
203211
elem = elems_array.read(i)

tensorflow_probability/python/layers/internal/distribution_tensor_coercible_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,9 @@ def testTypeObjectLeakage(self):
292292
layer = tfp.layers.DistributionLambda(tfp.distributions.Categorical)
293293
x = tf.constant([-.23, 1.23, 1.42])
294294
dist = layer(x)
295+
# Investigate why the second layer call creates a few more weakrefs.
296+
# These weakrefs could potentially come from tf.function.variables.
297+
dist = layer(x)
295298
gc.collect()
296299
before_objs = len(gc.get_objects())
297300
for _ in range(int(1e2)):

0 commit comments

Comments
 (0)