Skip to content

Commit fde5650

Browse files
Merge pull request #2244 from 8bitmp3:jax2tf_update
PiperOrigin-RevId: 547937608
2 parents 810e2e5 + fa487fd commit fde5650

File tree

1 file changed

+43
-42
lines changed

1 file changed

+43
-42
lines changed

site/en/guide/jax2tf.ipynb

Lines changed: 43 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@
111111
"import os\n",
112112
"from matplotlib import pyplot as plt\n",
113113
"from jax.experimental import jax2tf\n",
114-
"from threading import Lock # only used in the visualization utility\n",
114+
"from threading import Lock # Only used in the visualization utility.\n",
115115
"from functools import partial"
116116
]
117117
},
@@ -123,15 +123,15 @@
123123
},
124124
"outputs": [],
125125
"source": [
126-
"# Needed for TF and JAX to coexist in GPU memory\n",
126+
"# Needed for TensorFlow and JAX to coexist in GPU memory.\n",
127127
"os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = \"false\"\n",
128128
"gpus = tf.config.list_physical_devices('GPU')\n",
129129
"if gpus:\n",
130130
" try:\n",
131131
" for gpu in gpus:\n",
132132
" tf.config.experimental.set_memory_growth(gpu, True)\n",
133133
" except RuntimeError as e:\n",
134-
" # Memory growth must be set before GPUs have been initialized\n",
134+
" # Memory growth must be set before GPUs have been initialized.\n",
135135
" print(e)"
136136
]
137137
},
@@ -148,12 +148,12 @@
148148
"\n",
149149
"plt.rcParams[\"figure.figsize\"] = (20,8)\n",
150150
"\n",
151-
"# utility to display training and validation curves\n",
151+
"# The utility for displaying training and validation curves.\n",
152152
"def display_train_curves(loss, avg_loss, eval_loss, eval_accuracy, epochs, steps_per_epochs, ignore_first_n=10):\n",
153153
"\n",
154154
" ignore_first_n_epochs = int(ignore_first_n/steps_per_epochs)\n",
155155
"\n",
156-
" # Losses\n",
156+
" # The losses.\n",
157157
" ax = plt.subplot(121)\n",
158158
" if loss is not None:\n",
159159
" x = np.arange(len(loss)) / steps_per_epochs #* epochs\n",
@@ -172,7 +172,7 @@
172172
" ax.set_ylim(ymin-(ymax-ymin)/10, ymax+(ymax-ymin)/10)\n",
173173
" ax.legend(['avg train', 'eval'])\n",
174174
"\n",
175-
" #Accuracy\n",
175+
" # The accuracy.\n",
176176
" ax = plt.subplot(122)\n",
177177
" ax.set_title('Eval Accuracy')\n",
178178
" ax.set_ylabel('accuracy')\n",
@@ -197,7 +197,7 @@
197197
" :param msg: the message displayed in the header of the progress bar\n",
198198
" \"\"\"\n",
199199
" self.maxi = maxi\n",
200-
" self.p = self.__start_progress(maxi)() # () to get the iterator from the generator\n",
200+
" self.p = self.__start_progress(maxi)() # `()`: to get the iterator from the generator.\n",
201201
" self.header_printed = False\n",
202202
" self.msg = msg\n",
203203
" self.size = size\n",
@@ -232,7 +232,7 @@
232232
" d -= dx\n",
233233
" d += dy\n",
234234
" yield k\n",
235-
" # keep yielding the last result if too many steps\n",
235+
" # Keep yielding the last result if there are too many steps.\n",
236236
" while True:\n",
237237
" yield k\n",
238238
"\n",
@@ -295,17 +295,17 @@
295295
},
296296
"outputs": [],
297297
"source": [
298-
"# Training hyperparams\n",
298+
"# Training hyperparameters.\n",
299299
"JAX_EPOCHS = 3\n",
300300
"TF_EPOCHS = 7\n",
301301
"STEPS_PER_EPOCH = len(train_labels)//BATCH_SIZE\n",
302302
"LEARNING_RATE = 0.01\n",
303303
"LEARNING_RATE_EXP_DECAY = 0.6\n",
304304
"\n",
305-
"# Learning Rate schedule for JAX\n",
305+
"# The learning rate schedule for JAX (with Optax).\n",
306306
"jlr_decay = optax.exponential_decay(LEARNING_RATE, transition_steps=STEPS_PER_EPOCH, decay_rate=LEARNING_RATE_EXP_DECAY, staircase=True)\n",
307307
"\n",
308-
"# Learning Rate schedule for TF\n",
308+
"# THe learning rate schedule for TensorFlow.\n",
309309
"tflr_decay = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=LEARNING_RATE, decay_steps=STEPS_PER_EPOCH, decay_rate=LEARNING_RATE_EXP_DECAY, staircase=True)"
310310
]
311311
},
@@ -341,17 +341,17 @@
341341
" #x = flax.linen.log_softmax(x)\n",
342342
" return x\n",
343343
"\n",
344-
" # JAX differentiation requires a function f(params, other_state, data, labels) -> loss (as a single number)\n",
345-
" # jax.grad will differentiate it against the fist argument.\n",
346-
" # The user must split trainable and non-trainable variables into \"params\" and \"other_state\"\n",
347-
" # Must pass a different RNG Key each time for dropout mask to be different\n",
344+
" # JAX differentiation requires a function `f(params, other_state, data, labels)` -> `loss` (as a single number).\n",
345+
" # `jax.grad` will differentiate it against the fist argument.\n",
346+
" # The user must split trainable and non-trainable variables into `params` and `other_state`.\n",
347+
" # Must pass a different RNG key each time for the dropout mask to be different.\n",
348348
" def loss(self, params, other_state, rng, data, labels, train):\n",
349349
" logits, batch_stats = self.apply({'params': params, **other_state},\n",
350350
" data,\n",
351351
" mutable=['batch_stats'],\n",
352352
" rngs={'dropout': rng},\n",
353353
" train=train)\n",
354-
" # loss averaged across batch dimension\n",
354+
" # The loss averaged across the batch dimension.\n",
355355
" loss = optax.softmax_cross_entropy(logits, labels).mean()\n",
356356
" return loss, batch_stats\n",
357357
"\n",
@@ -374,7 +374,7 @@
374374
"id": "7Cr0FRNFtHN4"
375375
},
376376
"source": [
377-
"## Write the train_step"
377+
"## Write the training step function"
378378
]
379379
},
380380
{
@@ -385,7 +385,7 @@
385385
},
386386
"outputs": [],
387387
"source": [
388-
"# Training step\n",
388+
"# The training step.\n",
389389
"@partial(jax.jit, static_argnums=[0]) # this forces jax.jit to recompile for every new model\n",
390390
"def train_step(model, state, optimizer_state, rng, data, labels):\n",
391391
"\n",
@@ -423,18 +423,18 @@
423423
" rng = jax.random.PRNGKey(0)\n",
424424
" for epoch in range(epochs):\n",
425425
"\n",
426-
" # this is where the learning rate schedule state is stored in the optimizer state\n",
426+
" # This is where the learning rate schedule state is stored in the optimizer state.\n",
427427
" optimizer_step = optimizer_state[1].count\n",
428428
"\n",
429-
" # run an epoch of training\n",
429+
" # Run an epoch of training.\n",
430430
" for step, (data, labels) in enumerate(train_data):\n",
431431
" p.step(reset=(step==0))\n",
432432
" state, optimizer_state, rng, loss = train_step(model, state, optimizer_state, rng, data.numpy(), labels.numpy())\n",
433433
" losses.append(loss)\n",
434434
" avg_loss = np.mean(losses[-step:])\n",
435435
" avg_losses.append(avg_loss)\n",
436436
"\n",
437-
" # run one epoch of evals (10,000 test images in a single batch)\n",
437+
" # Run one epoch of evals (10,000 test images in a single batch).\n",
438438
" other_state, params = state.pop('params')\n",
439439
" # Gotcha: must discard modified batch_stats here\n",
440440
" eval_loss, _ = model.loss(params, other_state, rng, all_test_data.numpy(), all_test_labels.numpy(), train=False)\n",
@@ -453,7 +453,7 @@
453453
"id": "DGB3W5g0Wt1H"
454454
},
455455
"source": [
456-
"## Create the model, and optimizer (with Optax)"
456+
"## Create the model and the optimizer (with Optax)"
457457
]
458458
},
459459
{
@@ -464,11 +464,11 @@
464464
},
465465
"outputs": [],
466466
"source": [
467-
"# Model\n",
467+
"# The model.\n",
468468
"model = ConvModel()\n",
469469
"state = model.init({'params':jax.random.PRNGKey(0), 'dropout':jax.random.PRNGKey(0)}, one_batch, train=True) # Flax allows a separate RNG for \"dropout\"\n",
470470
"\n",
471-
"# Optimizer\n",
471+
"# The optimizer.\n",
472472
"optimizer = optax.adam(learning_rate=jlr_decay) # Gotcha: it does not seem to be possible to pass just a callable as LR, must be an Optax Schedule\n",
473473
"optimizer_state = optimizer.init(state['params'])\n",
474474
"\n",
@@ -531,7 +531,7 @@
531531
"model = ConvModel()\n",
532532
"state = model.init({'params':jax.random.PRNGKey(0), 'dropout':jax.random.PRNGKey(0)}, one_batch, train=True) # Flax allows a separate RNG for \"dropout\"\n",
533533
"\n",
534-
"# Optimizer\n",
534+
"# The optimizer.\n",
535535
"optimizer = optax.adam(learning_rate=jlr_decay) # LR must be an Optax LR Schedule\n",
536536
"optimizer_state = optimizer.init(state['params'])\n",
537537
"\n",
@@ -567,7 +567,8 @@
567567
},
568568
"source": [
569569
"## Save just enough for inference\n",
570-
"If your goal is deploy your JAX model (so you can run inference using `model.predict()`), simply exporting it to [SavedModel](https://www.tensorflow.org/guide/saved_model) is sufficient. This section demonstrates how to accomplish that."
570+
"\n",
571+
"If your goal is to deploy your JAX model (so you can run inference using `model.predict()`), simply exporting it to [SavedModel](https://www.tensorflow.org/guide/saved_model) is sufficient. This section demonstrates how to accomplish that."
571572
]
572573
},
573574
{
@@ -578,17 +579,17 @@
578579
},
579580
"outputs": [],
580581
"source": [
581-
"# test data with different batch size to test polymorphic shapes\n",
582-
"x,y = next(iter(train_data.unbatch().batch(13)))\n",
582+
"# Test data with a different batch size to test polymorphic shapes.\n",
583+
"x, y = next(iter(train_data.unbatch().batch(13)))\n",
583584
"\n",
584585
"m = tf.Module()\n",
585-
"# wrap JAX state in tf.Variable (needed when calling converted JAX function\n",
586+
"# Wrap the JAX state in `tf.Variable` (needed when calling the converted JAX function.\n",
586587
"state_vars = tf.nest.map_structure(tf.Variable, state)\n",
587-
"# keep the wrapped state as flat list (needed in TF fine-tuning)\n",
588+
"# Keep the wrapped state as flat list (needed in TensorFlow fine-tuning).\n",
588589
"m.vars = tf.nest.flatten(state_vars)\n",
589-
"# convert the desired JAX function (model.predict)\n",
590+
"# Convert the desired JAX function (`model.predict`).\n",
590591
"predict_fn = jax2tf.convert(model.predict, polymorphic_shapes=[\"...\", \"(b, 28, 28, 1)\"])\n",
591-
"# wrap converted function in tf.function with correct TensorSpec (necessary for dynamic shapes to work)\n",
592+
"# Wrap the converted function in `tf.function` with the correct `tf.TensorSpec` (necessary for dynamic shapes to work).\n",
592593
"@tf.function(autograph=False, input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32)])\n",
593594
"def predict(data):\n",
594595
" return predict_fn(state_vars, data)\n",
@@ -604,12 +605,12 @@
604605
},
605606
"outputs": [],
606607
"source": [
607-
"# test the converted function\n",
608-
"print(\"converted function predictions:\", np.argmax(m.predict(x).numpy(), axis=-1))\n",
609-
"#reload the model\n",
608+
"# Test the converted function.\n",
609+
"print(\"Converted function predictions:\", np.argmax(m.predict(x).numpy(), axis=-1))\n",
610+
"# Reload the model.\n",
610611
"reloaded_model = tf.saved_model.load(\"./\")\n",
611-
"# test the reloaded converted function (should be same result)\n",
612-
"print(\"reloaded function predictions:\", np.argmax(reloaded_model.predict(x).numpy(), axis=-1))"
612+
"# Test the reloaded converted function (the result should be the same).\n",
613+
"print(\"Reloaded function predictions:\", np.argmax(reloaded_model.predict(x).numpy(), axis=-1))"
613614
]
614615
},
615616
{
@@ -725,10 +726,10 @@
725726
},
726727
"outputs": [],
727728
"source": [
728-
"# instantiate the model\n",
729+
"# Instantiate the model.\n",
729730
"tf_model = TFModel(state, model)\n",
730731
"\n",
731-
"# save\n",
732+
"# Save the model.\n",
732733
"tf.saved_model.save(tf_model, \"./\")"
733734
]
734735
},
@@ -751,7 +752,7 @@
751752
"source": [
752753
"reloaded_model = tf.saved_model.load(\"./\")\n",
753754
"\n",
754-
"# test if it works and that the batch size is indeed variable\n",
755+
"# Test if it works and that the batch size is indeed variable.\n",
755756
"x,y = next(iter(train_data.unbatch().batch(13)))\n",
756757
"print(np.argmax(reloaded_model.predict(x).numpy(), axis=-1))\n",
757758
"x,y = next(iter(train_data.unbatch().batch(20)))\n",
@@ -780,14 +781,14 @@
780781
"source": [
781782
"optimizer = tf.keras.optimizers.Adam(learning_rate=tflr_decay)\n",
782783
"\n",
783-
"# set the iteration step for the LR to resume from where it left off in JAX\n",
784+
"# Set the iteration step for the learning rate to resume from where it left off in JAX.\n",
784785
"optimizer.iterations.assign(len(eval_losses)*STEPS_PER_EPOCH)\n",
785786
"\n",
786787
"p = Progress(STEPS_PER_EPOCH)\n",
787788
"\n",
788789
"for epoch in range(JAX_EPOCHS, JAX_EPOCHS+TF_EPOCHS):\n",
789790
"\n",
790-
" # this is where the learning rate schedule state is stored in the optimizer state\n",
791+
" # This is where the learning rate schedule state is stored in the optimizer state.\n",
791792
" optimizer_step = optimizer.iterations\n",
792793
"\n",
793794
" for step, (data, labels) in enumerate(train_data):\n",

0 commit comments

Comments
 (0)