|
91 | 91 | },
|
92 | 92 | "outputs": [],
|
93 | 93 | "source": [
|
94 |
| - "# Update TensorFlow, as this notebook requires version 2.9 or later\n", |
95 |
| - "!pip install -q -U tensorflow>=2.9.0\n", |
96 | 94 | "import tensorflow as tf"
|
97 | 95 | ]
|
98 | 96 | },
|
|
362 | 360 | "source": [
|
363 | 361 | "#### Rules of tracing\n",
|
364 | 362 | "\n",
|
365 |
| - "When called, a `Function` matches the call arguments to existing `ConcreteFunction`s using `tf.types.experimental.TraceType` of each argument. If a matching `ConcreteFunction` is found, the call is dispatched to it. If no match is found, a new `ConcreteFunction` is traced. \n", |
| 363 | + "When called, a `Function` matches the call arguments to existing `ConcreteFunction`s using `tf.types.experimental.TraceType` of each argument. If a matching `ConcreteFunction` is found, the call is dispatched to it. If no match is found, a new `ConcreteFunction` is traced.\n", |
366 | 364 | "\n",
|
367 | 365 | "If multiple matches are found, the most specific signature is chosen. Matching is done by [subtyping](https://en.wikipedia.org/wiki/Subtyping), much like normal function calls in C++ or Java, for instance. For example, `TensorShape([1, 2])` is a subtype of `TensorShape([None, None])` and so a call to the tf.function with `TensorShape([1, 2])` can be dispatched to the `ConcreteFunction` produced with `TensorShape([None, None])` but if a `ConcreteFunction` with `TensorShape([1, None])` also exists then it will prioritized since it is more specific.\n",
|
368 | 366 | "\n",
|
|
422 | 420 | "\n",
|
423 | 421 | "print(next_collatz(tf.constant([1, 2])))\n",
|
424 | 422 | "# You specified a 1-D tensor in the input signature, so this should fail.\n",
|
425 |
| - "with assert_raises(ValueError):\n", |
| 423 | + "with assert_raises(TypeError):\n", |
426 | 424 | " next_collatz(tf.constant([[1, 2], [3, 4]]))\n",
|
427 | 425 | "\n",
|
428 | 426 | "# You specified an int32 dtype in the input signature, so this should fail.\n",
|
429 |
| - "with assert_raises(ValueError):\n", |
| 427 | + "with assert_raises(TypeError):\n", |
430 | 428 | " next_collatz(tf.constant([1.0, 2.0]))\n"
|
431 | 429 | ]
|
432 | 430 | },
|
|
560 | 558 | " flavor = tf.constant([3, 4])\n",
|
561 | 559 | "\n",
|
562 | 560 | "# As described in the above rules, a generic TraceType for `Apple` and `Mango`\n",
|
563 |
| - "# is generated (and a corresponding ConcreteFunction is traced) but it fails to \n", |
564 |
| - "# match the second function call since the first pair of Apple() and Mango() \n", |
| 561 | + "# is generated (and a corresponding ConcreteFunction is traced) but it fails to\n", |
| 562 | + "# match the second function call since the first pair of Apple() and Mango()\n", |
565 | 563 | "# have gone out out of scope by then and deleted.\n",
|
566 | 564 | "get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function\n",
|
567 | 565 | "get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function again\n",
|
|
591 | 589 | "\n",
|
592 | 590 | " def __eq__(self, other):\n",
|
593 | 591 | " return type(other) is FruitTraceType and self.fruit_type == other.fruit_type\n",
|
594 |
| - " \n", |
| 592 | + "\n", |
595 | 593 | " def __hash__(self):\n",
|
596 | 594 | " return hash(self.fruit_type)\n",
|
597 | 595 | "\n",
|
|
970 | 968 | "id": "JeD2U-yrbfVb"
|
971 | 969 | },
|
972 | 970 | "source": [
|
973 |
| - "When wrapping Python/NumPy data in a Dataset, be mindful of `tf.data.Dataset.from_generator` versus ` tf.data.Dataset.from_tensors`. The former will keep the data in Python and fetch it via `tf.py_function` which can have performance implications, whereas the latter will bundle a copy of the data as one large `tf.constant()` node in the graph, which can have memory implications.\n", |
| 971 | + "When wrapping Python/NumPy data in a Dataset, be mindful of `tf.data.Dataset.from_generator` versus ` tf.data.Dataset.from_tensor_slices`. The former will keep the data in Python and fetch it via `tf.py_function` which can have performance implications, whereas the latter will bundle a copy of the data as one large `tf.constant()` node in the graph, which can have memory implications.\n", |
974 | 972 | "\n",
|
975 | 973 | "Reading data from files via `TFRecordDataset`, `CsvDataset`, etc. is the most effective way to consume data, as then TensorFlow itself can manage the asynchronous loading and prefetching of data, without having to involve Python. To learn more, see the [`tf.data`: Build TensorFlow input pipelines](../../guide/data) guide."
|
976 | 974 | ]
|
|
1608 | 1606 | "new_model = SimpleModel()\n",
|
1609 | 1607 | "evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)\n",
|
1610 | 1608 | "# Don't pass in `new_model`, `Function` already captured its state during tracing.\n",
|
1611 |
| - "print(evaluate_no_bias(x)) " |
| 1609 | + "print(evaluate_no_bias(x))" |
1612 | 1610 | ]
|
1613 | 1611 | },
|
1614 | 1612 | {
|
|
1752 | 1750 | "source": [
|
1753 | 1751 | "opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)\n",
|
1754 | 1752 | "opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)\n",
|
1755 |
| - " \n", |
| 1753 | + "\n", |
1756 | 1754 | "@tf.function\n",
|
1757 | 1755 | "def train_step(w, x, y, optimizer):\n",
|
1758 | 1756 | " with tf.GradientTape() as tape:\n",
|
|
1802 | 1800 | "y = tf.constant([2.])\n",
|
1803 | 1801 | "\n",
|
1804 | 1802 | "# Make a new Function and ConcreteFunction for each optimizer.\n",
|
1805 |
| - "train_step_1 = tf.function(train_step).get_concrete_function(w, x, y, opt1)\n", |
1806 |
| - "train_step_2 = tf.function(train_step).get_concrete_function(w, x, y, opt2)\n", |
| 1803 | + "train_step_1 = tf.function(train_step)\n", |
| 1804 | + "train_step_2 = tf.function(train_step)\n", |
1807 | 1805 | "for i in range(10):\n",
|
1808 | 1806 | " if i % 2 == 0:\n",
|
1809 |
| - " train_step_1(w, x, y) # `opt1` is not used as a parameter. \n", |
| 1807 | + " train_step_1(w, x, y, opt1)\n", |
1810 | 1808 | " else:\n",
|
1811 |
| - " train_step_2(w, x, y) # `opt2` is not used as a parameter." |
| 1809 | + " train_step_2(w, x, y, opt2)" |
1812 | 1810 | ]
|
1813 | 1811 | },
|
1814 | 1812 | {
|
|
0 commit comments