Skip to content

Commit 8cf5587

Browse files
emilyfertigtensorflower-gardener
authored andcommitted
dtype_util.common_dtype handles nested dtypes.
PiperOrigin-RevId: 476199135
1 parent 1fa43ea commit 8cf5587

File tree

2 files changed

+120
-12
lines changed

2 files changed

+120
-12
lines changed

tensorflow_probability/python/internal/dtype_util.py

Lines changed: 84 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -84,27 +84,101 @@ def __repr__(self):
8484

8585

8686
def common_dtype(args, dtype_hint=None):
87-
"""Returns explict dtype from `args` if there is one."""
87+
"""Returns (nested) explict dtype from `args` if there is one.
88+
89+
Dtypes of all args, and `dtype_hint`, must have the same nested structure if
90+
they are not `None`. `args` itself may be any nested structure; its
91+
structure is flattened and ignored.
92+
93+
Args:
94+
args: A nested structure of objects that may have `dtype`.
95+
dtype_hint: Optional (nested) dtype containing defaults to use in place of
96+
`None`. If `dtype_hint` is not nested and the common dtype of `args` is
97+
nested, `dtype_hint` serves as the default for each element of the common
98+
nested dtype structure.
99+
100+
Returns:
101+
dtype: The (nested) dtype common across all elements of `args`, or `None`.
102+
103+
#### Examples
104+
105+
Usage with non-nested dtype:
106+
107+
```python
108+
x = tf.ones([3, 4], dtype=tf.float64)
109+
y = 4.
110+
z = None
111+
common_dtype([x, y, z], dtype_hint=tf.float32) # ==> tf.float64
112+
common_dtype([y, z], dtype_hint=tf.float32) # ==> tf.float32
113+
114+
# The arg to `common_dtype` can be an arbitrary nested structure; it is
115+
# flattened, and the common dtype of its contents is returned.
116+
common_dtype({'x': x, 'yz': (y, z)})
117+
# ==> tf.float64
118+
```
119+
120+
Usage with nested dtype:
121+
122+
```python
123+
# Define `x` and `y` as JointDistributions with the same nested dtype.
124+
x = tfd.JointDistributionNamed(
125+
{'a': tfd.Uniform(np.float64(0.), 1.),
126+
'b': tfd.JointDistributionSequential(
127+
[tfd.Normal(0., 2.), tfd.Bernoulli(0.4)])})
128+
x.dtype # ==> {'a': tf.float64, 'b': [tf.float32, tf.int32]}
129+
130+
y = tfd.JointDistributionNamed(
131+
{'a': tfd.LogitNormal(np.float64(0.), 1.),
132+
'b': tfd.JointDistributionSequential(
133+
[tfd.Normal(-1., 1.), tfd.Bernoulli(0.6)])})
134+
y.dtype # ==> {'a': tf.float64, 'b': [tf.float32, tf.int32]}
135+
136+
# Pack x and y into an arbitrary nested structure and pass it to
137+
# `common_dtype`.
138+
args0 = [x, y]
139+
common_dtype(args0) # ==> {'a': tf.float64, 'b': [tf.float32, tf.int32]}
140+
141+
# The nested structure of the argument to `common_dtype` is flattened and
142+
# ignored; only the nested structures of the dtypes are relevant.
143+
args1 = {'x': x, 'yz': {'y': y, 'z': None}}
144+
common_dtype(args1) # ==> {'a': tf.float64, 'b': [tf.float32, tf.int32]}
145+
```
146+
"""
147+
148+
def _unify_dtype(current, new):
149+
if current is not None and new is not None and current != new:
150+
if SKIP_DTYPE_CHECKS:
151+
return (np.ones([2], dtype) + np.ones([2], dt)).dtype
152+
raise TypeError
153+
return new if current is None else current
154+
88155
dtype = None
89156
flattened_args = tf.nest.flatten(args)
90157
seen = [_NOT_YET_SEEN] * len(flattened_args)
91158
for i, a in enumerate(flattened_args):
92159
if hasattr(a, 'dtype') and a.dtype:
93-
dt = as_numpy_dtype(a.dtype)
160+
dt = tf.nest.map_structure(
161+
lambda d: d if d is None else as_numpy_dtype(d), a.dtype)
94162
seen[i] = dt
95163
else:
96164
seen[i] = None
97165
continue
98166
if dtype is None:
99167
dtype = dt
100-
elif dtype != dt:
101-
if SKIP_DTYPE_CHECKS:
102-
dtype = (np.ones([2], dtype) + np.ones([2], dt)).dtype
103-
else:
104-
raise TypeError(
105-
'Found incompatible dtypes, {} and {}. Seen so far: {}'.format(
106-
dtype, dt, tf.nest.pack_sequence_as(args, seen)))
107-
return base_dtype(dtype_hint) if dtype is None else base_dtype(dtype)
168+
try:
169+
dtype = tf.nest.map_structure(_unify_dtype, dtype, dt)
170+
except TypeError:
171+
raise TypeError(
172+
'Found incompatible dtypes, {} and {}. Seen so far: {}'.format(
173+
dtype, dt, tf.nest.pack_sequence_as(args, seen))) from None
174+
if dtype_hint is None:
175+
return tf.nest.map_structure(base_dtype, dtype)
176+
if dtype is None:
177+
return tf.nest.map_structure(base_dtype, dtype_hint)
178+
if tf.nest.is_nested(dtype) and not tf.nest.is_nested(dtype_hint):
179+
dtype_hint = tf.nest.map_structure(lambda _: dtype_hint, dtype)
180+
return tf.nest.map_structure(
181+
lambda dt, h: base_dtype(h if dt is None else dt), dtype, dtype_hint)
108182

