Skip to content

Commit f2c7ee7

Browse files
Documentation Fix - Ragged Tensor
1 parent ec3a1b3 commit f2c7ee7

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

site/en/guide/ragged_tensor.ipynb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,7 @@
674674
"source": [
675675
"### Keras\n",
676676
"\n",
677-
"[tf.keras](https://www.tensorflow.org/guide/keras) is TensorFlow's high-level API for building and training deep learning models. Ragged tensors may be passed as inputs to a Keras model by setting `ragged=True` on `tf.keras.Input` or `tf.keras.layers.InputLayer`. Ragged tensors may also be passed between Keras layers, and returned by Keras models. The following example shows a toy LSTM model that is trained using ragged tensors."
677+
"[tf.keras](https://www.tensorflow.org/guide/keras) is TensorFlow's high-level API for building and training deep learning models. Ragged tensors can be passed as inputs to a Keras model by using ragged tensors between Keras layers, and returning ragged tensors by Keras models. The following example shows a toy LSTM model that is trained using ragged tensors:"
678678
]
679679
},
680680
{
@@ -700,17 +700,17 @@
700700
"\n",
701701
"# Build the Keras model.\n",
702702
"keras_model = tf.keras.Sequential([\n",
703-
" tf.keras.layers.Input(shape=[None], dtype=tf.int64, ragged=True),\n",
704-
" tf.keras.layers.Embedding(hash_buckets, 16),\n",
705-
" tf.keras.layers.LSTM(32, use_bias=False),\n",
703+
" tf.keras.layers.Embedding(hash_buckets, 16, input_length=hashed_words.shape[1]),\n",
704+
" tf.keras.layers.LSTM(32, return_sequences=True, use_bias=False),\n",
705+
" tf.keras.layers.Flatten(),\n",
706706
" tf.keras.layers.Dense(32),\n",
707707
" tf.keras.layers.Activation(tf.nn.relu),\n",
708708
" tf.keras.layers.Dense(1)\n",
709709
"])\n",
710710
"\n",
711711
"keras_model.compile(loss='binary_crossentropy', optimizer='rmsprop')\n",
712712
"keras_model.fit(hashed_words, is_question, epochs=5)\n",
713-
"print(keras_model.predict(hashed_words))"
713+
"print(keras_model.predict(hashed_words))\n"
714714
]
715715
},
716716
{

0 commit comments

Comments
 (0)