Skip to content

Commit 4d512c2

Browse files
Change layer.get_weights()[0] to layer.embeddings in the warmstart_embedding_matrix tutorial.
This prevents an error when the vocab is large and the embedding tensor gets sharded. PiperOrigin-RevId: 542046545
1 parent 82c1d10 commit 4d512c2

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

site/en/tutorials/text/warmstart_embedding_matrix.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@
525525
"outputs": [],
526526
"source": [
527527
"embedding_weights_base = (\n",
528-
" model.get_layer(\"text_input\").get_layer(\"embedding\").get_weights()[0]\n",
528+
" model.get_layer(\"text_input\").get_layer(\"embedding\").embeddings\n",
529529
")\n",
530530
"vocab_base = vectorize_layer.get_vocabulary()"
531531
]
@@ -681,7 +681,7 @@
681681
"\n",
682682
"# Verify the shape of updated weights\n",
683683
"# The new weights shape should reflect the new vocabulary size\n",
684-
"text_input_new.get_layer(\"embedding\").get_weights()[0].shape"
684+
"text_input_new.get_layer(\"embedding\").embeddings.shape"
685685
]
686686
},
687687
{

0 commit comments

Comments
 (0)