Skip to content

Commit ee8fbbe

Browse files
ColCarrolltensorflower-gardener
authored andcommitted
Allow jnp.bfloat16 arrays to be correctly recognized as floats.
PiperOrigin-RevId: 492022573
1 parent 4c03126 commit ee8fbbe

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

tensorflow_probability/python/internal/dtype_util.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@
4747
SKIP_DTYPE_CHECKS = False
4848

4949

50+
_issubdtype = np.issubdtype
51+
if JAX_MODE:
52+
# jnp.issubdtype handles custom JAX types like bfloat16
53+
import jax.numpy as jnp # pylint: disable=g-import-not-at-top
54+
_issubdtype = jnp.issubdtype
55+
56+
5057
def is_numpy_compatible(dtype):
5158
"""Returns if dtype has a corresponding NumPy dtype."""
5259
if JAX_MODE or NUMPY_MODE:
@@ -270,23 +277,23 @@ def is_complex(dtype):
270277
dtype = tf.as_dtype(dtype)
271278
if hasattr(dtype, 'is_complex'):
272279
return dtype.is_complex
273-
return np.issubdtype(np.dtype(dtype), np.complexfloating)
280+
return _issubdtype(np.dtype(dtype), np.complexfloating)
274281

275282

276283
def is_floating(dtype):
277284
"""Returns whether this is a (non-quantized, real) floating point type."""
278285
dtype = tf.as_dtype(dtype)
279286
if hasattr(dtype, 'is_floating'):
280287
return dtype.is_floating
281-
return np.issubdtype(np.dtype(dtype), np.floating)
288+
return _issubdtype(np.dtype(dtype), np.floating)
282289

283290

284291
def is_integer(dtype):
285292
"""Returns whether this is a (non-quantized) integer type."""
286293
dtype = tf.as_dtype(dtype)
287294
if hasattr(dtype, 'is_integer') and not callable(dtype.is_integer):
288295
return dtype.is_integer
289-
return np.issubdtype(np.dtype(dtype), np.integer)
296+
return _issubdtype(np.dtype(dtype), np.integer)
290297

291298

292299
def max(dtype): # pylint: disable=redefined-builtin

tensorflow_probability/python/internal/dtype_util_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from tensorflow_probability.python.internal import test_util
2626

2727

28+
NUMPY_MODE = False
2829
JAX_MODE = False
2930

3031

@@ -281,6 +282,14 @@ def test_size(self):
281282
self.assertEqual(dtype_util.size(np.float32), 4)
282283
self.assertEqual(dtype_util.size(np.float64), 8)
283284

285+
@parameterized.named_parameters(
286+
('float32', tf.float32, True),
287+
('bfloat16', 'bfloat16', True),
288+
('not_int8', tf.int8, False))
289+
def test_is_floating(self, dtype, expected):
290+
if NUMPY_MODE and dtype == 'bfloat16':
291+
self.skipTest('No bfloat16 in numpy')
292+
self.assertEqual(dtype_util.is_floating(dtype), expected)
284293

285294
if __name__ == '__main__':
286295
test_util.main()

0 commit comments

Comments
 (0)