@@ -84,27 +84,101 @@ def __repr__(self):
84
84
85
85
86
86
def common_dtype (args , dtype_hint = None ):
87
- """Returns explict dtype from `args` if there is one."""
87
+ """Returns (nested) explict dtype from `args` if there is one.
88
+
89
+ Dtypes of all args, and `dtype_hint`, must have the same nested structure if
90
+ they are not `None`. `args` itself may be any nested structure; its
91
+ structure is flattened and ignored.
92
+
93
+ Args:
94
+ args: A nested structure of objects that may have `dtype`.
95
+ dtype_hint: Optional (nested) dtype containing defaults to use in place of
96
+ `None`. If `dtype_hint` is not nested and the common dtype of `args` is
97
+ nested, `dtype_hint` serves as the default for each element of the common
98
+ nested dtype structure.
99
+
100
+ Returns:
101
+ dtype: The (nested) dtype common across all elements of `args`, or `None`.
102
+
103
+ #### Examples
104
+
105
+ Usage with non-nested dtype:
106
+
107
+ ```python
108
+ x = tf.ones([3, 4], dtype=tf.float64)
109
+ y = 4.
110
+ z = None
111
+ common_dtype([x, y, z], dtype_hint=tf.float32) # ==> tf.float64
112
+ common_dtype([y, z], dtype_hint=tf.float32) # ==> tf.float32
113
+
114
+ # The arg to `common_dtype` can be an arbitrary nested structure; it is
115
+ # flattened, and the common dtype of its contents is returned.
116
+ common_dtype({'x': x, 'yz': (y, z)})
117
+ # ==> tf.float64
118
+ ```
119
+
120
+ Usage with nested dtype:
121
+
122
+ ```python
123
+ # Define `x` and `y` as JointDistributions with the same nested dtype.
124
+ x = tfd.JointDistributionNamed(
125
+ {'a': tfd.Uniform(np.float64(0.), 1.),
126
+ 'b': tfd.JointDistributionSequential(
127
+ [tfd.Normal(0., 2.), tfd.Bernoulli(0.4)])})
128
+ x.dtype # ==> {'a': tf.float64, 'b': [tf.float32, tf.int32]}
129
+
130
+ y = tfd.JointDistributionNamed(
131
+ {'a': tfd.LogitNormal(np.float64(0.), 1.),
132
+ 'b': tfd.JointDistributionSequential(
133
+ [tfd.Normal(-1., 1.), tfd.Bernoulli(0.6)])})
134
+ y.dtype # ==> {'a': tf.float64, 'b': [tf.float32, tf.int32]}
135
+
136
+ # Pack x and y into an arbitrary nested structure and pass it to
137
+ # `common_dtype`.
138
+ args0 = [x, y]
139
+ common_dtype(args0) # ==> {'a': tf.float64, 'b': [tf.float32, tf.int32]}
140
+
141
+ # The nested structure of the argument to `common_dtype` is flattened and
142
+ # ignored; only the nested structures of the dtypes are relevant.
143
+ args1 = {'x': x, 'yz': {'y': y, 'z': None}}
144
+ common_dtype(args1) # ==> {'a': tf.float64, 'b': [tf.float32, tf.int32]}
145
+ ```
146
+ """
147
+
148
+ def _unify_dtype (current , new ):
149
+ if current is not None and new is not None and current != new :
150
+ if SKIP_DTYPE_CHECKS :
151
+ return (np .ones ([2 ], dtype ) + np .ones ([2 ], dt )).dtype
152
+ raise TypeError
153
+ return new if current is None else current
154
+
88
155
dtype = None
89
156
flattened_args = tf .nest .flatten (args )
90
157
seen = [_NOT_YET_SEEN ] * len (flattened_args )
91
158
for i , a in enumerate (flattened_args ):
92
159
if hasattr (a , 'dtype' ) and a .dtype :
93
- dt = as_numpy_dtype (a .dtype )
160
+ dt = tf .nest .map_structure (
161
+ lambda d : d if d is None else as_numpy_dtype (d ), a .dtype )
94
162
seen [i ] = dt
95
163
else :
96
164
seen [i ] = None
97
165
continue
98
166
if dtype is None :
99
167
dtype = dt
100
- elif dtype != dt :
101
- if SKIP_DTYPE_CHECKS :
102
- dtype = (np .ones ([2 ], dtype ) + np .ones ([2 ], dt )).dtype
103
- else :
104
- raise TypeError (
105
- 'Found incompatible dtypes, {} and {}. Seen so far: {}' .format (
106
- dtype , dt , tf .nest .pack_sequence_as (args , seen )))
107
- return base_dtype (dtype_hint ) if dtype is None else base_dtype (dtype )
168
+ try :
169
+ dtype = tf .nest .map_structure (_unify_dtype , dtype , dt )
170
+ except TypeError :
171
+ raise TypeError (
172
+ 'Found incompatible dtypes, {} and {}. Seen so far: {}' .format (
173
+ dtype , dt , tf .nest .pack_sequence_as (args , seen ))) from None
174
+ if dtype_hint is None :
175
+ return tf .nest .map_structure (base_dtype , dtype )
176
+ if dtype is None :
177
+ return tf .nest .map_structure (base_dtype , dtype_hint )
178
+ if tf .nest .is_nested (dtype ) and not tf .nest .is_nested (dtype_hint ):
179
+ dtype_hint = tf .nest .map_structure (lambda _ : dtype_hint , dtype )
180
+ return tf .nest .map_structure (
181
+ lambda dt , h : base_dtype (h if dt is None else dt ), dtype , dtype_hint )
108
182
109
183
110
184
def convert_to_dtype (tensor_or_dtype , dtype = None , dtype_hint = None ):
0 commit comments