|
74 | 74 | "\n",
|
75 | 75 | "DTensor provides a global programming model that allows developers to compose applications that operate on Tensors globally while managing the distribution across devices internally. DTensor distributes the program and tensors according to the sharding directives through a procedure called *[Single program, multiple data (SPMD)](https://en.wikipedia.org/wiki/SPMD) expansion*.\n",
|
76 | 76 | "\n",
|
77 |
| - "By decoupling the application from sharding directives, DTensor enables running the same application on a single device, multiple devices, or even multiple clients, while preserving its global semantics. \n", |
| 77 | + "By decoupling the application from sharding directives, DTensor enables running the same application on a single device, multiple devices, or even multiple clients, while preserving its global semantics.\n", |
78 | 78 | "\n",
|
79 | 79 | "This guide introduces DTensor concepts for distributed computing, and how DTensor integrates with TensorFlow. To see a demo of using DTensor in model training, see [Distributed training with DTensor](https://www.tensorflow.org/tutorials/distribute/dtensor_ml_tutorial) tutorial."
|
80 | 80 | ]
|
|
130 | 130 | " tf.config.set_logical_device_configuration(phy_devices[0], [\n",
|
131 | 131 | " tf.config.LogicalDeviceConfiguration(),\n",
|
132 | 132 | " ] * ncpu)\n",
|
133 |
| - " \n", |
| 133 | + "\n", |
134 | 134 | "configure_virtual_cpus(6)\n",
|
135 | 135 | "DEVICES = [f'CPU:{i}' for i in range(6)]\n",
|
136 | 136 | "\n",
|
|
147 | 147 | "\n",
|
148 | 148 | "DTensor introduces two concepts: `dtensor.Mesh` and `dtensor.Layout`. They are abstractions to model the sharding of tensors across topologically related devices.\n",
|
149 | 149 | "\n",
|
150 |
| - "- `Mesh` defines the device list for computation. \n", |
151 |
| - "- `Layout` defines how to shard the Tensor dimension on a `Mesh`. " |
| 150 | + "- `Mesh` defines the device list for computation.\n", |
| 151 | + "- `Layout` defines how to shard the Tensor dimension on a `Mesh`." |
152 | 152 | ]
|
153 | 153 | },
|
154 | 154 | {
|
|
223 | 223 | "\n",
|
224 | 224 | "**`Layout`** specifies how a tensor is distributed, or sharded, on a `Mesh`. In DTensor, the placement specification is on a per-axis bases. An axis of a `Layout` can be either `sharded` or `unsharded` (replicated) along a mesh dimension.\n",
|
225 | 225 | "\n",
|
226 |
| - "Note: In order to avoid confusions between `Mesh` and `Layout`, the term *dimension* is always associated with `Mesh`, and the term *axis* with `Tensor` and `Layout` in this guide. \n", |
| 226 | + "Note: In order to avoid confusions between `Mesh` and `Layout`, the term *dimension* is always associated with `Mesh`, and the term *axis* with `Tensor` and `Layout` in this guide.\n", |
227 | 227 | "\n",
|
228 | 228 | "The rank of a `Layout` and the number of dimensions of a `Mesh` do not need to match. The `unsharded` axes of a `Layout` do not need to be associated to a mesh dimension, and `unsharded` mesh dimensions do not need to be associated with a `layout` axis.\n",
|
229 | 229 | "\n",
|
|
298 | 298 | "\n",
|
299 | 299 | "DTensor supports both single-client and multi-client applications. The colab Python kernel is an example of a single client DTensor application, where there is a single Python process.\n",
|
300 | 300 | "\n",
|
301 |
| - "In a multi-client DTensor application, multiple Python processes collectively perform as a coherent application. The Cartisian grid of a `Mesh` in a multi-client DTensor application can span across devices regardless of whether they are attached locally to the current client or attached remotely to another client. The set of all devices used by a `Mesh` are called the *global device list*. \n", |
| 301 | + "In a multi-client DTensor application, multiple Python processes collectively perform as a coherent application. The Cartisian grid of a `Mesh` in a multi-client DTensor application can span across devices regardless of whether they are attached locally to the current client or attached remotely to another client. The set of all devices used by a `Mesh` are called the *global device list*.\n", |
302 | 302 | "\n",
|
303 |
| - "The creation of a `Mesh` in a multi-client DTensor application is a collective operation where the *global device list* is identicial for all of the participating clients, and the creation of the `Mesh` serves as a global barrier. \n", |
| 303 | + "The creation of a `Mesh` in a multi-client DTensor application is a collective operation where the *global device list* is identicial for all of the participating clients, and the creation of the `Mesh` serves as a global barrier.\n", |
304 | 304 | "\n",
|
305 | 305 | "During `Mesh` creation, each client provides its *local device list* together with the expected *global device list*. DTensor validates that both lists are consistent. Please refer to the API documentation for `dtensor.create_mesh` and `dtensor.create_distributed_mesh`\n",
|
306 | 306 | " for more information on multi-client mesh creation and the *global device list*.\n",
|
|
331 | 331 | "source": [
|
332 | 332 | "def dtensor_from_array(arr, layout, shape=None, dtype=None):\n",
|
333 | 333 | " \"\"\"Convert a DTensor from something that looks like an array or Tensor.\n",
|
334 |
| - " \n", |
| 334 | + "\n", |
335 | 335 | " This function is convenient for quick doodling DTensors from a known,\n",
|
336 | 336 | " unsharded data object in a single-client environment. This is not the\n",
|
337 |
| - " most efficient way of creating a DTensor, but it will do for this \n", |
| 337 | + " most efficient way of creating a DTensor, but it will do for this\n", |
338 | 338 | " tutorial.\n",
|
339 | 339 | " \"\"\"\n",
|
340 | 340 | " if shape is not None or dtype is not None:\n",
|
341 | 341 | " arr = tf.constant(arr, shape=shape, dtype=dtype)\n",
|
342 |
| - " \n", |
| 342 | + "\n", |
343 | 343 | " # replicate the input to the mesh\n",
|
344 |
| - " a = dtensor.copy_to_mesh(arr, \n", |
| 344 | + " a = dtensor.copy_to_mesh(arr,\n", |
345 | 345 | " layout=dtensor.Layout.replicated(layout.mesh, rank=layout.rank))\n",
|
346 | 346 | " # shard the copy to the desirable layout\n",
|
347 | 347 | " return dtensor.relayout(a, layout=layout)"
|
|
356 | 356 | "### Anatomy of a DTensor\n",
|
357 | 357 | "\n",
|
358 | 358 | "A DTensor is a `tf.Tensor` object, but augumented with the `Layout` annotation that defines its sharding behavior. A DTensor consists of the following:\n",
|
359 |
| - " \n", |
| 359 | + "\n", |
360 | 360 | " - Global tensor meta-data, including the global shape and dtype of the tensor.\n",
|
361 | 361 | " - A `Layout`, which defines the `Mesh` the `Tensor` belongs to, and how the `Tensor` is sharded onto the `Mesh`.\n",
|
362 |
| - " - A list of **component tensors**, one item per local device in the `Mesh`. \n", |
363 |
| - " \n", |
| 362 | + " - A list of **component tensors**, one item per local device in the `Mesh`.\n", |
| 363 | + "\n", |
364 | 364 | "With `dtensor_from_array`, you can create your first DTensor, `my_first_dtensor`, and examine its contents."
|
365 | 365 | ]
|
366 | 366 | },
|
|
450 | 450 | "source": [
|
451 | 451 | "The inverse operation of `dtensor.unpack` is `dtensor.pack`. Component tensors can be packed back into a DTensor.\n",
|
452 | 452 | "\n",
|
453 |
| - "The components must have the same rank and dtype, which will be the rank and dtype of the returned DTensor. However there is no strict requirement on the device placement of component tensors as inputs of `dtensor.unpack`: the function will automatically copy the component tensors to their respective corresponding devices. \n" |
| 453 | + "The components must have the same rank and dtype, which will be the rank and dtype of the returned DTensor. However there is no strict requirement on the device placement of component tensors as inputs of `dtensor.unpack`: the function will automatically copy the component tensors to their respective corresponding devices.\n", |
| 454 | + "\n" |
454 | 455 | ]
|
455 | 456 | },
|
456 | 457 | {
|
|
504 | 505 | "Create a 3x2 rank-2 DTensor, sharding its first axis along the `'x'` mesh dimension, and its second axis along the `'y'` mesh dimension.\n",
|
505 | 506 | "\n",
|
506 | 507 | "- Because the tensor shape equals to the mesh dimension along all of the sharded axes, each device receives a single element of the DTensor.\n",
|
507 |
| - "- The rank of the component tensor is always the same as the rank of the global shape. DTensor adopts this convention as a simple way to preserve information for locating the relation between a component tensor and the global DTensor. " |
| 508 | + "- The rank of the component tensor is always the same as the rank of the global shape. DTensor adopts this convention as a simple way to preserve information for locating the relation between a component tensor and the global DTensor." |
508 | 509 | ]
|
509 | 510 | },
|
510 | 511 | {
|
|
567 | 568 | "DTensor allows a `Layout` to be a hybrid, sharded along some axes, but replicated along others.\n",
|
568 | 569 | "\n",
|
569 | 570 | "For example, you can shard the same 3x2 rank-2 DTensor in the following way:\n",
|
570 |
| - " \n", |
| 571 | + "\n", |
571 | 572 | " - 1st axis sharded along the `'x'` mesh dimension.\n",
|
572 | 573 | " - 2nd axis replicated along the `'y'` mesh dimension.\n",
|
573 |
| - " \n", |
| 574 | + "\n", |
574 | 575 | "To achieve this sharding scheme, you just need to replace the sharding spec of the 2nd axis from `'y'` to `dtensor.UNSHARDED`, to indicate your intention of replicating along the 2nd axis. The layout object will look like `Layout(['x', dtensor.UNSHARDED], mesh)`."
|
575 | 576 | ]
|
576 | 577 | },
|
|
610 | 611 | "source": [
|
611 | 612 | "#### Tensor.numpy() and sharded DTensor\n",
|
612 | 613 | "\n",
|
613 |
| - "Be aware that calling the `.numpy()` method on a sharded DTensor raises an error. The rationale for erroring is to protect against unintended gathering of data from multiple computing devices to the host CPU device backing the returned numpy array. " |
| 614 | + "Be aware that calling the `.numpy()` method on a sharded DTensor raises an error. The rationale for erroring is to protect against unintended gathering of data from multiple computing devices to the host CPU device backing the returned numpy array." |
614 | 615 | ]
|
615 | 616 | },
|
616 | 617 | {
|
|
624 | 625 | "print(fully_replicated_dtensor.numpy())\n",
|
625 | 626 | "\n",
|
626 | 627 | "try:\n",
|
627 |
| - " fully_sharded_dtensor.numpy() \n", |
| 628 | + " fully_sharded_dtensor.numpy()\n", |
628 | 629 | "except tf.errors.UnimplementedError:\n",
|
629 | 630 | " print(\"got an error as expected for fully_sharded_dtensor\")\n",
|
630 | 631 | "\n",
|
631 | 632 | "try:\n",
|
632 |
| - " hybrid_sharded_dtensor.numpy() \n", |
| 633 | + " hybrid_sharded_dtensor.numpy()\n", |
633 | 634 | "except tf.errors.UnimplementedError:\n",
|
634 | 635 | " print(\"got an error as expected for hybrid_sharded_dtensor\")"
|
635 | 636 | ]
|
|
650 | 651 | " - Rewriting TensorFlow Ops on the global DTensor with equivalent TensorFlow Ops on the componenent tensors, inserting collective and communication Ops when necessary\n",
|
651 | 652 | " - Lowering backend neutral TensorFlow Ops to backend specific TensorFlow Ops.\n",
|
652 | 653 | "\n",
|
653 |
| - "The final result is that **DTensor is a drop-in replacement for Tensor**. \n", |
| 654 | + "The final result is that **DTensor is a drop-in replacement for Tensor**.\n", |
654 | 655 | "\n",
|
655 | 656 | "Note: DTensor is still an experimental API which means you will be exploring and pushing the boundaries and limits of the DTensor programming model.\n",
|
656 | 657 | "\n",
|
657 |
| - "There are 2 ways of triggering DTensor execution: \n", |
| 658 | + "There are 2 ways of triggering DTensor execution:\n", |
658 | 659 | " - DTensor as operands of a Python function, e.g. `tf.matmul(a, b)` will run through DTensor if `a`, `b`, or both are DTensors.\n",
|
659 | 660 | " - Requesting the result of a Python function to be a DTensor, e.g. `dtensor.call_with_layout(tf.ones, layout, shape=(3, 2))` will run through DTensor because we requested the output of tf.ones to be sharded according to a `layout`."
|
660 | 661 | ]
|
|
678 | 679 | "source": [
|
679 | 680 | "#### Fully replicated input and output\n",
|
680 | 681 | "\n",
|
681 |
| - "In this case, the DTensors are fully replicated. On each of the devices of the `Mesh`, \n", |
| 682 | + "In this case, the DTensors are fully replicated. On each of the devices of the `Mesh`,\n", |
682 | 683 | " - the component tensor for operand `a` is `[[1, 2, 3], [4, 5, 6]]` (2x3)\n",
|
683 | 684 | " - the component tensor for operand `b` is `[[6, 5], [4, 3], [2, 1]]` (3x2)\n",
|
684 |
| - " - the computation consists of a single `MatMul` of `(2x3, 3x2) -> 2x2`, \n", |
| 685 | + " - the computation consists of a single `MatMul` of `(2x3, 3x2) -> 2x2`,\n", |
685 | 686 | " - the component tensor for result `c` is `[[20, 14], [56,41]]` (2x2)\n",
|
686 | 687 | "\n",
|
687 | 688 | "Total number of floating point mul operations is `6 device * 4 result * 3 mul = 72`."
|
|
792 | 793 | "\n",
|
793 | 794 | "What about Python functions that do not take operands, but returns a Tensor result that can be sharded? Examples of such functions are\n",
|
794 | 795 | "\n",
|
795 |
| - " - `tf.ones`, `tf.zeros`, `tf.random.stateless_normal`, \n", |
| 796 | + " - `tf.ones`, `tf.zeros`, `tf.random.stateless_normal`,\n", |
796 | 797 | "\n",
|
797 | 798 | "For these Python functions, DTensor provides `dtensor.call_with_layout` which eagelry executes a Python function with DTensor, and ensures that the returned Tensor is a DTensor with the requested `Layout`."
|
798 | 799 | ]
|
|
885 | 886 | "outputs": [],
|
886 | 887 | "source": [
|
887 | 888 | "ones = dtensor.call_with_layout(\n",
|
888 |
| - " tf.function(tf.random.stateless_normal), \n", |
889 |
| - " dtensor.Layout(['x', 'y'], mesh), \n", |
890 |
| - " shape=(6, 4), \n", |
| 889 | + " tf.function(tf.random.stateless_normal),\n", |
| 890 | + " dtensor.Layout(['x', 'y'], mesh),\n", |
| 891 | + " shape=(6, 4),\n", |
891 | 892 | " seed=(1, 1))\n",
|
892 | 893 | "print(ones)"
|
893 | 894 | ]
|
|
910 | 911 | "outputs": [],
|
911 | 912 | "source": [
|
912 | 913 | "ones = dtensor.call_with_layout(\n",
|
913 |
| - " tf.function(tf.ones), \n", |
914 |
| - " dtensor.Layout(['x', 'y'], mesh), \n", |
| 914 | + " tf.function(tf.ones),\n", |
| 915 | + " dtensor.Layout(['x', 'y'], mesh),\n", |
915 | 916 | " shape=(6, 4))\n",
|
916 | 917 | "print(ones)"
|
917 | 918 | ]
|
|
925 | 926 | "### From `tf.Variable` to `dtensor.DVariable`\n",
|
926 | 927 | "\n",
|
927 | 928 | "In Tensorflow, `tf.Variable` is the holder for a mutable `Tensor` value.\n",
|
928 |
| - "With DTensor, the corresponding variable semantics is provided by `dtensor.DVariable`. \n", |
| 929 | + "With DTensor, the corresponding variable semantics is provided by `dtensor.DVariable`.\n", |
929 | 930 | "\n",
|
930 | 931 | "The reason a new type `DVariable` was introduced for DTensor variable is because DVariables have an additional requirement that the layout cannot change from its initial value."
|
931 | 932 | ]
|
|
945 | 946 | " initial_value=dtensor.call_with_layout(\n",
|
946 | 947 | " tf.function(tf.random.stateless_normal),\n",
|
947 | 948 | " layout=layout,\n",
|
948 |
| - " shape=tf.TensorShape([64, 32]), \n", |
| 949 | + " shape=tf.TensorShape([64, 32]),\n", |
949 | 950 | " seed=[1, 1],\n",
|
950 | 951 | " dtype=tf.float32))\n",
|
951 | 952 | "\n",
|
|
0 commit comments