Skip to content

Commit 988f023

Browse files
jburnimtensorflower-gardener
authored andcommitted
Add Python 3.12 support and add better error msg for missing tf_keras
With TF 2.16, users of TFP-on-TF must install `tf-keras` in addition to `tensorflow` -- so this change adds a custom error message if `tf-keras` (or `tf-keras-nightly`) is not installed. PiperOrigin-RevId: 614280621
1 parent 9a14b9b commit 988f023

File tree

3 files changed

+25
-4
lines changed

3 files changed

+25
-4
lines changed

.github/workflows/continuous-integration.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
runs-on: ubuntu-latest
2121
strategy:
2222
matrix:
23-
python-version: [3.9]
23+
python-version: [3.12]
2424
steps:
2525
- name: Checkout
2626
uses: actions/checkout@v1
@@ -38,7 +38,7 @@ jobs:
3838
runs-on: ubuntu-latest
3939
strategy:
4040
matrix:
41-
python-version: [3.9]
41+
python-version: [3.12]
4242
shard: [0, 1, 2, 3, 4]
4343
env:
4444
SHARD: ${{ matrix.shard }}

setup.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@
4848
else:
4949
TFDS_PACKAGE = 'tfds-nightly'
5050

51+
if release:
52+
TF_PACKAGE = 'tensorflow >= 2.15'
53+
KERAS_PACKAGE = 'tf-keras >= 2.15'
54+
else:
55+
TF_PACKAGE = 'tf-nightly'
56+
KERAS_PACKAGE = 'tf-keras-nightly'
57+
5158

5259
class BinaryDistribution(Distribution):
5360
"""This class is needed in order to create OS specific wheels."""
@@ -91,6 +98,7 @@ def has_ext_modules(self):
9198
'Programming Language :: Python :: 3.9',
9299
'Programming Language :: Python :: 3.10',
93100
'Programming Language :: Python :: 3.11',
101+
'Programming Language :: Python :: 3.12',
94102
'Topic :: Scientific/Engineering',
95103
'Topic :: Scientific/Engineering :: Mathematics',
96104
'Topic :: Scientific/Engineering :: Artificial Intelligence',
@@ -101,6 +109,7 @@ def has_ext_modules(self):
101109
keywords='tensorflow probability statistics bayesian machine learning',
102110
extras_require={ # e.g. `pip install tfp-nightly[jax]`
103111
'jax': ['jax', 'jaxlib'],
112+
'tf': [TF_PACKAGE, KERAS_PACKAGE],
104113
'tfds': [TFDS_PACKAGE],
105114
}
106115
)

tensorflow_probability/python/__init__.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def _validate_tf_environment(package):
3535
inadequate.
3636
"""
3737
try:
38-
import tensorflow.compat.v1 as tf
38+
import tensorflow as tf
3939
except (ImportError, ModuleNotFoundError):
4040
# Print more informative error message, then reraise.
4141
print('\n\nFailed to import TensorFlow. Please note that TensorFlow is not '
@@ -51,7 +51,7 @@ def _validate_tf_environment(package):
5151
#
5252
# Update this whenever we need to depend on a newer TensorFlow release.
5353
#
54-
required_tensorflow_version = '2.14'
54+
required_tensorflow_version = '2.15'
5555
# required_tensorflow_version = '1.15' # Needed internally -- DisableOnExport
5656

5757
if (distutils.version.LooseVersion(tf.__version__) <
@@ -74,6 +74,18 @@ def _validate_tf_environment(package):
7474
'For more detail, see https://github.com/tensorflow/community/pull/287.'
7575
)
7676

77+
if required_tensorflow_version[0] == '2':
78+
try:
79+
import tf_keras # pylint: disable=unused-import
80+
except (ImportError, ModuleNotFoundError):
81+
# Print more informative error message, then reraise.
82+
print('\n\nFailed to import TF-Keras. Please note that TF-Keras is not '
83+
'installed by default when you install TensorFlow Probability. '
84+
'This is so that JAX-only users do not have to install TensorFlow '
85+
'or TF-Keras. To use TensorFlow Probability with TensorFlow, '
86+
'please install the tf-keras or tf-keras-nightly package.\n\n')
87+
raise
88+
7789

7890
# Declare these explicitly to appease pytype, which otherwise misses them,
7991
# presumably due to lazy loading.

0 commit comments

Comments
 (0)