Skip to content

Commit f934d00

Browse files
Update image captioning guide to use TextVectorization
PiperOrigin-RevId: 416179547
1 parent 1283fc9 commit f934d00

File tree

1 file changed

+47
-64
lines changed

1 file changed

+47
-64
lines changed

site/en/tutorials/text/image_captioning.ipynb

Lines changed: 47 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@
261261
"def load_image(image_path):\n",
262262
" img = tf.io.read_file(image_path)\n",
263263
" img = tf.io.decode_jpeg(img, channels=3)\n",
264-
" img = tf.image.resize(img, (299, 299))\n",
264+
" img = tf.keras.layers.Resizing(299, 299)(img)\n",
265265
" img = tf.keras.applications.inception_v3.preprocess_input(img)\n",
266266
" return img, image_path"
267267
]
@@ -361,23 +361,11 @@
361361
"source": [
362362
"## Preprocess and tokenize the captions\n",
363363
"\n",
364-
"* First, you'll tokenize the captions (for example, by splitting on spaces). This gives us a vocabulary of all of the unique words in the data (for example, \"surfing\", \"football\", and so on).\n",
365-
"* Next, you'll limit the vocabulary size to the top 5,000 words (to save memory). You'll replace all other words with the token \"UNK\" (unknown).\n",
366-
"* You then create word-to-index and index-to-word mappings.\n",
367-
"* Finally, you pad all sequences to be the same length as the longest one."
368-
]
369-
},
370-
{
371-
"cell_type": "code",
372-
"execution_count": null,
373-
"metadata": {
374-
"id": "HZfK8RhQRPFj"
375-
},
376-
"outputs": [],
377-
"source": [
378-
"# Find the maximum length of any caption in the dataset\n",
379-
"def calc_max_length(tensor):\n",
380-
" return max(len(t) for t in tensor)"
364+
"You will transform the text captions into integer sequences using the [TextVectorization](https://www.tensorflow.org/api_docs/python/tf/keras/layers/TextVectorization) layer, with the following steps:\n",
365+
"\n",
366+
"* Use [adapt](https://www.tensorflow.org/api_docs/python/tf/keras/layers/TextVectorization#adapt) to iterate over all captions, split the captions into words, and compute a vocabulary of the top 5,000 words (to save memory).\n",
367+
"* Tokenize all captions by mapping each word to it's index in the vocabulary. All output sequences will be padded to length 50.\n",
368+
"* Create word-to-index and index-to-word mappings to display results."
381369
]
382370
},
383371
{
@@ -388,61 +376,55 @@
388376
},
389377
"outputs": [],
390378
"source": [
391-
"# Choose the top 5000 words from the vocabulary\n",
392-
"top_k = 5000\n",
393-
"tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=top_k,\n",
394-
" oov_token=\"<unk>\",\n",
395-
" filters='!\"#$%&()*+.,-/:;=?@[\\]^_`{|}~')\n",
396-
"tokenizer.fit_on_texts(train_captions)"
397-
]
398-
},
399-
{
400-
"cell_type": "code",
401-
"execution_count": null,
402-
"metadata": {
403-
"id": "8Q44tNQVRPFt"
404-
},
405-
"outputs": [],
406-
"source": [
407-
"tokenizer.word_index['<pad>'] = 0\n",
408-
"tokenizer.index_word[0] = '<pad>'"
379+
"caption_dataset = tf.data.Dataset.from_tensor_slices(train_captions)\n",
380+
"\n",
381+
"# We will override the default standardization of TextVectorization to preserve\n",
382+
"# \"<>\" characters, so we preserve the tokens for the <start> and <end>.\n",
383+
"def standardize(inputs):\n",
384+
" inputs = tf.strings.lower(inputs)\n",
385+
" return tf.strings.regex_replace(inputs,\n",
386+
" r\"!\\\"#$%&\\(\\)\\*\\+.,-/:;=?@\\[\\\\\\]^_`{|}~\", \"\")\n",
387+
"\n",
388+
"# Max word count for a caption.\n",
389+
"max_length = 50\n",
390+
"# Use the top 5000 words for a vocabulary.\n",
391+
"vocabulary_size = 5000\n",
392+
"tokenizer = tf.keras.layers.TextVectorization(\n",
393+
" max_tokens=vocabulary_size,\n",
394+
" standardize=standardize,\n",
395+
" output_sequence_length=max_length)\n",
396+
"# Learn the vocabulary from the caption data.\n",
397+
"tokenizer.adapt(caption_dataset)"
409398
]
410399
},
411400
{
412401
"cell_type": "code",
413402
"execution_count": null,
414403
"metadata": {
415-
"id": "0fpJb5ojRPFv"
404+
"id": "Uaq07VVEu36f"
416405
},
417406
"outputs": [],
418407
"source": [
419408
"# Create the tokenized vectors\n",
420-
"train_seqs = tokenizer.texts_to_sequences(train_captions)"
409+
"cap_vector = caption_dataset.map(lambda x: tokenizer(x))"
421410
]
422411
},
423412
{
424413
"cell_type": "code",
425414
"execution_count": null,
426415
"metadata": {
427-
"id": "AidglIZVRPF4"
428-
},
429-
"outputs": [],
430-
"source": [
431-
"# Pad each vector to the max_length of the captions\n",
432-
"# If you do not provide a max_length value, pad_sequences calculates it automatically\n",
433-
"cap_vector = tf.keras.preprocessing.sequence.pad_sequences(train_seqs, padding='post')"
434-
]
435-
},
436-
{
437-
"cell_type": "code",
438-
"execution_count": null,
439-
"metadata": {
440-
"id": "gL0wkttkRPGA"
416+
"id": "8Q44tNQVRPFt"
441417
},
442418
"outputs": [],
443419
"source": [
444-
"# Calculates the max_length, which is used to store the attention weights\n",
445-
"max_length = calc_max_length(train_seqs)"
420+
"# Create mappings for words to indices and indicies to words.\n",
421+
"word_to_index = tf.keras.layers.StringLookup(\n",
422+
" mask_token=\"\",\n",
423+
" vocabulary=tokenizer.get_vocabulary())\n",
424+
"index_to_word = tf.keras.layers.StringLookup(\n",
425+
" mask_token=\"\",\n",
426+
" vocabulary=tokenizer.get_vocabulary(),\n",
427+
" invert=True)"
446428
]
447429
},
448430
{
@@ -531,7 +513,6 @@
531513
"BUFFER_SIZE = 1000\n",
532514
"embedding_dim = 256\n",
533515
"units = 512\n",
534-
"vocab_size = top_k + 1\n",
535516
"num_steps = len(img_name_train) // BATCH_SIZE\n",
536517
"# Shape of the vector extracted from InceptionV3 is (64, 2048)\n",
537518
"# These two variables represent that vector shape\n",
@@ -565,7 +546,7 @@
565546
"\n",
566547
"# Use map to load the numpy files in parallel\n",
567548
"dataset = dataset.map(lambda item1, item2: tf.numpy_function(\n",
568-
" map_func, [item1, item2], [tf.float32, tf.int32]),\n",
549+
" map_func, [item1, item2], [tf.float32, tf.int64]),\n",
569550
" num_parallel_calls=tf.data.AUTOTUNE)\n",
570551
"\n",
571552
"# Shuffle and batch\n",
@@ -713,7 +694,7 @@
713694
"outputs": [],
714695
"source": [
715696
"encoder = CNN_Encoder(embedding_dim)\n",
716-
"decoder = RNN_Decoder(embedding_dim, units, vocab_size)"
697+
"decoder = RNN_Decoder(embedding_dim, units, tokenizer.vocabulary_size())"
717698
]
718699
},
719700
{
@@ -824,7 +805,7 @@
824805
" # because the captions are not related from image to image\n",
825806
" hidden = decoder.reset_state(batch_size=target.shape[0])\n",
826807
"\n",
827-
" dec_input = tf.expand_dims([tokenizer.word_index['<start>']] * target.shape[0], 1)\n",
808+
" dec_input = tf.expand_dims([word_to_index('<start>')] * target.shape[0], 1)\n",
828809
"\n",
829810
" with tf.GradientTape() as tape:\n",
830811
" features = encoder(img_tensor)\n",
@@ -929,7 +910,7 @@
929910
"\n",
930911
" features = encoder(img_tensor_val)\n",
931912
"\n",
932-
" dec_input = tf.expand_dims([tokenizer.word_index['<start>']], 0)\n",
913+
" dec_input = tf.expand_dims([word_to_index('<start>')], 0)\n",
933914
" result = []\n",
934915
"\n",
935916
" for i in range(max_length):\n",
@@ -940,9 +921,10 @@
940921
" attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n",
941922
"\n",
942923
" predicted_id = tf.random.categorical(predictions, 1)[0][0].numpy()\n",
943-
" result.append(tokenizer.index_word[predicted_id])\n",
924+
" predicted_word = tf.compat.as_text(index_to_word(predicted_id).numpy())\n",
925+
" result.append(predicted_word)\n",
944926
"\n",
945-
" if tokenizer.index_word[predicted_id] == '<end>':\n",
927+
" if predicted_word == '<end>':\n",
946928
" return result, attention_plot\n",
947929
"\n",
948930
" dec_input = tf.expand_dims([predicted_id], 0)\n",
@@ -967,7 +949,7 @@
967949
" len_result = len(result)\n",
968950
" for i in range(len_result):\n",
969951
" temp_att = np.resize(attention_plot[i], (8, 8))\n",
970-
" grid_size = max(np.ceil(len_result/2), 2)\n",
952+
" grid_size = max(int(np.ceil(len_result/2)), 2)\n",
971953
" ax = fig.add_subplot(grid_size, grid_size, i+1)\n",
972954
" ax.set_title(result[i])\n",
973955
" img = ax.imshow(temp_image)\n",
@@ -988,8 +970,8 @@
988970
"# captions on the validation set\n",
989971
"rid = np.random.randint(0, len(img_name_val))\n",
990972
"image = img_name_val[rid]\n",
991-
"real_caption = ' '.join([tokenizer.index_word[i]\n",
992-
" for i in cap_val[rid] if i not in [0]])\n",
973+
"real_caption = ' '.join([tf.compat.as_text(index_to_word(i).numpy())\n",
974+
" for i in cap_val[rid] if i not in [0]])\n",
993975
"result, attention_plot = evaluate(image)\n",
994976
"\n",
995977
"print('Real Caption:', real_caption)\n",
@@ -1044,6 +1026,7 @@
10441026
"colab": {
10451027
"collapsed_sections": [],
10461028
"name": "image_captioning.ipynb",
1029+
"provenance": [],
10471030
"toc_visible": true
10481031
},
10491032
"kernelspec": {

0 commit comments

Comments
 (0)