Skip to content

Commit dbc999f

Browse files
committed
update to checkpoint callback options (save_frequency)
introduced the number of batches ('n_batches') option for the save frequency instead of 'batch_size'. Using 'batch_size' works in this tutorial because the length of the training data is 1000 which coincidentally results in a rounded value of ~32 when it is divided by the 'batch_size'. In cases when the number of samples is not 1000, this will result in the model saving at different epoch frequencies other than after every 5 epochs. the definition of 'save_freq' (https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/ModelCheckpoint#args) clearly refers to the number of batches ('n_batches' in this context) and not the number of samples in a batch ('batch_size').
1 parent f696366 commit dbc999f

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

site/en/tutorials/keras/save_and_load.ipynb

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,12 +385,17 @@
385385
"\n",
386386
"batch_size = 32\n",
387387
"\n",
388+
"# calculate the number of batches per epoch\n",
389+
"import math\n",
390+
"n_batches = len(train_images) / batch_size\n",
391+
"n_batches = math.ceil(n_batches) # round up the number of batches to the nearest whole integer\n",
392+
"\n",
388393
"# Create a callback that saves the model's weights every 5 epochs\n",
389394
"cp_callback = tf.keras.callbacks.ModelCheckpoint(\n",
390395
" filepath=checkpoint_path, \n",
391396
" verbose=1, \n",
392397
" save_weights_only=True,\n",
393-
" save_freq=5*batch_size)\n",
398+
" save_freq=5*n_batches)\n",
394399
"\n",
395400
"# Create a new model instance\n",
396401
"model = create_model()\n",

0 commit comments

Comments
 (0)