|
261 | 261 | "def load_image(image_path):\n",
|
262 | 262 | " img = tf.io.read_file(image_path)\n",
|
263 | 263 | " img = tf.io.decode_jpeg(img, channels=3)\n",
|
264 |
| - " img = tf.image.resize(img, (299, 299))\n", |
| 264 | + " img = tf.keras.layers.Resizing(299, 299)(img)\n", |
265 | 265 | " img = tf.keras.applications.inception_v3.preprocess_input(img)\n",
|
266 | 266 | " return img, image_path"
|
267 | 267 | ]
|
|
361 | 361 | "source": [
|
362 | 362 | "## Preprocess and tokenize the captions\n",
|
363 | 363 | "\n",
|
364 |
| - "* First, you'll tokenize the captions (for example, by splitting on spaces). This gives us a vocabulary of all of the unique words in the data (for example, \"surfing\", \"football\", and so on).\n", |
365 |
| - "* Next, you'll limit the vocabulary size to the top 5,000 words (to save memory). You'll replace all other words with the token \"UNK\" (unknown).\n", |
366 |
| - "* You then create word-to-index and index-to-word mappings.\n", |
367 |
| - "* Finally, you pad all sequences to be the same length as the longest one." |
368 |
| - ] |
369 |
| - }, |
370 |
| - { |
371 |
| - "cell_type": "code", |
372 |
| - "execution_count": null, |
373 |
| - "metadata": { |
374 |
| - "id": "HZfK8RhQRPFj" |
375 |
| - }, |
376 |
| - "outputs": [], |
377 |
| - "source": [ |
378 |
| - "# Find the maximum length of any caption in the dataset\n", |
379 |
| - "def calc_max_length(tensor):\n", |
380 |
| - " return max(len(t) for t in tensor)" |
| 364 | + "You will transform the text captions into integer sequences using the [TextVectorization](https://www.tensorflow.org/api_docs/python/tf/keras/layers/TextVectorization) layer, with the following steps:\n", |
| 365 | + "\n", |
| 366 | + "* Use [adapt](https://www.tensorflow.org/api_docs/python/tf/keras/layers/TextVectorization#adapt) to iterate over all captions, split the captions into words, and compute a vocabulary of the top 5,000 words (to save memory).\n", |
| 367 | + "* Tokenize all captions by mapping each word to it's index in the vocabulary. All output sequences will be padded to length 50.\n", |
| 368 | + "* Create word-to-index and index-to-word mappings to display results." |
381 | 369 | ]
|
382 | 370 | },
|
383 | 371 | {
|
|
388 | 376 | },
|
389 | 377 | "outputs": [],
|
390 | 378 | "source": [
|
391 |
| - "# Choose the top 5000 words from the vocabulary\n", |
392 |
| - "top_k = 5000\n", |
393 |
| - "tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=top_k,\n", |
394 |
| - " oov_token=\"<unk>\",\n", |
395 |
| - " filters='!\"#$%&()*+.,-/:;=?@[\\]^_`{|}~')\n", |
396 |
| - "tokenizer.fit_on_texts(train_captions)" |
397 |
| - ] |
398 |
| - }, |
399 |
| - { |
400 |
| - "cell_type": "code", |
401 |
| - "execution_count": null, |
402 |
| - "metadata": { |
403 |
| - "id": "8Q44tNQVRPFt" |
404 |
| - }, |
405 |
| - "outputs": [], |
406 |
| - "source": [ |
407 |
| - "tokenizer.word_index['<pad>'] = 0\n", |
408 |
| - "tokenizer.index_word[0] = '<pad>'" |
| 379 | + "caption_dataset = tf.data.Dataset.from_tensor_slices(train_captions)\n", |
| 380 | + "\n", |
| 381 | + "# We will override the default standardization of TextVectorization to preserve\n", |
| 382 | + "# \"<>\" characters, so we preserve the tokens for the <start> and <end>.\n", |
| 383 | + "def standardize(inputs):\n", |
| 384 | + " inputs = tf.strings.lower(inputs)\n", |
| 385 | + " return tf.strings.regex_replace(inputs,\n", |
| 386 | + " r\"!\\\"#$%&\\(\\)\\*\\+.,-/:;=?@\\[\\\\\\]^_`{|}~\", \"\")\n", |
| 387 | + "\n", |
| 388 | + "# Max word count for a caption.\n", |
| 389 | + "max_length = 50\n", |
| 390 | + "# Use the top 5000 words for a vocabulary.\n", |
| 391 | + "vocabulary_size = 5000\n", |
| 392 | + "tokenizer = tf.keras.layers.TextVectorization(\n", |
| 393 | + " max_tokens=vocabulary_size,\n", |
| 394 | + " standardize=standardize,\n", |
| 395 | + " output_sequence_length=max_length)\n", |
| 396 | + "# Learn the vocabulary from the caption data.\n", |
| 397 | + "tokenizer.adapt(caption_dataset)" |
409 | 398 | ]
|
410 | 399 | },
|
411 | 400 | {
|
412 | 401 | "cell_type": "code",
|
413 | 402 | "execution_count": null,
|
414 | 403 | "metadata": {
|
415 |
| - "id": "0fpJb5ojRPFv" |
| 404 | + "id": "Uaq07VVEu36f" |
416 | 405 | },
|
417 | 406 | "outputs": [],
|
418 | 407 | "source": [
|
419 | 408 | "# Create the tokenized vectors\n",
|
420 |
| - "train_seqs = tokenizer.texts_to_sequences(train_captions)" |
| 409 | + "cap_vector = caption_dataset.map(lambda x: tokenizer(x))" |
421 | 410 | ]
|
422 | 411 | },
|
423 | 412 | {
|
424 | 413 | "cell_type": "code",
|
425 | 414 | "execution_count": null,
|
426 | 415 | "metadata": {
|
427 |
| - "id": "AidglIZVRPF4" |
428 |
| - }, |
429 |
| - "outputs": [], |
430 |
| - "source": [ |
431 |
| - "# Pad each vector to the max_length of the captions\n", |
432 |
| - "# If you do not provide a max_length value, pad_sequences calculates it automatically\n", |
433 |
| - "cap_vector = tf.keras.preprocessing.sequence.pad_sequences(train_seqs, padding='post')" |
434 |
| - ] |
435 |
| - }, |
436 |
| - { |
437 |
| - "cell_type": "code", |
438 |
| - "execution_count": null, |
439 |
| - "metadata": { |
440 |
| - "id": "gL0wkttkRPGA" |
| 416 | + "id": "8Q44tNQVRPFt" |
441 | 417 | },
|
442 | 418 | "outputs": [],
|
443 | 419 | "source": [
|
444 |
| - "# Calculates the max_length, which is used to store the attention weights\n", |
445 |
| - "max_length = calc_max_length(train_seqs)" |
| 420 | + "# Create mappings for words to indices and indicies to words.\n", |
| 421 | + "word_to_index = tf.keras.layers.StringLookup(\n", |
| 422 | + " mask_token=\"\",\n", |
| 423 | + " vocabulary=tokenizer.get_vocabulary())\n", |
| 424 | + "index_to_word = tf.keras.layers.StringLookup(\n", |
| 425 | + " mask_token=\"\",\n", |
| 426 | + " vocabulary=tokenizer.get_vocabulary(),\n", |
| 427 | + " invert=True)" |
446 | 428 | ]
|
447 | 429 | },
|
448 | 430 | {
|
|
531 | 513 | "BUFFER_SIZE = 1000\n",
|
532 | 514 | "embedding_dim = 256\n",
|
533 | 515 | "units = 512\n",
|
534 |
| - "vocab_size = top_k + 1\n", |
535 | 516 | "num_steps = len(img_name_train) // BATCH_SIZE\n",
|
536 | 517 | "# Shape of the vector extracted from InceptionV3 is (64, 2048)\n",
|
537 | 518 | "# These two variables represent that vector shape\n",
|
|
565 | 546 | "\n",
|
566 | 547 | "# Use map to load the numpy files in parallel\n",
|
567 | 548 | "dataset = dataset.map(lambda item1, item2: tf.numpy_function(\n",
|
568 |
| - " map_func, [item1, item2], [tf.float32, tf.int32]),\n", |
| 549 | + " map_func, [item1, item2], [tf.float32, tf.int64]),\n", |
569 | 550 | " num_parallel_calls=tf.data.AUTOTUNE)\n",
|
570 | 551 | "\n",
|
571 | 552 | "# Shuffle and batch\n",
|
|
713 | 694 | "outputs": [],
|
714 | 695 | "source": [
|
715 | 696 | "encoder = CNN_Encoder(embedding_dim)\n",
|
716 |
| - "decoder = RNN_Decoder(embedding_dim, units, vocab_size)" |
| 697 | + "decoder = RNN_Decoder(embedding_dim, units, tokenizer.vocabulary_size())" |
717 | 698 | ]
|
718 | 699 | },
|
719 | 700 | {
|
|
824 | 805 | " # because the captions are not related from image to image\n",
|
825 | 806 | " hidden = decoder.reset_state(batch_size=target.shape[0])\n",
|
826 | 807 | "\n",
|
827 |
| - " dec_input = tf.expand_dims([tokenizer.word_index['<start>']] * target.shape[0], 1)\n", |
| 808 | + " dec_input = tf.expand_dims([word_to_index('<start>')] * target.shape[0], 1)\n", |
828 | 809 | "\n",
|
829 | 810 | " with tf.GradientTape() as tape:\n",
|
830 | 811 | " features = encoder(img_tensor)\n",
|
|
929 | 910 | "\n",
|
930 | 911 | " features = encoder(img_tensor_val)\n",
|
931 | 912 | "\n",
|
932 |
| - " dec_input = tf.expand_dims([tokenizer.word_index['<start>']], 0)\n", |
| 913 | + " dec_input = tf.expand_dims([word_to_index('<start>')], 0)\n", |
933 | 914 | " result = []\n",
|
934 | 915 | "\n",
|
935 | 916 | " for i in range(max_length):\n",
|
|
940 | 921 | " attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n",
|
941 | 922 | "\n",
|
942 | 923 | " predicted_id = tf.random.categorical(predictions, 1)[0][0].numpy()\n",
|
943 |
| - " result.append(tokenizer.index_word[predicted_id])\n", |
| 924 | + " predicted_word = tf.compat.as_text(index_to_word(predicted_id).numpy())\n", |
| 925 | + " result.append(predicted_word)\n", |
944 | 926 | "\n",
|
945 |
| - " if tokenizer.index_word[predicted_id] == '<end>':\n", |
| 927 | + " if predicted_word == '<end>':\n", |
946 | 928 | " return result, attention_plot\n",
|
947 | 929 | "\n",
|
948 | 930 | " dec_input = tf.expand_dims([predicted_id], 0)\n",
|
|
967 | 949 | " len_result = len(result)\n",
|
968 | 950 | " for i in range(len_result):\n",
|
969 | 951 | " temp_att = np.resize(attention_plot[i], (8, 8))\n",
|
970 |
| - " grid_size = max(np.ceil(len_result/2), 2)\n", |
| 952 | + " grid_size = max(int(np.ceil(len_result/2)), 2)\n", |
971 | 953 | " ax = fig.add_subplot(grid_size, grid_size, i+1)\n",
|
972 | 954 | " ax.set_title(result[i])\n",
|
973 | 955 | " img = ax.imshow(temp_image)\n",
|
|
988 | 970 | "# captions on the validation set\n",
|
989 | 971 | "rid = np.random.randint(0, len(img_name_val))\n",
|
990 | 972 | "image = img_name_val[rid]\n",
|
991 |
| - "real_caption = ' '.join([tokenizer.index_word[i]\n", |
992 |
| - " for i in cap_val[rid] if i not in [0]])\n", |
| 973 | + "real_caption = ' '.join([tf.compat.as_text(index_to_word(i).numpy())\n", |
| 974 | + " for i in cap_val[rid] if i not in [0]])\n", |
993 | 975 | "result, attention_plot = evaluate(image)\n",
|
994 | 976 | "\n",
|
995 | 977 | "print('Real Caption:', real_caption)\n",
|
|
1044 | 1026 | "colab": {
|
1045 | 1027 | "collapsed_sections": [],
|
1046 | 1028 | "name": "image_captioning.ipynb",
|
| 1029 | + "provenance": [], |
1047 | 1030 | "toc_visible": true
|
1048 | 1031 | },
|
1049 | 1032 | "kernelspec": {
|
|
0 commit comments