|
145 | 145 | {
|
146 | 146 | "cell_type": "markdown",
|
147 | 147 | "metadata": {
|
148 |
| - "id": "usyRSlIRl3r2" |
| 148 | + "id": "XKakQBI9-FLb" |
149 | 149 | },
|
150 | 150 | "source": [
|
151 |
| - "### Single forward pass validation \n", |
152 |
| - "\n", |
153 |
| - "Single forward pass validation, including checkpoint loading, is covered in a different [colab](./validate_correctness.ipynb)." |
| 151 | + "## Setup" |
154 | 152 | ]
|
155 | 153 | },
|
156 | 154 | {
|
157 | 155 | "cell_type": "code",
|
158 | 156 | "execution_count": null,
|
159 | 157 | "metadata": {
|
160 |
| - "id": "HVBQbsZeVL_V" |
| 158 | + "id": "sopP--i7-LaF" |
161 | 159 | },
|
162 | 160 | "outputs": [],
|
163 | 161 | "source": [
|
164 |
| - "import sys\n", |
165 |
| - "import unittest\n", |
166 |
| - "import numpy as np\n", |
167 |
| - "\n", |
168 |
| - "import tensorflow as tf\n", |
169 |
| - "import tensorflow.compat.v1 as v1" |
| 162 | + "!pip uninstall -y -q tensorflow" |
| 163 | + ] |
| 164 | + }, |
| 165 | + { |
| 166 | + "cell_type": "code", |
| 167 | + "execution_count": null, |
| 168 | + "metadata": { |
| 169 | + "id": "i1ghHyXl-Oqd" |
| 170 | + }, |
| 171 | + "outputs": [], |
| 172 | + "source": [ |
| 173 | + "# Install tf-nightly as the DeterministicRandomTestTool is only available in\n", |
| 174 | + "# Tensorflow 2.8\n", |
| 175 | + "!pip install -q tf-nightly" |
170 | 176 | ]
|
171 | 177 | },
|
172 | 178 | {
|
173 | 179 | "cell_type": "markdown",
|
174 | 180 | "metadata": {
|
175 |
| - "id": "m5XhfpDky1Kw" |
| 181 | + "id": "usyRSlIRl3r2" |
176 | 182 | },
|
177 | 183 | "source": [
|
178 |
| - "Define a context manager to control random number generation." |
| 184 | + "### Single forward pass validation \n", |
| 185 | + "\n", |
| 186 | + "Single forward pass validation, including checkpoint loading, is covered in a different [colab](./validate_correctness.ipynb)." |
179 | 187 | ]
|
180 | 188 | },
|
181 | 189 | {
|
182 | 190 | "cell_type": "code",
|
183 | 191 | "execution_count": null,
|
184 | 192 | "metadata": {
|
185 |
| - "id": "DXEUeSUvkeEV" |
| 193 | + "id": "HVBQbsZeVL_V" |
186 | 194 | },
|
187 | 195 | "outputs": [],
|
188 | 196 | "source": [
|
189 |
| - "seed_implementation = sys.modules[tf.compat.v1.get_seed.__module__]\n", |
190 |
| - "\n", |
191 |
| - "class DeterministicTestTool(object):\n", |
192 |
| - " def __init__(self, seed: int = 42, mode='constant'):\n", |
193 |
| - " \"\"\"Set mode to 'constant' or 'num_random_ops'. Defaults to 'constant'.\"\"\"\n", |
194 |
| - " if mode not in {'constant', 'num_random_ops'}:\n", |
195 |
| - " raise ValueError(\"Mode arg must be 'constant' or 'num_random_ops'. \" +\n", |
196 |
| - " \"Got: {}\".format(mode))\n", |
197 |
| - "\n", |
198 |
| - " self._mode = mode\n", |
199 |
| - " self._seed = seed\n", |
200 |
| - " self.operation_seed = 0\n", |
201 |
| - " self._observed_seeds = set()\n", |
202 |
| - "\n", |
203 |
| - " def scope(self):\n", |
204 |
| - " tf.random.set_seed(self._seed)\n", |
205 |
| - "\n", |
206 |
| - " def _get_seed(_):\n", |
207 |
| - " \"\"\"Wraps TF get_seed to make deterministic random generation easier.\n", |
208 |
| - "\n", |
209 |
| - " This makes a variable's initialization (and calls that involve random\n", |
210 |
| - " number generation) depend only on how many random number generations\n", |
211 |
| - " were used in the scope so far, rather than on how many unrelated\n", |
212 |
| - " operations the graph contains.\n", |
213 |
| - "\n", |
214 |
| - " Returns:\n", |
215 |
| - " Random seed tuple.\n", |
216 |
| - " \"\"\"\n", |
217 |
| - " op_seed = self.operation_seed\n", |
218 |
| - " if self._mode == \"constant\":\n", |
219 |
| - " tf.random.set_seed(op_seed)\n", |
220 |
| - " else:\n", |
221 |
| - " if op_seed in self._observed_seeds:\n", |
222 |
| - " raise ValueError(\n", |
223 |
| - " 'This `DeterministicTestTool` object is trying to re-use the ' +\n", |
224 |
| - " 'already-used operation seed {}. '.format(op_seed) +\n", |
225 |
| - " 'It cannot guarantee random numbers will match between eager ' +\n", |
226 |
| - " 'and sessions when an operation seed is reused. ' +\n", |
227 |
| - " 'You most likely set ' +\n", |
228 |
| - " '`operation_seed` explicitly but used a value that caused the ' +\n", |
229 |
| - " 'naturally-incrementing operation seed sequences to overlap ' +\n", |
230 |
| - " 'with an already-used seed.')\n", |
231 |
| - "\n", |
232 |
| - " self._observed_seeds.add(op_seed)\n", |
233 |
| - " self.operation_seed += 1\n", |
234 |
| - "\n", |
235 |
| - " return (self._seed, op_seed)\n", |
236 |
| - "\n", |
237 |
| - " # mock.patch internal symbols to modify the behavior of TF APIs relying on them\n", |
| 197 | + "import sys\n", |
| 198 | + "import unittest\n", |
| 199 | + "import numpy as np\n", |
238 | 200 | "\n",
|
239 |
| - " return unittest.mock.patch.object(seed_implementation, 'get_seed', wraps=_get_seed)" |
| 201 | + "import tensorflow as tf\n", |
| 202 | + "import tensorflow.compat.v1 as v1" |
240 | 203 | ]
|
241 | 204 | },
|
242 | 205 | {
|
|
362 | 325 | " model_tf1.logs[key].append(logs[key])"
|
363 | 326 | ]
|
364 | 327 | },
|
| 328 | + { |
| 329 | + "cell_type": "markdown", |
| 330 | + "metadata": { |
| 331 | + "id": "kki9yILSKS7f" |
| 332 | + }, |
| 333 | + "source": [ |
| 334 | + "The following [`v1.keras.utils.DeterministicRandomTestTool`](https://www.tensorflow.org/api_docs/python/tf/compat/v1/keras/utils/DeterministicRandomTestTool) class provides a context manager `scope()` that can make stateful random operations use the same seed across both TF1 graphs/sessions and eager execution,\n", |
| 335 | + "\n", |
| 336 | + "The tool provides two testing modes: \n", |
| 337 | + "1. `constant` which uses the same seed for every single operation no matter how many times it has been called and,\n", |
| 338 | + "2. `num_random_ops` which uses the number of previously-observed stateful random operations as the operation seed.\n", |
| 339 | + "\n", |
| 340 | + "This applies both to the stateful random operations used for creating and initializing variables, and to the stateful random operations used in computation (such as for dropout layers)." |
| 341 | + ] |
| 342 | + }, |
| 343 | + { |
| 344 | + "cell_type": "code", |
| 345 | + "execution_count": null, |
| 346 | + "metadata": { |
| 347 | + "id": "X6Y3RWMoKOl8" |
| 348 | + }, |
| 349 | + "outputs": [], |
| 350 | + "source": [ |
| 351 | + "random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')" |
| 352 | + ] |
| 353 | + }, |
365 | 354 | {
|
366 | 355 | "cell_type": "markdown",
|
367 | 356 | "metadata": {
|
|
379 | 368 | },
|
380 | 369 | "outputs": [],
|
381 | 370 | "source": [
|
382 |
| - "random_tool = DeterministicTestTool(mode='num_random_ops')\n", |
383 | 371 | "with random_tool.scope():\n",
|
384 | 372 | " graph = tf.Graph()\n",
|
385 | 373 | " with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:\n",
|
|
483 | 471 | },
|
484 | 472 | "outputs": [],
|
485 | 473 | "source": [
|
486 |
| - "random_tool = DeterministicTestTool(mode='num_random_ops')\n", |
| 474 | + "random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')\n", |
487 | 475 | "with random_tool.scope():\n",
|
488 | 476 | " model_tf2 = SimpleModel(params)\n",
|
489 | 477 | " for step in range(step_num):\n",
|
|
562 | 550 | " step_num = 100\n",
|
563 | 551 | "\n",
|
564 | 552 | " # setup TF 1 model\n",
|
565 |
| - " random_tool = DeterministicTestTool(mode='num_random_ops')\n", |
| 553 | + " random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')\n", |
566 | 554 | " with random_tool.scope():\n",
|
567 | 555 | " # run TF1.x code in graph mode with context management\n",
|
568 | 556 | " graph = tf.Graph()\n",
|
|
583 | 571 | " self.model_tf1.update_logs(logs)\n",
|
584 | 572 | "\n",
|
585 | 573 | " # setup TF2 model\n",
|
586 |
| - " random_tool = DeterministicTestTool(mode='num_random_ops')\n", |
| 574 | + " random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')\n", |
587 | 575 | " with random_tool.scope():\n",
|
588 | 576 | " self.model_tf2 = SimpleModel(params)\n",
|
589 | 577 | " for step in range(step_num):\n",
|
|
811 | 799 | "colab": {
|
812 | 800 | "collapsed_sections": [],
|
813 | 801 | "name": "migration_debugging.ipynb",
|
814 |
| - "provenance": [], |
815 | 802 | "toc_visible": true
|
816 | 803 | },
|
817 | 804 | "kernelspec": {
|
|
0 commit comments