|
477 | 477 | "num_variational_steps = 200 # @param { isTemplate: true}\n", |
478 | 478 | "num_variational_steps = int(num_variational_steps)\n", |
479 | 479 | "\n", |
480 | | - "seed = tfp.random.sanitize_seed(jax.random.PRNGKey(42), salt='fit_stateless') \n", |
481 | | - "init_seed, fit_seed, sample_seed = tfp.random.split_seed(seed, n=3) \n", |
482 | | - "initial_parameters = init_fn(init_seed) \n", |
483 | | - "jd = co2_model.joint_distribution(co2_by_month_training_data) \n", |
| 480 | + "seed = jax.random.PRNGKey(42)\n", |
| 481 | + "init_seed, fit_seed, sample_seed = jax.random.split(seed, 3)\n", |
| 482 | + "initial_parameters = init_fn(init_seed)\n", |
| 483 | + "jd = co2_model.joint_distribution(co2_by_month_training_data)\n", |
484 | 484 | "\n", |
485 | 485 | "# Build and optimize the variational loss function.\n", |
486 | | - "optimized_parameters, elbo_loss_curve = tfp.vi.fit_surrogate_posterior_stateless( \n", |
487 | | - " target_log_prob_fn=jd.log_prob, \n", |
488 | | - " initial_parameters=initial_parameters, \n", |
489 | | - " build_surrogate_posterior_fn=build_surrogate_fn, \n", |
| 486 | + "optimized_parameters, elbo_loss_curve = tfp.vi.fit_surrogate_posterior_stateless(\n", |
| 487 | + " target_log_prob_fn=jd.log_prob,\n", |
| 488 | + " initial_parameters=initial_parameters,\n", |
| 489 | + " build_surrogate_posterior_fn=build_surrogate_fn,\n", |
490 | 490 | " optimizer=optax.adam(0.1), \n", |
491 | 491 | " num_steps=num_variational_steps,\n", |
492 | | - " seed=fit_seed) \n", |
| 492 | + " seed=fit_seed)\n", |
493 | 493 | "plt.plot(elbo_loss_curve)\n", |
494 | 494 | "plt.show()\n", |
495 | 495 | "\n", |
|
0 commit comments