|
111 | 111 | "import os\n",
|
112 | 112 | "from matplotlib import pyplot as plt\n",
|
113 | 113 | "from jax.experimental import jax2tf\n",
|
114 |
| - "from threading import Lock # only used in the visualization utility\n", |
| 114 | + "from threading import Lock # Only used in the visualization utility.\n", |
115 | 115 | "from functools import partial"
|
116 | 116 | ]
|
117 | 117 | },
|
|
123 | 123 | },
|
124 | 124 | "outputs": [],
|
125 | 125 | "source": [
|
126 |
| - "# Needed for TF and JAX to coexist in GPU memory\n", |
| 126 | + "# Needed for TensorFlow and JAX to coexist in GPU memory.\n", |
127 | 127 | "os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = \"false\"\n",
|
128 | 128 | "gpus = tf.config.list_physical_devices('GPU')\n",
|
129 | 129 | "if gpus:\n",
|
130 | 130 | " try:\n",
|
131 | 131 | " for gpu in gpus:\n",
|
132 | 132 | " tf.config.experimental.set_memory_growth(gpu, True)\n",
|
133 | 133 | " except RuntimeError as e:\n",
|
134 |
| - " # Memory growth must be set before GPUs have been initialized\n", |
| 134 | + " # Memory growth must be set before GPUs have been initialized.\n", |
135 | 135 | " print(e)"
|
136 | 136 | ]
|
137 | 137 | },
|
|
148 | 148 | "\n",
|
149 | 149 | "plt.rcParams[\"figure.figsize\"] = (20,8)\n",
|
150 | 150 | "\n",
|
151 |
| - "# utility to display training and validation curves\n", |
| 151 | + "# The utility for displaying training and validation curves.\n", |
152 | 152 | "def display_train_curves(loss, avg_loss, eval_loss, eval_accuracy, epochs, steps_per_epochs, ignore_first_n=10):\n",
|
153 | 153 | "\n",
|
154 | 154 | " ignore_first_n_epochs = int(ignore_first_n/steps_per_epochs)\n",
|
155 | 155 | "\n",
|
156 |
| - " # Losses\n", |
| 156 | + " # The losses.\n", |
157 | 157 | " ax = plt.subplot(121)\n",
|
158 | 158 | " if loss is not None:\n",
|
159 | 159 | " x = np.arange(len(loss)) / steps_per_epochs #* epochs\n",
|
|
172 | 172 | " ax.set_ylim(ymin-(ymax-ymin)/10, ymax+(ymax-ymin)/10)\n",
|
173 | 173 | " ax.legend(['avg train', 'eval'])\n",
|
174 | 174 | "\n",
|
175 |
| - " #Accuracy\n", |
| 175 | + " # The accuracy.\n", |
176 | 176 | " ax = plt.subplot(122)\n",
|
177 | 177 | " ax.set_title('Eval Accuracy')\n",
|
178 | 178 | " ax.set_ylabel('accuracy')\n",
|
|
197 | 197 | " :param msg: the message displayed in the header of the progress bar\n",
|
198 | 198 | " \"\"\"\n",
|
199 | 199 | " self.maxi = maxi\n",
|
200 |
| - " self.p = self.__start_progress(maxi)() # () to get the iterator from the generator\n", |
| 200 | + " self.p = self.__start_progress(maxi)() # `()`: to get the iterator from the generator.\n", |
201 | 201 | " self.header_printed = False\n",
|
202 | 202 | " self.msg = msg\n",
|
203 | 203 | " self.size = size\n",
|
|
232 | 232 | " d -= dx\n",
|
233 | 233 | " d += dy\n",
|
234 | 234 | " yield k\n",
|
235 |
| - " # keep yielding the last result if too many steps\n", |
| 235 | + " # Keep yielding the last result if there are too many steps.\n", |
236 | 236 | " while True:\n",
|
237 | 237 | " yield k\n",
|
238 | 238 | "\n",
|
|
295 | 295 | },
|
296 | 296 | "outputs": [],
|
297 | 297 | "source": [
|
298 |
| - "# Training hyperparams\n", |
| 298 | + "# Training hyperparameters.\n", |
299 | 299 | "JAX_EPOCHS = 3\n",
|
300 | 300 | "TF_EPOCHS = 7\n",
|
301 | 301 | "STEPS_PER_EPOCH = len(train_labels)//BATCH_SIZE\n",
|
302 | 302 | "LEARNING_RATE = 0.01\n",
|
303 | 303 | "LEARNING_RATE_EXP_DECAY = 0.6\n",
|
304 | 304 | "\n",
|
305 |
| - "# Learning Rate schedule for JAX\n", |
| 305 | + "# The learning rate schedule for JAX (with Optax).\n", |
306 | 306 | "jlr_decay = optax.exponential_decay(LEARNING_RATE, transition_steps=STEPS_PER_EPOCH, decay_rate=LEARNING_RATE_EXP_DECAY, staircase=True)\n",
|
307 | 307 | "\n",
|
308 |
| - "# Learning Rate schedule for TF\n", |
| 308 | + "# THe learning rate schedule for TensorFlow.\n", |
309 | 309 | "tflr_decay = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=LEARNING_RATE, decay_steps=STEPS_PER_EPOCH, decay_rate=LEARNING_RATE_EXP_DECAY, staircase=True)"
|
310 | 310 | ]
|
311 | 311 | },
|
|
341 | 341 | " #x = flax.linen.log_softmax(x)\n",
|
342 | 342 | " return x\n",
|
343 | 343 | "\n",
|
344 |
| - " # JAX differentiation requires a function f(params, other_state, data, labels) -> loss (as a single number)\n", |
345 |
| - " # jax.grad will differentiate it against the fist argument.\n", |
346 |
| - " # The user must split trainable and non-trainable variables into \"params\" and \"other_state\"\n", |
347 |
| - " # Must pass a different RNG Key each time for dropout mask to be different\n", |
| 344 | + " # JAX differentiation requires a function `f(params, other_state, data, labels)` -> `loss` (as a single number).\n", |
| 345 | + " # `jax.grad` will differentiate it against the fist argument.\n", |
| 346 | + " # The user must split trainable and non-trainable variables into `params` and `other_state`.\n", |
| 347 | + " # Must pass a different RNG key each time for the dropout mask to be different.\n", |
348 | 348 | " def loss(self, params, other_state, rng, data, labels, train):\n",
|
349 | 349 | " logits, batch_stats = self.apply({'params': params, **other_state},\n",
|
350 | 350 | " data,\n",
|
351 | 351 | " mutable=['batch_stats'],\n",
|
352 | 352 | " rngs={'dropout': rng},\n",
|
353 | 353 | " train=train)\n",
|
354 |
| - " # loss averaged across batch dimension\n", |
| 354 | + " # The loss averaged across the batch dimension.\n", |
355 | 355 | " loss = optax.softmax_cross_entropy(logits, labels).mean()\n",
|
356 | 356 | " return loss, batch_stats\n",
|
357 | 357 | "\n",
|
|
374 | 374 | "id": "7Cr0FRNFtHN4"
|
375 | 375 | },
|
376 | 376 | "source": [
|
377 |
| - "## Write the train_step" |
| 377 | + "## Write the training step function" |
378 | 378 | ]
|
379 | 379 | },
|
380 | 380 | {
|
|
385 | 385 | },
|
386 | 386 | "outputs": [],
|
387 | 387 | "source": [
|
388 |
| - "# Training step\n", |
| 388 | + "# The training step.\n", |
389 | 389 | "@partial(jax.jit, static_argnums=[0]) # this forces jax.jit to recompile for every new model\n",
|
390 | 390 | "def train_step(model, state, optimizer_state, rng, data, labels):\n",
|
391 | 391 | "\n",
|
|
423 | 423 | " rng = jax.random.PRNGKey(0)\n",
|
424 | 424 | " for epoch in range(epochs):\n",
|
425 | 425 | "\n",
|
426 |
| - " # this is where the learning rate schedule state is stored in the optimizer state\n", |
| 426 | + " # This is where the learning rate schedule state is stored in the optimizer state.\n", |
427 | 427 | " optimizer_step = optimizer_state[1].count\n",
|
428 | 428 | "\n",
|
429 |
| - " # run an epoch of training\n", |
| 429 | + " # Run an epoch of training.\n", |
430 | 430 | " for step, (data, labels) in enumerate(train_data):\n",
|
431 | 431 | " p.step(reset=(step==0))\n",
|
432 | 432 | " state, optimizer_state, rng, loss = train_step(model, state, optimizer_state, rng, data.numpy(), labels.numpy())\n",
|
433 | 433 | " losses.append(loss)\n",
|
434 | 434 | " avg_loss = np.mean(losses[-step:])\n",
|
435 | 435 | " avg_losses.append(avg_loss)\n",
|
436 | 436 | "\n",
|
437 |
| - " # run one epoch of evals (10,000 test images in a single batch)\n", |
| 437 | + " # Run one epoch of evals (10,000 test images in a single batch).\n", |
438 | 438 | " other_state, params = state.pop('params')\n",
|
439 | 439 | " # Gotcha: must discard modified batch_stats here\n",
|
440 | 440 | " eval_loss, _ = model.loss(params, other_state, rng, all_test_data.numpy(), all_test_labels.numpy(), train=False)\n",
|
|
453 | 453 | "id": "DGB3W5g0Wt1H"
|
454 | 454 | },
|
455 | 455 | "source": [
|
456 |
| - "## Create the model, and optimizer (with Optax)" |
| 456 | + "## Create the model and the optimizer (with Optax)" |
457 | 457 | ]
|
458 | 458 | },
|
459 | 459 | {
|
|
464 | 464 | },
|
465 | 465 | "outputs": [],
|
466 | 466 | "source": [
|
467 |
| - "# Model\n", |
| 467 | + "# The model.\n", |
468 | 468 | "model = ConvModel()\n",
|
469 | 469 | "state = model.init({'params':jax.random.PRNGKey(0), 'dropout':jax.random.PRNGKey(0)}, one_batch, train=True) # Flax allows a separate RNG for \"dropout\"\n",
|
470 | 470 | "\n",
|
471 |
| - "# Optimizer\n", |
| 471 | + "# The optimizer.\n", |
472 | 472 | "optimizer = optax.adam(learning_rate=jlr_decay) # Gotcha: it does not seem to be possible to pass just a callable as LR, must be an Optax Schedule\n",
|
473 | 473 | "optimizer_state = optimizer.init(state['params'])\n",
|
474 | 474 | "\n",
|
|
531 | 531 | "model = ConvModel()\n",
|
532 | 532 | "state = model.init({'params':jax.random.PRNGKey(0), 'dropout':jax.random.PRNGKey(0)}, one_batch, train=True) # Flax allows a separate RNG for \"dropout\"\n",
|
533 | 533 | "\n",
|
534 |
| - "# Optimizer\n", |
| 534 | + "# The optimizer.\n", |
535 | 535 | "optimizer = optax.adam(learning_rate=jlr_decay) # LR must be an Optax LR Schedule\n",
|
536 | 536 | "optimizer_state = optimizer.init(state['params'])\n",
|
537 | 537 | "\n",
|
|
567 | 567 | },
|
568 | 568 | "source": [
|
569 | 569 | "## Save just enough for inference\n",
|
570 |
| - "If your goal is deploy your JAX model (so you can run inference using `model.predict()`), simply exporting it to [SavedModel](https://www.tensorflow.org/guide/saved_model) is sufficient. This section demonstrates how to accomplish that." |
| 570 | + "\n", |
| 571 | + "If your goal is to deploy your JAX model (so you can run inference using `model.predict()`), simply exporting it to [SavedModel](https://www.tensorflow.org/guide/saved_model) is sufficient. This section demonstrates how to accomplish that." |
571 | 572 | ]
|
572 | 573 | },
|
573 | 574 | {
|
|
578 | 579 | },
|
579 | 580 | "outputs": [],
|
580 | 581 | "source": [
|
581 |
| - "# test data with different batch size to test polymorphic shapes\n", |
582 |
| - "x,y = next(iter(train_data.unbatch().batch(13)))\n", |
| 582 | + "# Test data with a different batch size to test polymorphic shapes.\n", |
| 583 | + "x, y = next(iter(train_data.unbatch().batch(13)))\n", |
583 | 584 | "\n",
|
584 | 585 | "m = tf.Module()\n",
|
585 |
| - "# wrap JAX state in tf.Variable (needed when calling converted JAX function\n", |
| 586 | + "# Wrap the JAX state in `tf.Variable` (needed when calling the converted JAX function.\n", |
586 | 587 | "state_vars = tf.nest.map_structure(tf.Variable, state)\n",
|
587 |
| - "# keep the wrapped state as flat list (needed in TF fine-tuning)\n", |
| 588 | + "# Keep the wrapped state as flat list (needed in TensorFlow fine-tuning).\n", |
588 | 589 | "m.vars = tf.nest.flatten(state_vars)\n",
|
589 |
| - "# convert the desired JAX function (model.predict)\n", |
| 590 | + "# Convert the desired JAX function (`model.predict`).\n", |
590 | 591 | "predict_fn = jax2tf.convert(model.predict, polymorphic_shapes=[\"...\", \"(b, 28, 28, 1)\"])\n",
|
591 |
| - "# wrap converted function in tf.function with correct TensorSpec (necessary for dynamic shapes to work)\n", |
| 592 | + "# Wrap the converted function in `tf.function` with the correct `tf.TensorSpec` (necessary for dynamic shapes to work).\n", |
592 | 593 | "@tf.function(autograph=False, input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32)])\n",
|
593 | 594 | "def predict(data):\n",
|
594 | 595 | " return predict_fn(state_vars, data)\n",
|
|
604 | 605 | },
|
605 | 606 | "outputs": [],
|
606 | 607 | "source": [
|
607 |
| - "# test the converted function\n", |
608 |
| - "print(\"converted function predictions:\", np.argmax(m.predict(x).numpy(), axis=-1))\n", |
609 |
| - "#reload the model\n", |
| 608 | + "# Test the converted function.\n", |
| 609 | + "print(\"Converted function predictions:\", np.argmax(m.predict(x).numpy(), axis=-1))\n", |
| 610 | + "# Reload the model.\n", |
610 | 611 | "reloaded_model = tf.saved_model.load(\"./\")\n",
|
611 |
| - "# test the reloaded converted function (should be same result)\n", |
612 |
| - "print(\"reloaded function predictions:\", np.argmax(reloaded_model.predict(x).numpy(), axis=-1))" |
| 612 | + "# Test the reloaded converted function (the result should be the same).\n", |
| 613 | + "print(\"Reloaded function predictions:\", np.argmax(reloaded_model.predict(x).numpy(), axis=-1))" |
613 | 614 | ]
|
614 | 615 | },
|
615 | 616 | {
|
|
725 | 726 | },
|
726 | 727 | "outputs": [],
|
727 | 728 | "source": [
|
728 |
| - "# instantiate the model\n", |
| 729 | + "# Instantiate the model.\n", |
729 | 730 | "tf_model = TFModel(state, model)\n",
|
730 | 731 | "\n",
|
731 |
| - "# save\n", |
| 732 | + "# Save the model.\n", |
732 | 733 | "tf.saved_model.save(tf_model, \"./\")"
|
733 | 734 | ]
|
734 | 735 | },
|
|
751 | 752 | "source": [
|
752 | 753 | "reloaded_model = tf.saved_model.load(\"./\")\n",
|
753 | 754 | "\n",
|
754 |
| - "# test if it works and that the batch size is indeed variable\n", |
| 755 | + "# Test if it works and that the batch size is indeed variable.\n", |
755 | 756 | "x,y = next(iter(train_data.unbatch().batch(13)))\n",
|
756 | 757 | "print(np.argmax(reloaded_model.predict(x).numpy(), axis=-1))\n",
|
757 | 758 | "x,y = next(iter(train_data.unbatch().batch(20)))\n",
|
|
780 | 781 | "source": [
|
781 | 782 | "optimizer = tf.keras.optimizers.Adam(learning_rate=tflr_decay)\n",
|
782 | 783 | "\n",
|
783 |
| - "# set the iteration step for the LR to resume from where it left off in JAX\n", |
| 784 | + "# Set the iteration step for the learning rate to resume from where it left off in JAX.\n", |
784 | 785 | "optimizer.iterations.assign(len(eval_losses)*STEPS_PER_EPOCH)\n",
|
785 | 786 | "\n",
|
786 | 787 | "p = Progress(STEPS_PER_EPOCH)\n",
|
787 | 788 | "\n",
|
788 | 789 | "for epoch in range(JAX_EPOCHS, JAX_EPOCHS+TF_EPOCHS):\n",
|
789 | 790 | "\n",
|
790 |
| - " # this is where the learning rate schedule state is stored in the optimizer state\n", |
| 791 | + " # This is where the learning rate schedule state is stored in the optimizer state.\n", |
791 | 792 | " optimizer_step = optimizer.iterations\n",
|
792 | 793 | "\n",
|
793 | 794 | " for step, (data, labels) in enumerate(train_data):\n",
|
|
0 commit comments