Skip to content

Commit 6cc612f

Browse files
brianwa84tensorflower-gardener
authored andcommitted
Avoid packing dtypes into ndarrays for comparison (which is what assertAllEqual will do).
PiperOrigin-RevId: 567640584
1 parent dff8111 commit 6cc612f

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

tensorflow_probability/python/internal/dtype_util_test.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,37 +74,44 @@ def testCommonStructuredDtype(self):
7474
w = structured_dtype_obj(None)
7575

7676
# Check that structured dtypes unify correctly.
77-
self.assertAllEqualNested(
77+
self.assertAllAssertsNested(
78+
self.assertEqual,
7879
dtype_util.common_dtype([w, x, y, z]),
7980
{'a': tf.float32, 'b': (None, tf.float64)})
8081

8182
# Check that dict `args` works and that `dtype_hint` works.
8283
dtype_hint = {'a': tf.int32, 'b': (tf.int32, None)}
83-
self.assertAllEqualNested(
84+
self.assertAllAssertsNested(
85+
self.assertEqual,
8486
dtype_util.common_dtype(
8587
{'x': x, 'y': y, 'z': z}, dtype_hint=dtype_hint),
8688
{'a': tf.float32, 'b': (tf.int32, tf.float64)})
87-
self.assertAllEqualNested(
89+
self.assertAllAssertsNested(
90+
self.assertEqual,
8891
dtype_util.common_dtype([w], dtype_hint=dtype_hint),
8992
dtype_hint)
9093

9194
# Check that non-nested dtype_hint broadcasts.
92-
self.assertAllEqualNested(
95+
self.assertAllAssertsNested(
96+
self.assertEqual,
9397
dtype_util.common_dtype([y, z], dtype_hint=tf.int32),
9498
{'a': tf.int32, 'b': (tf.int32, tf.float64)})
9599

96100
# Check that structured `dtype_hint` behaves as expected.
97101
s = {'a': [tf.ones([3], tf.float32), 4.],
98102
'b': (np.float64(2.), None)}
99-
self.assertAllEqualNested(
103+
self.assertAllAssertsNested(
104+
self.assertEqual,
100105
dtype_util.common_dtype([x, s], dtype_hint=z.dtype),
101106
{'a': tf.float32, 'b': (tf.float64, None)})
102-
self.assertAllEqualNested(
107+
self.assertAllAssertsNested(
108+
self.assertEqual,
103109
dtype_util.common_dtype([y, s], dtype_hint=z.dtype),
104110
{'a': tf.float32, 'b': (tf.float64, tf.float64)})
105111

106112
t = {'a': [[1., 2., 3.]], 'b': {'c': np.float64(1.), 'd': np.float64(2.)}}
107-
self.assertAllEqualNested(
113+
self.assertAllAssertsNested(
114+
self.assertEqual,
108115
dtype_util.common_dtype(
109116
[w, t],
110117
dtype_hint={'a': tf.float32, 'b': tf.float32}),

0 commit comments

Comments
 (0)