@@ -45,6 +45,13 @@ def _initialize_arrays(initial_values,
45
45
lambda ta , t : ta .write (0 , t ), trace_arrays , initial_values )
46
46
47
47
48
+ def _convert_variables_to_tensors (values ):
49
+ """Read `tf.Variables` in `values` and keep other objects unchanged."""
50
+ return tf .nest .map_structure (
51
+ lambda x : tf .convert_to_tensor (x ) if isinstance (x , tf .Variable ) else x ,
52
+ values )
53
+
54
+
48
55
def smart_for_loop (loop_num_iter , body_fn , initial_loop_vars ,
49
56
parallel_iterations = 10 , unroll_threshold = 1 , name = None ):
50
57
"""Construct a for loop, preferring a python loop if `n` is statically known.
@@ -127,7 +134,7 @@ def trace_scan(loop_fn,
127
134
elems: A `Tensor` that is split along the first dimension and each element
128
135
of which is passed to `loop_fn`.
129
136
trace_fn: A callable that takes in the return value of `loop_fn` and returns
130
- a `Tensor` or a nested collection of `Tensor`s.
137
+ a `Tensor`, 'Variable' or a nested collection of `Tensor`s or 'Variable' s.
131
138
trace_criterion_fn: Optional callable that takes in the return value of
132
139
`loop_fn` and returns a boolean `Tensor` indicating whether to trace it.
133
140
If `None`, all steps are traced.
@@ -182,7 +189,8 @@ def trace_scan(loop_fn,
182
189
dynamic_size , initial_size = False , length
183
190
else :
184
191
dynamic_size , initial_size = True , 0
185
- initial_trace = trace_fn (initial_state )
192
+ # Convert variables returned by trace_fn to tensors.
193
+ initial_trace = _convert_variables_to_tensors (trace_fn (initial_state ))
186
194
flat_initial_trace = tf .nest .flatten (initial_trace , expand_composites = True )
187
195
trace_arrays = []
188
196
for trace_elt in flat_initial_trace :
@@ -195,9 +203,9 @@ def trace_scan(loop_fn,
195
203
196
204
# Helper for writing a (structured) state to (structured) arrays.
197
205
def trace_one_step (num_steps_traced , trace_arrays , state ):
198
- return [ ta . write ( num_steps_traced , x ) for ta , x in
199
- zip (trace_arrays ,
200
- tf .nest .flatten (trace_fn ( state ) , expand_composites = True ))]
206
+ trace = _convert_variables_to_tensors ( trace_fn ( state ))
207
+ return [ ta . write ( num_steps_traced , x ) for ta , x in zip (
208
+ trace_arrays , tf .nest .flatten (trace , expand_composites = True ))]
201
209
202
210
def _body (i , state , num_steps_traced , trace_arrays ):
203
211
elem = elems_array .read (i )
0 commit comments