|
526 | 526 | },
|
527 | 527 | "outputs": [],
|
528 | 528 | "source": [
|
529 |
| - "# Add a dimension so you can use concatenation (in the next step).\n", |
530 |
| - "negative_sampling_candidates = tf.expand_dims(negative_sampling_candidates, 1)\n", |
| 529 | + "# Reduce a dimension so you can use concatenation (in the next step).\n", |
| 530 | + "squeezed_context_class = tf.squeeze(context_class, 1)\n", |
531 | 531 | "\n",
|
532 | 532 | "# Concatenate a positive context word with negative sampled words.\n",
|
533 |
| - "context = tf.concat([context_class, negative_sampling_candidates], 0)\n", |
| 533 | + "context = tf.concat([squeezed_context_class, negative_sampling_candidates], 0)\n", |
534 | 534 | "\n",
|
535 | 535 | "# Label the first context word as `1` (positive) followed by `num_ns` `0`s (negative).\n",
|
536 | 536 | "label = tf.constant([1] + [0]*num_ns, dtype=\"int64\")\n",
|
537 |
| - "\n", |
538 |
| - "# Reshape the target to shape `(1,)` and context and label to `(num_ns+1,)`.\n", |
539 |
| - "target = tf.squeeze(target_word)\n", |
540 |
| - "context = tf.squeeze(context)\n", |
541 |
| - "label = tf.squeeze(label)" |
| 537 | + "target = target_word\n" |
542 | 538 | ]
|
543 | 539 | },
|
544 | 540 | {
|
|
751 | 747 | " name=\"negative_sampling\")\n",
|
752 | 748 | "\n",
|
753 | 749 | " # Build context and label vectors (for one target word)\n",
|
754 |
| - " negative_sampling_candidates = tf.expand_dims(\n", |
755 |
| - " negative_sampling_candidates, 1)\n", |
756 |
| - "\n", |
757 |
| - " context = tf.concat([context_class, negative_sampling_candidates], 0)\n", |
| 750 | + " context = tf.concat([tf.squeeze(context_class,1), negative_sampling_candidates], 0)\n", |
758 | 751 | " label = tf.constant([1] + [0]*num_ns, dtype=\"int64\")\n",
|
759 | 752 | "\n",
|
760 | 753 | " # Append each element from the training example to global lists.\n",
|
|
1053 | 1046 | " seed=SEED)\n",
|
1054 | 1047 | "\n",
|
1055 | 1048 | "targets = np.array(targets)\n",
|
1056 |
| - "contexts = np.array(contexts)[:,:,0]\n", |
| 1049 | + "contexts = np.array(contexts)\n", |
1057 | 1050 | "labels = np.array(labels)\n",
|
1058 | 1051 | "\n",
|
1059 | 1052 | "print('\\n')\n",
|
|
0 commit comments