Skip to content

Commit a28c3d5

Browse files
jburnimtensorflower-gardener
authored andcommitted
Use jit_compile= instead of experimental_compile= in tf.function.
PiperOrigin-RevId: 380008718
1 parent 6b36b64 commit a28c3d5

8 files changed

+12
-12
lines changed

tensorflow_probability/examples/jupyter_notebooks/Bayesian_Switchpoint_Analysis.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@
320320
"num_results = 10000\n",
321321
"num_burnin_steps = 3000\n",
322322
"\n",
323-
"@tf.function(autograph=False, experimental_compile=True)\n",
323+
"@tf.function(autograph=False, jit_compile=True)\n",
324324
"def make_chain(target_log_prob_fn):\n",
325325
" kernel = tfp.mcmc.TransformedTransitionKernel(\n",
326326
" inner_kernel=tfp.mcmc.HamiltonianMonteCarlo(\n",

tensorflow_probability/examples/jupyter_notebooks/Eight_Schools.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@
274274
"\n",
275275
"# Improve performance by tracing the sampler using `tf.function`\n",
276276
"# and compiling it using XLA.\n",
277-
"@tf.function(autograph=False, experimental_compile=True)\n",
277+
"@tf.function(autograph=False, jit_compile=True)\n",
278278
"def do_sampling():\n",
279279
" return tfp.mcmc.sample_chain(\n",
280280
" num_results=num_results,\n",

tensorflow_probability/examples/jupyter_notebooks/Gaussian_Process_Latent_Variable_Model.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@
347347
"\n",
348348
"optimizer = tf.optimizers.Adam(learning_rate=1.0)\n",
349349
"\n",
350-
"@tf.function(autograph=False, experimental_compile=True)\n",
350+
"@tf.function(autograph=False, jit_compile=True)\n",
351351
"def train_model():\n",
352352
" with tf.GradientTape() as tape:\n",
353353
" loss_value = loss_fn()\n",

tensorflow_probability/examples/jupyter_notebooks/Gaussian_Process_Regression_In_TFP.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@
542542
"outputs": [],
543543
"source": [
544544
"# Use `tf.function` to trace the loss for more efficient evaluation.\n",
545-
"@tf.function(autograph=False, experimental_compile=False)\n",
545+
"@tf.function(autograph=False, jit_compile=False)\n",
546546
"def target_log_prob(amplitude, length_scale, observation_noise_variance):\n",
547547
" return gp_joint_model.log_prob({\n",
548548
" 'amplitude': amplitude,\n",
@@ -790,7 +790,7 @@
790790
],
791791
"source": [
792792
"# Speed up sampling by tracing with `tf.function`.\n",
793-
"@tf.function(autograph=False, experimental_compile=False)\n",
793+
"@tf.function(autograph=False, jit_compile=False)\n",
794794
"def do_sampling():\n",
795795
" return tfp.mcmc.sample_chain(\n",
796796
" kernel=adaptive_sampler,\n",

tensorflow_probability/examples/jupyter_notebooks/Linear_Mixed_Effects_Models.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -721,7 +721,7 @@
721721
" num_leapfrog_steps=3)\n",
722722
"kernel_results = hmc.bootstrap_results(current_state)\n",
723723
"\n",
724-
"@tf.function(autograph=False, experimental_compile=True)\n",
724+
"@tf.function(autograph=False, jit_compile=True)\n",
725725
"def one_e_step(current_state, kernel_results):\n",
726726
" next_state, next_kernel_results = hmc.one_step(\n",
727727
" current_state=current_state,\n",
@@ -731,7 +731,7 @@
731731
"optimizer = tf.optimizers.Adam(learning_rate=.01)\n",
732732
"\n",
733733
"# Set up M-step (gradient descent).\n",
734-
"@tf.function(autograph=False, experimental_compile=True)\n",
734+
"@tf.function(autograph=False, jit_compile=True)\n",
735735
"def one_m_step(current_state):\n",
736736
" with tf.GradientTape() as tape:\n",
737737
" loss = -target_log_prob_fn(*current_state)\n",
@@ -843,7 +843,7 @@
843843
},
844844
"outputs": [],
845845
"source": [
846-
"@tf.function(autograph=False, experimental_compile=True)\n",
846+
"@tf.function(autograph=False, jit_compile=True)\n",
847847
"def run_k_e_steps(k, current_state, kernel_results):\n",
848848
" _, next_state, next_kernel_results = tf.while_loop(\n",
849849
" cond=lambda i, state, pkr: i \u003c k,\n",

tensorflow_probability/examples/jupyter_notebooks/STS_approximate_inference_for_models_with_non_Gaussian_observations.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@
359359
],
360360
"source": [
361361
"# Speed up sampling by tracing with `tf.function`.\n",
362-
"@tf.function(autograph=False, experimental_compile=True)\n",
362+
"@tf.function(autograph=False, jit_compile=True)\n",
363363
"def do_sampling():\n",
364364
" return tfp.mcmc.sample_chain(\n",
365365
" kernel=adaptive_sampler,\n",

tensorflow_probability/examples/jupyter_notebooks/Structural_Time_Series_Modeling_Case_Studies_Atmospheric_CO2_and_Electricity_Demand.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,7 @@
535535
"\n",
536536
"optimizer = tf.optimizers.Adam(learning_rate=.1)\n",
537537
"# Using fit_surrogate_posterior to build and optimize the variational loss function.\n",
538-
"@tf.function(experimental_compile=True)\n",
538+
"@tf.function(jit_compile=True)\n",
539539
"def train():\n",
540540
" elbo_loss_curve = tfp.vi.fit_surrogate_posterior(\n",
541541
" target_log_prob_fn=co2_model.joint_log_prob(\n",
@@ -980,7 +980,7 @@
980980
"\n",
981981
"optimizer = tf.optimizers.Adam(learning_rate=.1)\n",
982982
"# Using fit_surrogate_posterior to build and optimize the variational loss function.\n",
983-
"@tf.function(experimental_compile=True)\n",
983+
"@tf.function(jit_compile=True)\n",
984984
"def train():\n",
985985
" elbo_loss_curve = tfp.vi.fit_surrogate_posterior(\n",
986986
" target_log_prob_fn=demand_model.joint_log_prob(\n",

tensorflow_probability/python/experimental/sts_gibbs/benchmarks_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_benchmark_sampling_with_xla(self):
8484
samples = tf.function(
8585
gibbs_sampler.fit_with_gibbs_sampling,
8686
autograph=False,
87-
experimental_compile=True)(
87+
jit_compile=True)(
8888
model,
8989
tfp.sts.MaskedTimeSeries(observed_time_series[..., tf.newaxis],
9090
is_missing),

0 commit comments

Comments
 (0)