Skip to content

Commit 7d04d5d

Browse files
Merge pull request #2327 from mahsanghani:patch-2
PiperOrigin-RevId: 673915278
2 parents 7df1042 + 9555b66 commit 7d04d5d

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

site/en/tutorials/distribute/keras.ipynb

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@
363363
"# Define the checkpoint directory to store the checkpoints.\n",
364364
"checkpoint_dir = './training_checkpoints'\n",
365365
"# Define the name of the checkpoint files.\n",
366-
"checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt_{epoch}\")"
366+
"checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt_{epoch:04d}.weights.h5\")"
367367
]
368368
},
369369
{
@@ -396,7 +396,7 @@
396396
"# Define a callback for printing the learning rate at the end of each epoch.\n",
397397
"class PrintLR(tf.keras.callbacks.Callback):\n",
398398
" def on_epoch_end(self, epoch, logs=None):\n",
399-
" print('\\nLearning rate for epoch {} is {}'.format( epoch + 1, model.optimizer.lr.numpy()))"
399+
" print('\\nLearning rate for epoch {} is {}'.format(epoch + 1, model.optimizer.learning_rate.numpy()))"
400400
]
401401
},
402402
{
@@ -486,7 +486,10 @@
486486
},
487487
"outputs": [],
488488
"source": [
489-
"model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))\n",
489+
"import pathlib\n",
490+
"latest_checkpoint = sorted(pathlib.Path(checkpoint_dir).glob('*'))[-1]\n",
491+
"\n",
492+
"model.load_weights(latest_checkpoint)\n",
490493
"\n",
491494
"eval_loss, eval_acc = model.evaluate(eval_dataset)\n",
492495
"\n",

0 commit comments

Comments
 (0)