17
17
from __future__ import division
18
18
from __future__ import print_function
19
19
20
+ import abc
20
21
import functools
21
22
import numpy as np
22
23
import six
@@ -81,10 +82,10 @@ def _tensorize(d, dtype=None, name=None, as_ref=False):
81
82
return d ._value (dtype , name , as_ref ) # pylint: disable=protected-access
82
83
83
84
84
- class TensorMetaClass (type ):
85
+ class TensorMetaClass (abc . ABCMeta ):
85
86
"""A type of class which will make objects which act like Tensors."""
86
87
87
- def __new__ (mcs , name , bases , attrs ):
88
+ def __new__ (mcs , name , bases , attrs ): # pylint: disable=bad-mcs-classmethod-argument
88
89
operators = set (tf .Tensor .OVERLOADABLE_OPERATORS )
89
90
operators .difference_update ({'__eq__' , '__ne__' })
90
91
operators .update ({'__iter__' })
@@ -109,15 +110,16 @@ def __new__(mcs, name, bases, attrs):
109
110
attrs .update (
110
111
(attr , getattr (tf .Tensor , attr ))
111
112
for attr in {'__bool__' , '__array_priority__' , '__nonzero__' })
112
- cls = super (TensorMetaClass , mcs ).__new__ (mcs , name , bases , attrs )
113
+ cls = super (TensorMetaClass , mcs ).__new__ (mcs , name , bases , attrs ) # pylint: disable=too-many-function-args
113
114
tf .register_tensor_conversion_function (cls , conversion_func = _tensorize )
114
115
return cls
115
116
116
117
117
118
NONE_SPECIFIED = 'None'
118
119
119
120
120
- class DeferredTensor (six .with_metaclass (TensorMetaClass , tf .Module )):
121
+ class DeferredTensor (six .with_metaclass (
122
+ TensorMetaClass , tf .Module , tf .__internal__ .CompositeTensor )):
121
123
"""Variable tracking object which applies function upon `convert_to_tensor`.
122
124
123
125
#### Example
@@ -379,6 +381,39 @@ def __array__(self, dtype=None):
379
381
'numpy array.' )
380
382
return np .array (self ._value (dtype = dtype ))
381
383
384
+ def _get_input_spec (self ):
385
+ if isinstance (self .pretransformed_input , tf .__internal__ .CompositeTensor ):
386
+ return self .pretransformed_input ._type_spec # pylint: disable=protected-access
387
+ if isinstance (self .pretransformed_input , tf .Variable ):
388
+ return resource_variable_ops .VariableSpec (
389
+ self .pretransformed_input .shape ,
390
+ dtype = self .pretransformed_input .dtype ,
391
+ trainable = self .pretransformed_input .trainable )
392
+ return tf .TensorSpec .from_tensor (self .pretransformed_input )
393
+
394
+ @property
395
+ def _type_spec (self ):
396
+ input_spec = self ._get_input_spec ()
397
+ transform_or_spec = getattr (self ._transform_fn , '_type_spec' ,
398
+ self ._transform_fn )
399
+
400
+ # Extract Variables from also_track.
401
+ if self .also_track is None :
402
+ also_track_spec = None
403
+ else :
404
+ also_track_vars = tf .nest .flatten (
405
+ tf .nest .map_structure (
406
+ lambda x : x .variables if isinstance (x , tf .Module ) else x ,
407
+ self .also_track ))
408
+ also_track_spec = tf .nest .map_structure (
409
+ lambda x : resource_variable_ops .VariableSpec ( # pylint: disable=g-long-lambda
410
+ x .shape , x .dtype , trainable = x .trainable ),
411
+ also_track_vars )
412
+
413
+ return _DeferredTensorSpec (
414
+ input_spec , transform_or_spec , dtype = self .dtype , shape = self .shape ,
415
+ name = self .name , also_track_spec = also_track_spec )
416
+
382
417
383
418
class TransformedVariable (DeferredTensor ):
384
419
"""Variable tracking object which applies a bijector upon `convert_to_tensor`.
@@ -455,11 +490,7 @@ def __init__(self, initial_value, bijector, dtype=None, name=None, **kwargs):
455
490
which is the initial value for the `TransformedVariable`. The underlying
456
491
untransformed `tf.Variable` will be initialized with
457
492
`bijector.inverse(initial_value)`. Can also be a callable with no
458
- argument that returns the initial value when called. Note: if
459
- `initial_value` is a `TransformedVariable` then the instantiated object
460
- does not create a new `tf.Variable`, but rather points to the underlying
461
- `Variable` and chains the `bijector` arg with the underlying bijector as
462
- `tfb.Chain([bijector, initial_value.bijector])`.
493
+ argument that returns the initial value when called.
463
494
bijector: A `Bijector`-like instance which defines the transformations
464
495
applied to the underlying `tf.Variable`.
465
496
dtype: `tf.dtype.DType` instance or otherwise valid `dtype` value to
@@ -479,16 +510,25 @@ def __init__(self, initial_value, bijector, dtype=None, name=None, **kwargs):
479
510
480
511
if callable (initial_value ):
481
512
initial_value = initial_value ()
482
- initial_value = tf .convert_to_tensor (
483
- initial_value , dtype_hint = bijector .dtype , dtype = dtype )
513
+
514
+ # Extra kwarg that TypeSpec._from_components uses to re-build the object
515
+ # without re-initializing the variable.
516
+ pretransformed_input = kwargs .pop ('pretransformed_input' , None )
517
+ if pretransformed_input is None :
518
+ initial_value = tf .convert_to_tensor (
519
+ initial_value , dtype_hint = bijector .dtype , dtype = dtype )
520
+ pretransformed_input = tf .Variable (
521
+ initial_value = bijector .inverse (initial_value ),
522
+ name = name ,
523
+ dtype = dtype ,
524
+ ** kwargs )
525
+ shape = initial_value .shape
526
+ else :
527
+ shape = bijector .forward_event_shape (pretransformed_input .shape )
484
528
super (TransformedVariable , self ).__init__ (
485
- pretransformed_input = tf .Variable (
486
- initial_value = bijector .inverse (initial_value ),
487
- name = name ,
488
- dtype = dtype ,
489
- ** kwargs ),
529
+ pretransformed_input = pretransformed_input ,
490
530
transform_fn = bijector ,
491
- shape = initial_value . shape ,
531
+ shape = shape ,
492
532
name = bijector .name )
493
533
self ._bijector = bijector
494
534
@@ -529,6 +569,13 @@ def assign_sub(self, value, use_locking=False, name=None, read_value=True):
529
569
name = name ,
530
570
read_value = read_value )
531
571
572
+ @property
573
+ def _type_spec (self ):
574
+ input_spec = self ._get_input_spec ()
575
+ transform_or_spec = getattr (self .bijector , '_type_spec' , self .bijector )
576
+ return _TransformedVariableSpec (
577
+ input_spec , transform_or_spec , self .dtype , self .name )
578
+
532
579
533
580
class _DeferredTensorSpecBase (object ):
534
581
"""Common methods for '_DeferredTensorSpec' and '_TransformedVariableSpec."""
0 commit comments