@@ -65,6 +65,27 @@ def __repr__(self):
65
65
self .shape , self .dtype , self .name )
66
66
67
67
68
+ class VariableSpec (TensorSpec ):
69
+ """Stub for VariableSpec to simplify tests in np/jax backends."""
70
+
71
+ def __init__ (self , shape , dtype = tf .float32 , trainable = True , name = None ):
72
+ super (VariableSpec , self ).__init__ (shape , dtype , name = name )
73
+ self .trainable = trainable
74
+
75
+ @classmethod
76
+ def from_value (cls , variable ):
77
+ return cls (
78
+ variable .shape , variable .dtype , variable .trainable , variable .name )
79
+
80
+ def __eq__ (self , other ):
81
+ return (super (VariableSpec , self ).__eq__ (other )
82
+ and self .trainable == other .trainable )
83
+
84
+ def __repr__ (self ):
85
+ return 'VariableSpec(shape={}, dtype={}, trainable={}, name={})' .format (
86
+ self .shape , self .dtype , self .trainable , self .name )
87
+
88
+
68
89
class LeafList (list ):
69
90
_tfp_nest_expansion_force_leaf = ()
70
91
@@ -297,12 +318,13 @@ def fn(arg1, arg2):
297
318
'b' : TensorSpec ([], tf .int64 , name = 'c2t/b' )}
298
319
},{
299
320
'testcase_name' : '_tensor_with_hint' ,
300
- 'value' : [TensorSpec ([], tf .int32 )],
321
+ 'value' : [TensorSpec ([], tf .int32 , name = 'tensor' )],
301
322
'dtype_hint' : [tf .float32 ],
302
323
'expected' : [TensorSpec ([], tf .int32 , name = 'tensor' )]
303
324
},{
304
325
'testcase_name' : '_tensor_struct' ,
305
- 'value' : [TensorSpec ([], tf .int32 ), TensorSpec ([], tf .float32 )],
326
+ 'value' : [TensorSpec ([], tf .int32 , name = 'tensor' ),
327
+ TensorSpec ([], tf .float32 , name = 'tensor' )],
306
328
'dtype_hint' : [tf .float32 , tf .float32 ],
307
329
'expected' : [TensorSpec ([], tf .int32 , name = 'tensor' ),
308
330
TensorSpec ([], tf .float32 , name = 'tensor_1' )]
@@ -318,20 +340,46 @@ def fn(arg1, arg2):
318
340
'name' : None ,
319
341
'expected' : [TensorSpec ([], tf .float32 , name = 'Const' ),
320
342
TensorSpec ([], tf .float32 , name = 'Const_1' )]
343
+ },{
344
+ 'testcase_name' : '_tensor_and_variable_struct' ,
345
+ 'value' : [TensorSpec ([], tf .int32 , name = 'tensor' ),
346
+ VariableSpec ([], tf .float32 , trainable = False , name = 'variable' )],
347
+ 'dtype_hint' : [tf .float32 , tf .float32 ],
348
+ 'convert_ref' : False ,
349
+ 'expected' : [
350
+ TensorSpec ([], tf .int32 , name = 'tensor' ),
351
+ VariableSpec ([], tf .float32 , trainable = False , name = 'variable:0' )]
352
+ },{
353
+ 'testcase_name' : '_tensor_and_variable_struct_convert_ref' ,
354
+ 'value' : [VariableSpec ([], tf .int32 , name = 'variable' ),
355
+ TensorSpec ([], tf .float32 , name = 'tensor' )],
356
+ 'dtype_hint' : [tf .float32 , tf .float32 ],
357
+ 'expected' : [TensorSpec ([], tf .int32 , name = 'c2t/ReadVariableOp' ),
358
+ TensorSpec ([], tf .float32 , name = 'tensor' )]
321
359
})
322
360
def testConvertToNestedTensor (
323
- self , value , dtype = None , dtype_hint = None , name = 'c2t' , expected = None ):
324
- # Convert specs to tensors
361
+ self , value , dtype = None , dtype_hint = None , name = 'c2t' , convert_ref = True ,
362
+ expected = None ):
363
+ # Convert specs to tensors or variables.
325
364
def maybe_spec_to_tensor (x ):
365
+ if isinstance (x , VariableSpec ):
366
+ return tf .Variable (
367
+ tf .zeros (x .shape , x .dtype ), trainable = x .trainable , name = x .name )
326
368
if isinstance (x , TensorSpec ):
327
- return tf .zeros (x .shape , x .dtype , name = 'tensor' )
369
+ return tf .zeros (x .shape , x .dtype , name = x . name )
328
370
return x
329
371
value = nest .map_structure (maybe_spec_to_tensor , value )
330
372
331
373
# Grab shape/dtype from convert_to_nested_tensor for comparison.
374
+ def spec_from_value (x ):
375
+ if isinstance (x , tf .Variable ):
376
+ return VariableSpec .from_value (x )
377
+ return TensorSpec .from_tensor (x )
378
+
332
379
observed = nest .map_structure (
333
- TensorSpec .from_tensor ,
334
- nest_util .convert_to_nested_tensor (value , dtype , dtype_hint , name = name ))
380
+ spec_from_value ,
381
+ nest_util .convert_to_nested_tensor (
382
+ value , dtype , dtype_hint , convert_ref = convert_ref , name = name ))
335
383
self .assertAllEqualNested (observed , expected )
336
384
337
385
@parameterized .named_parameters ({
0 commit comments