|
572 | 572 | "# can have significant performance benefits.\n",
|
573 | 573 | "\n",
|
574 | 574 | "class FruitTraceType(tf.types.experimental.TraceType):\n",
|
575 |
| - " def __init__(self, fruit_type):\n", |
576 |
| - " self.fruit_type = fruit_type\n", |
| 575 | + " def __init__(self, fruit):\n", |
| 576 | + " self.fruit_type = type(fruit)\n", |
| 577 | + " self.fruit_value = fruit\n", |
577 | 578 | "\n",
|
578 | 579 | " def is_subtype_of(self, other):\n",
|
| 580 | + " # True if self subtypes `other` and `other`'s type matches FruitTraceType.\n", |
579 | 581 | " return (type(other) is FruitTraceType and\n",
|
580 | 582 | " self.fruit_type is other.fruit_type)\n",
|
581 | 583 | "\n",
|
582 | 584 | " def most_specific_common_supertype(self, others):\n",
|
| 585 | + " # `self` is the specific common supertype if all input types match it.\n", |
583 | 586 | " return self if all(self == other for other in others) else None\n",
|
584 | 587 | "\n",
|
| 588 | + " def placeholder_value(self, placeholder_context=None):\n", |
| 589 | + " # Use the fruit itself instead of the type for correct tracing.\n", |
| 590 | + " return self.fruit_value\n", |
| 591 | + "\n", |
585 | 592 | " def __eq__(self, other):\n",
|
586 | 593 | " return type(other) is FruitTraceType and self.fruit_type == other.fruit_type\n",
|
587 | 594 | " \n",
|
|
591 | 598 | "class FruitWithTraceType:\n",
|
592 | 599 | "\n",
|
593 | 600 | " def __tf_tracing_type__(self, context):\n",
|
594 |
| - " return FruitTraceType(type(self))\n", |
| 601 | + " return FruitTraceType(self)\n", |
595 | 602 | "\n",
|
596 | 603 | "class AppleWithTraceType(FruitWithTraceType):\n",
|
597 | 604 | " flavor = tf.constant([1, 2])\n",
|
|
1831 | 1838 | ],
|
1832 | 1839 | "metadata": {
|
1833 | 1840 | "colab": {
|
1834 |
| - "collapsed_sections": [], |
1835 | 1841 | "name": "function.ipynb",
|
| 1842 | + "provenance": [], |
1836 | 1843 | "toc_visible": true
|
1837 | 1844 | },
|
1838 | 1845 | "kernelspec": {
|
|
0 commit comments