|
90 | 90 | },
|
91 | 91 | "outputs": [],
|
92 | 92 | "source": [
|
93 |
| - "import tensorflow as tf" |
| 93 | + "import tensorflow as tf\n", |
| 94 | + "\n", |
| 95 | + "import matplotlib.pyplot as plt\n", |
| 96 | + "\n", |
| 97 | + "colors = plt.rcParams['axes.prop_cycle'].by_key()['color']" |
94 | 98 | ]
|
95 | 99 | },
|
96 | 100 | {
|
|
142 | 146 | "TRUE_W = 3.0\n",
|
143 | 147 | "TRUE_B = 2.0\n",
|
144 | 148 | "\n",
|
145 |
| - "NUM_EXAMPLES = 1000\n", |
| 149 | + "NUM_EXAMPLES = 201\n", |
146 | 150 | "\n",
|
147 | 151 | "# A vector of random x values\n",
|
148 |
| - "x = tf.random.normal(shape=[NUM_EXAMPLES])\n", |
| 152 | + "x = tf.linspace(-2,2, NUM_EXAMPLES)\n", |
| 153 | + "x = tf.cast(x, tf.float32)\n", |
| 154 | + "\n", |
| 155 | + "def f(x):\n", |
| 156 | + " return x * TRUE_W + TRUE_B\n", |
149 | 157 | "\n",
|
150 | 158 | "# Generate some noise\n",
|
151 | 159 | "noise = tf.random.normal(shape=[NUM_EXAMPLES])\n",
|
152 | 160 | "\n",
|
153 | 161 | "# Calculate y\n",
|
154 |
| - "y = x * TRUE_W + TRUE_B + noise" |
| 162 | + "y = f(x) + noise" |
155 | 163 | ]
|
156 | 164 | },
|
157 | 165 | {
|
|
163 | 171 | "outputs": [],
|
164 | 172 | "source": [
|
165 | 173 | "# Plot all the data\n",
|
166 |
| - "import matplotlib.pyplot as plt\n", |
167 |
| - "\n", |
168 |
| - "plt.scatter(x, y, c=\"b\")\n", |
| 174 | + "plt.plot(x, y, '.')\n", |
169 | 175 | "plt.show()"
|
170 | 176 | ]
|
171 | 177 | },
|
|
271 | 277 | },
|
272 | 278 | "outputs": [],
|
273 | 279 | "source": [
|
274 |
| - "plt.scatter(x, y, c=\"b\")\n", |
275 |
| - "plt.scatter(x, model(x), c=\"r\")\n", |
| 280 | + "plt.plot(x, y, '.', label=\"Data\")\n", |
| 281 | + "plt.plot(x, f(x), label=\"Ground truth\")\n", |
| 282 | + "plt.plot(x, model(x), label=\"Predictions\")\n", |
| 283 | + "plt.legend()\n", |
276 | 284 | "plt.show()\n",
|
277 | 285 | "\n",
|
278 | 286 | "print(\"Current loss: %1.6f\" % loss(y, model(x)).numpy())"
|
|
341 | 349 | "model = MyModel()\n",
|
342 | 350 | "\n",
|
343 | 351 | "# Collect the history of W-values and b-values to plot later\n",
|
344 |
| - "Ws, bs = [], []\n", |
| 352 | + "weights = []\n", |
| 353 | + "biases = []\n", |
345 | 354 | "epochs = range(10)\n",
|
346 | 355 | "\n",
|
347 | 356 | "# Define a training loop\n",
|
| 357 | + "def report(model, loss):\n", |
| 358 | + " return f\"W = {model.w.numpy():1.2f}, b = {model.b.numpy():1.2f}, loss={current_loss:2.5f}\"\n", |
| 359 | + "\n", |
| 360 | + "\n", |
348 | 361 | "def training_loop(model, x, y):\n",
|
349 | 362 | "\n",
|
350 | 363 | " for epoch in epochs:\n",
|
351 | 364 | " # Update the model with the single giant batch\n",
|
352 | 365 | " train(model, x, y, learning_rate=0.1)\n",
|
353 | 366 | "\n",
|
354 | 367 | " # Track this before I update\n",
|
355 |
| - " Ws.append(model.w.numpy())\n", |
356 |
| - " bs.append(model.b.numpy())\n", |
| 368 | + " weights.append(model.w.numpy())\n", |
| 369 | + " biases.append(model.b.numpy())\n", |
357 | 370 | " current_loss = loss(y, model(x))\n",
|
358 | 371 | "\n",
|
359 |
| - " print(\"Epoch %2d: W=%1.2f b=%1.2f, loss=%2.5f\" %\n", |
360 |
| - " (epoch, Ws[-1], bs[-1], current_loss))\n" |
| 372 | + " print(f\"Epoch {epoch:2d}:\")\n", |
| 373 | + " print(\" \", report(model, current_loss))" |
| 374 | + ] |
| 375 | + }, |
| 376 | + { |
| 377 | + "cell_type": "markdown", |
| 378 | + "metadata": { |
| 379 | + "id": "8dKKLU4KkQEq" |
| 380 | + }, |
| 381 | + "source": [ |
| 382 | + "Do the training" |
361 | 383 | ]
|
362 | 384 | },
|
363 | 385 | {
|
|
368 | 390 | },
|
369 | 391 | "outputs": [],
|
370 | 392 | "source": [
|
371 |
| - "print(\"Starting: W=%1.2f b=%1.2f, loss=%2.5f\" %\n", |
372 |
| - " (model.w, model.b, loss(y, model(x))))\n", |
| 393 | + "current_loss = loss(y, model(x))\n", |
373 | 394 | "\n",
|
374 |
| - "# Do the training\n", |
375 |
| - "training_loop(model, x, y)\n", |
| 395 | + "print(f\"Starting:\")\n", |
| 396 | + "print(\" \", report(model, current_loss))\n", |
376 | 397 | "\n",
|
377 |
| - "# Plot it\n", |
378 |
| - "plt.plot(epochs, Ws, \"r\",\n", |
379 |
| - " epochs, bs, \"b\")\n", |
| 398 | + "training_loop(model, x, y)" |
| 399 | + ] |
| 400 | + }, |
| 401 | + { |
| 402 | + "cell_type": "markdown", |
| 403 | + "metadata": { |
| 404 | + "id": "JPJgimg8kSA4" |
| 405 | + }, |
| 406 | + "source": [ |
| 407 | + "Plot the evolution of the weights over time:" |
| 408 | + ] |
| 409 | + }, |
| 410 | + { |
| 411 | + "cell_type": "code", |
| 412 | + "execution_count": null, |
| 413 | + "metadata": { |
| 414 | + "id": "ND1fQw8sbTNr" |
| 415 | + }, |
| 416 | + "outputs": [], |
| 417 | + "source": [ |
| 418 | + "plt.plot(epochs, weights, label='Weights', color=colors[0])\n", |
| 419 | + "plt.plot(epochs, [TRUE_W] * len(epochs), '--',\n", |
| 420 | + " label = \"True weight\", color=colors[0])\n", |
380 | 421 | "\n",
|
381 |
| - "plt.plot([TRUE_W] * len(epochs), \"r--\",\n", |
382 |
| - " [TRUE_B] * len(epochs), \"b--\")\n", |
| 422 | + "plt.plot(epochs, biases, label='bias', color=colors[1])\n", |
| 423 | + "plt.plot(epochs, [TRUE_B] * len(epochs), \"--\",\n", |
| 424 | + " label=\"True bias\", color=colors[1])\n", |
383 | 425 | "\n",
|
384 |
| - "plt.legend([\"W\", \"b\", \"True W\", \"True b\"])\n", |
385 |
| - "plt.show()\n" |
| 426 | + "plt.legend()\n", |
| 427 | + "plt.show()" |
| 428 | + ] |
| 429 | + }, |
| 430 | + { |
| 431 | + "cell_type": "markdown", |
| 432 | + "metadata": { |
| 433 | + "id": "zhlwj1ojkcUP" |
| 434 | + }, |
| 435 | + "source": [ |
| 436 | + "Visualize how the trained model performs" |
386 | 437 | ]
|
387 | 438 | },
|
388 | 439 | {
|
|
393 | 444 | },
|
394 | 445 | "outputs": [],
|
395 | 446 | "source": [
|
396 |
| - "# Visualize how the trained model performs\n", |
397 |
| - "plt.scatter(x, y, c=\"b\")\n", |
398 |
| - "plt.scatter(x, model(x), c=\"r\")\n", |
| 447 | + "plt.plot(x, y, '.', label=\"Data\")\n", |
| 448 | + "plt.plot(x, f(x), label=\"Ground truth\")\n", |
| 449 | + "plt.plot(x, model(x), label=\"Predictions\")\n", |
| 450 | + "plt.legend()\n", |
399 | 451 | "plt.show()\n",
|
400 | 452 | "\n",
|
401 | 453 | "print(\"Current loss: %1.6f\" % loss(model(x), y).numpy())"
|
|
531 | 583 | "colab": {
|
532 | 584 | "collapsed_sections": [
|
533 | 585 | "5rmpybwysXGV",
|
534 |
| - "iKD__8kFCKNt", |
535 |
| - "vPnIVuaSJwWz" |
| 586 | + "iKD__8kFCKNt" |
536 | 587 | ],
|
537 | 588 | "name": "basic_training_loops.ipynb",
|
538 | 589 | "toc_visible": true
|
|
0 commit comments