Skip to content

Commit 5c5a466

Browse files
Merge pull request #2302 from balaganesh102004:master
PiperOrigin-RevId: 667491664
2 parents 57d0938 + 426ab2a commit 5c5a466

File tree

1 file changed

+67
-15
lines changed

1 file changed

+67
-15
lines changed

site/en/guide/ragged_tensor.ipynb

Lines changed: 67 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -674,14 +674,14 @@
674674
"source": [
675675
"### Keras\n",
676676
"\n",
677-
"[tf.keras](https://www.tensorflow.org/guide/keras) is TensorFlow's high-level API for building and training deep learning models. Ragged tensors may be passed as inputs to a Keras model by setting `ragged=True` on `tf.keras.Input` or `tf.keras.layers.InputLayer`. Ragged tensors may also be passed between Keras layers, and returned by Keras models. The following example shows a toy LSTM model that is trained using ragged tensors."
677+
"[tf.keras](https://www.tensorflow.org/guide/keras) is TensorFlow's high-level API for building and training deep learning models. It doesn't have ragged support. But it does support masked tensors. So the easiest way to use a ragged tensor in a Keras model is to convert the ragged tensor to a dense tensor, using `.to_tensor()` and then using Keras's builtin masking:"
678678
]
679679
},
680680
{
681681
"cell_type": "code",
682682
"execution_count": null,
683683
"metadata": {
684-
"id": "pHls7hQVJlk5"
684+
"id": "ucYf2sSzTvQo"
685685
},
686686
"outputs": [],
687687
"source": [
@@ -691,26 +691,77 @@
691691
" 'She turned me into a newt.',\n",
692692
" 'A newt?',\n",
693693
" 'Well, I got better.'])\n",
694-
"is_question = tf.constant([True, False, True, False])\n",
695-
"\n",
694+
"is_question = tf.constant([True, False, True, False])"
695+
]
696+
},
697+
{
698+
"cell_type": "code",
699+
"execution_count": null,
700+
"metadata": {
701+
"id": "MGYKmizJTw8B"
702+
},
703+
"outputs": [],
704+
"source": [
696705
"# Preprocess the input strings.\n",
697706
"hash_buckets = 1000\n",
698707
"words = tf.strings.split(sentences, ' ')\n",
699708
"hashed_words = tf.strings.to_hash_bucket_fast(words, hash_buckets)\n",
700-
"\n",
709+
"hashed_words.to_list()"
710+
]
711+
},
712+
{
713+
"cell_type": "code",
714+
"execution_count": null,
715+
"metadata": {
716+
"id": "7FTujwOlUT8J"
717+
},
718+
"outputs": [],
719+
"source": [
720+
"hashed_words.to_tensor()"
721+
]
722+
},
723+
{
724+
"cell_type": "code",
725+
"execution_count": null,
726+
"metadata": {
727+
"id": "vzWudaESUBOZ"
728+
},
729+
"outputs": [],
730+
"source": [
731+
"tf.keras.Input?"
732+
]
733+
},
734+
{
735+
"cell_type": "code",
736+
"execution_count": null,
737+
"metadata": {
738+
"id": "pHls7hQVJlk5"
739+
},
740+
"outputs": [],
741+
"source": [
701742
"# Build the Keras model.\n",
702743
"keras_model = tf.keras.Sequential([\n",
703-
" tf.keras.layers.Input(shape=[None], dtype=tf.int64, ragged=True),\n",
704-
" tf.keras.layers.Embedding(hash_buckets, 16),\n",
705-
" tf.keras.layers.LSTM(32, use_bias=False),\n",
744+
" tf.keras.layers.Embedding(hash_buckets, 16, mask_zero=True),\n",
745+
" tf.keras.layers.LSTM(32, return_sequences=True, use_bias=False),\n",
746+
" tf.keras.layers.GlobalAveragePooling1D(),\n",
706747
" tf.keras.layers.Dense(32),\n",
707748
" tf.keras.layers.Activation(tf.nn.relu),\n",
708749
" tf.keras.layers.Dense(1)\n",
709750
"])\n",
710751
"\n",
711752
"keras_model.compile(loss='binary_crossentropy', optimizer='rmsprop')\n",
712-
"keras_model.fit(hashed_words, is_question, epochs=5)\n",
713-
"print(keras_model.predict(hashed_words))"
753+
"keras_model.fit(hashed_words.to_tensor(), is_question, epochs=5)\n"
754+
]
755+
},
756+
{
757+
"cell_type": "code",
758+
"execution_count": null,
759+
"metadata": {
760+
"id": "1IAjjmdTU9OU"
761+
},
762+
"outputs": [],
763+
"source": [
764+
"print(keras_model.predict(hashed_words.to_tensor()))"
714765
]
715766
},
716767
{
@@ -799,7 +850,7 @@
799850
"source": [
800851
"### Datasets\n",
801852
"\n",
802-
"[tf.data](https://www.tensorflow.org/guide/data) is an API that enables you to build complex input pipelines from simple, reusable pieces. Its core data structure is `tf.data.Dataset`, which represents a sequence of elements, in which each element consists of one or more components. "
853+
"[tf.data](https://www.tensorflow.org/guide/data) is an API that enables you to build complex input pipelines from simple, reusable pieces. Its core data structure is `tf.data.Dataset`, which represents a sequence of elements, in which each element consists of one or more components."
803854
]
804855
},
805856
{
@@ -1078,9 +1129,11 @@
10781129
"import tempfile\n",
10791130
"\n",
10801131
"keras_module_path = tempfile.mkdtemp()\n",
1081-
"tf.saved_model.save(keras_model, keras_module_path)\n",
1082-
"imported_model = tf.saved_model.load(keras_module_path)\n",
1083-
"imported_model(hashed_words)"
1132+
"keras_model.save(keras_module_path+\"/my_model.keras\")\n",
1133+
"\n",
1134+
"imported_model = tf.keras.models.load_model(keras_module_path+\"/my_model.keras\")\n",
1135+
"\n",
1136+
"imported_model(hashed_words.to_tensor())"
10841137
]
10851138
},
10861139
{
@@ -2125,7 +2178,6 @@
21252178
],
21262179
"metadata": {
21272180
"colab": {
2128-
"collapsed_sections": [],
21292181
"name": "ragged_tensor.ipynb",
21302182
"toc_visible": true
21312183
},

0 commit comments

Comments
 (0)