109183

110184
def convert_to_dtype(tensor_or_dtype, dtype=None, dtype_hint=None):

tensorflow_probability/python/internal/dtype_util_test.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# ============================================================================
1515
"""Tests for dtype_util."""
1616

17-
import collections
17+
import dataclasses
1818

1919
# Dependency imports
2020
from absl.testing import parameterized
@@ -51,16 +51,50 @@ def testCommonDtypeAcceptsNone(self):
5151
tf.float16, dtype_util.common_dtype(
5252
[x, None], dtype_hint=tf.float32))
5353

54-
fake_tensor = collections.namedtuple('fake_tensor', ['dtype'])
54+
fake_tensor = dataclasses.make_dataclass('fake_tensor', ['dtype'])
5555
self.assertEqual(
5656
tf.float16, dtype_util.common_dtype(
5757
[fake_tensor(dtype=None), None, x], dtype_hint=tf.float32))
58+
self.assertEqual(
59+
tf.float32, dtype_util.common_dtype(
60+
[fake_tensor(dtype=tf.float32), None]))
5861

5962
def testCommonDtypeFromLinop(self):
6063
x = tf.linalg.LinearOperatorDiag(tf.ones(3, tf.float16))
6164
self.assertEqual(
6265
tf.float16, dtype_util.common_dtype([x], dtype_hint=tf.float32))
6366

67+
def testCommonStructuredDtype(self):
68+
structured_dtype_obj = dataclasses.make_dataclass(
69+
'structured_dtype_obj', ['dtype'])
70+
x = structured_dtype_obj({'a': tf.float32, 'b': (None, None)})
71+
y = structured_dtype_obj({'a': None, 'b': (None, tf.float64)})
72+
z = structured_dtype_obj({'a': None, 'b': (None, None)})
73+
w = structured_dtype_obj(None)
74+
75+
# Check that structured dtypes unify correctly.
76+
self.assertAllEqualNested(
77+
dtype_util.common_dtype([w, x, y, z]),
78+
{'a': tf.float32, 'b': (None, tf.float64)})
79+
80+
# Check that dict `args` works and that `dtype_hint` works.
81+
dtype_hint = {'a': tf.int32, 'b': (tf.int32, None)}
82+
self.assertAllEqualNested(
83+
dtype_util.common_dtype(
84+
{'x': x, 'y': y, 'z': z}, dtype_hint=dtype_hint),
85+
{'a': tf.float32, 'b': (tf.int32, tf.float64)})
86+
self.assertAllEqualNested(
87+
dtype_util.common_dtype([w], dtype_hint=dtype_hint),
88+
dtype_hint)
89+
90+
# Check that non-nested dtype_hint broadcasts.
91+
self.assertAllEqualNested(
92+
dtype_util.common_dtype([y, z], dtype_hint=tf.int32),
93+
{'a': tf.int32, 'b': (tf.int32, tf.float64)})
94+
95+
with self.assertRaisesRegex(TypeError, 'Found incompatible dtypes'):
96+
dtype_util.common_dtype([x, structured_dtype_obj(dtype_hint)])
97+
6498
@parameterized.named_parameters(
6599
dict(testcase_name='Float32', dtype=tf.float32,
66100
expected_minval=np.float32(-3.4028235e+38)),

0 commit comments

Comments
 (0)