Skip to content

Commit b37ef65

Browse files
Merge pull request #2128 from metric-space:word2vec-redundant-ops
PiperOrigin-RevId: 484066361
2 parents f384bf4 + 13b6887 commit b37ef65

File tree

1 file changed

+6
-13
lines changed

1 file changed

+6
-13
lines changed

site/en/tutorials/text/word2vec.ipynb

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -526,19 +526,15 @@
526526
},
527527
"outputs": [],
528528
"source": [
529-
"# Add a dimension so you can use concatenation (in the next step).\n",
530-
"negative_sampling_candidates = tf.expand_dims(negative_sampling_candidates, 1)\n",
529+
"# Reduce a dimension so you can use concatenation (in the next step).\n",
530+
"squeezed_context_class = tf.squeeze(context_class, 1)\n",
531531
"\n",
532532
"# Concatenate a positive context word with negative sampled words.\n",
533-
"context = tf.concat([context_class, negative_sampling_candidates], 0)\n",
533+
"context = tf.concat([squeezed_context_class, negative_sampling_candidates], 0)\n",
534534
"\n",
535535
"# Label the first context word as `1` (positive) followed by `num_ns` `0`s (negative).\n",
536536
"label = tf.constant([1] + [0]*num_ns, dtype=\"int64\")\n",
537-
"\n",
538-
"# Reshape the target to shape `(1,)` and context and label to `(num_ns+1,)`.\n",
539-
"target = tf.squeeze(target_word)\n",
540-
"context = tf.squeeze(context)\n",
541-
"label = tf.squeeze(label)"
537+
"target = target_word\n"
542538
]
543539
},
544540
{
@@ -751,10 +747,7 @@
751747
" name=\"negative_sampling\")\n",
752748
"\n",
753749
" # Build context and label vectors (for one target word)\n",
754-
" negative_sampling_candidates = tf.expand_dims(\n",
755-
" negative_sampling_candidates, 1)\n",
756-
"\n",
757-
" context = tf.concat([context_class, negative_sampling_candidates], 0)\n",
750+
" context = tf.concat([tf.squeeze(context_class,1), negative_sampling_candidates], 0)\n",
758751
" label = tf.constant([1] + [0]*num_ns, dtype=\"int64\")\n",
759752
"\n",
760753
" # Append each element from the training example to global lists.\n",
@@ -1053,7 +1046,7 @@
10531046
" seed=SEED)\n",
10541047
"\n",
10551048
"targets = np.array(targets)\n",
1056-
"contexts = np.array(contexts)[:,:,0]\n",
1049+
"contexts = np.array(contexts)\n",
10571050
"labels = np.array(labels)\n",
10581051
"\n",
10591052
"print('\\n')\n",

0 commit comments

Comments
 (0)