From ca914452566c811b8bce4eec350d88480d874937 Mon Sep 17 00:00:00 2001 From: Zeev Melumian Date: Fri, 18 Jul 2025 22:42:25 +0300 Subject: [PATCH 001/133] Add support for callable in torchax.interop.JittableModule.functional_call in the first parameter (#9451) Co-authored-by: zmelumian --- torchax/test/test_jittable_module.py | 19 +++++++++++++++++++ torchax/torchax/interop.py | 14 ++++++++++++-- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/torchax/test/test_jittable_module.py b/torchax/test/test_jittable_module.py index 52bbbbcf7ab7..3a44450cbe8f 100644 --- a/torchax/test/test_jittable_module.py +++ b/torchax/test/test_jittable_module.py @@ -34,6 +34,25 @@ def test_isinstance_does_not_mix(self): assert isinstance(JittableMoreAwesomeModel, EvenMoreAwesomeModel) assert not isinstance(JittableMoreAwesomeModel, MyAwesomeModel) + def test_functional_call_callable(self): + + def outer_function(model, x): + return x + 1 + + model = MyAwesomeModel() + jittable_module = interop.JittableModule(model) + + # Check if the jittable module can be called like a function + input_tensor = torch.randn(1, 3, 224, 224) + expected_output = input_tensor + 1 + + output = jittable_module.functional_call(outer_function, + jittable_module.params, + jittable_module.buffers, + input_tensor) + + assert torch.equal(output, expected_output) + if __name__ == '__main__': unittest.main() diff --git a/torchax/torchax/interop.py b/torchax/torchax/interop.py index 8d2b01bc6398..bceb008ef4b0 100644 --- a/torchax/torchax/interop.py +++ b/torchax/torchax/interop.py @@ -90,7 +90,7 @@ def __class__(self): def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) - def functional_call(self, method_name, params, buffers, *args, **kwargs): + def functional_call(self, method_or_name, params, buffers, *args, **kwargs): kwargs = kwargs or {} params_copy = copy.copy(params) params_copy.update(buffers) @@ -98,8 +98,18 @@ def functional_call(self, method_name, params, buffers, *args, **kwargs): for k, v in self._extra_dumped_weights.items(): for new_key in v: params_copy[new_key] = params_copy[k] + + if isinstance(method_or_name, str): + method = getattr(self._model, method_or_name) + else: + if not callable(method_or_name): + raise TypeError( + f"method_or_name should be a callable or a string, got {type(method_or_name)}" + ) + method = method_or_name + args = (self._model,) + args with torch_stateless._reparametrize_module(self._model, params_copy): - res = getattr(self._model, method_name)(*args, **kwargs) + res = method(*args, **kwargs) return res def jittable_call(self, method_name: str, *args, **kwargs): From 86a99d7589cd23296f03bccb8b229b7c0bb7df5b Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Fri, 18 Jul 2025 12:46:17 -0700 Subject: [PATCH 002/133] Update README.md to reflect supported python versions (#9484) Co-authored-by: qihqi --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 9d20948f35f8..b9c1f4a55544 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,8 @@ Note: Builds are available for Python 3.8 to 3.11; please use one of the support pip install torch==2.7.0 'torch_xla[tpu]==2.7.0' ``` - +**As of 07/16/2025 and starting from Pytorch/XLA 2.8 release, PyTorch/XLA will +provide nightly and release wheels for Python 3.11 to 3.13** To install PyTorch/XLA nightly build in a new TPU VM: ```sh From f3c7907fb01ad2d7b6d903eba64f794b74513cb3 Mon Sep 17 00:00:00 2001 From: qihqi Date: Fri, 18 Jul 2025 13:12:27 -0700 Subject: [PATCH 003/133] Remove support for one-process-per-device style of distributed. (#9490) --- torchax/examples/mnist_tpu.ipynb | 647 ------------------------ torchax/examples/train_gpt/train_ddp.py | 140 ----- torchax/test_dist/README.md | 4 - torchax/test_dist/__init__.py | 0 torchax/test_dist/test_distributed.py | 154 ------ torchax/torchax/__init__.py | 1 - torchax/torchax/distributed.py | 241 --------- 7 files changed, 1187 deletions(-) delete mode 100644 torchax/examples/mnist_tpu.ipynb delete mode 100644 torchax/examples/train_gpt/train_ddp.py delete mode 100644 torchax/test_dist/README.md delete mode 100644 torchax/test_dist/__init__.py delete mode 100644 torchax/test_dist/test_distributed.py delete mode 100644 torchax/torchax/distributed.py diff --git a/torchax/examples/mnist_tpu.ipynb b/torchax/examples/mnist_tpu.ipynb deleted file mode 100644 index 8ffcc9ec27d0..000000000000 --- a/torchax/examples/mnist_tpu.ipynb +++ /dev/null @@ -1,647 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "tHNudaKYx4Ci", - "outputId": "d72d15a1-483f-4820-de3c-0ef8905cb1ed" - }, - "outputs": [], - "source": [ - "# Uncomment and run these if you haven't already installed `torchax`\n", - "#!pip uninstall -y tensorflow\n", - "#!pip install tpu-info 'torchax[tpu] @ git+https://github.com/pytorch/xla.git#subdirectory=experimental/torchax' -f https://storage.googleapis.com/libtpu-releases/index.html\n", - "#!pip install torchvision" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Distributed training with `torchax`\n", - "\n", - "This Notebook demonstrates how to perform distributed training using `torchax`, which allows you to run PyTorch models with JAX." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Dataset and model setup\n", - "\n", - "Below, we download and preprocess the MNIST dataset and instantiate a simple neural network to use as an example. The details here aren't important here. You can follow the same steps below for any PyTorch model and dataset.\n", - "\n", - "A couple of important notes about this section:\n", - "\n", - "- When we're loading data, the batch will be split across all local devices.\n", - "- `model` remains on the CPU device. We'll move it to the TPU in the next step." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "dbNWnxtizF-Z" - }, - "outputs": [], - "source": [ - "import torch\n", - "import torchvision\n", - "import torchvision.transforms as transforms\n", - "\n", - "train_dataset = torchvision.datasets.MNIST(\n", - " root='./data',\n", - " train=True,\n", - " download=True,\n", - " transform=transforms.Compose(\n", - " [transforms.ToTensor(),\n", - " transforms.Normalize((0.1307,), (0.3081,))]))\n", - "test_dataset = torchvision.datasets.MNIST(\n", - " root='./data',\n", - " train=False,\n", - " download=True,\n", - " transform=transforms.Compose(\n", - " [transforms.ToTensor(),\n", - " transforms.Normalize((0.1307,), (0.3081,))]))\n", - "\n", - "train_loader = torch.utils.data.DataLoader(\n", - " train_dataset,\n", - " batch_size=128,\n", - " drop_last=True,\n", - " shuffle=True)\n", - "test_loader = torch.utils.data.DataLoader(\n", - " test_dataset,\n", - " batch_size=128,\n", - " drop_last=True,\n", - " shuffle=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "id": "_p2gxDdv6RYo" - }, - "outputs": [], - "source": [ - "import torch.nn as nn\n", - "\n", - "model = nn.Sequential(\n", - " nn.Flatten(),\n", - " nn.Linear(784, 512),\n", - " nn.ReLU(),\n", - " nn.Linear(512, 512),\n", - " nn.ReLU(),\n", - " nn.Linear(512, 512),\n", - " nn.ReLU(),\n", - " nn.Linear(512, 10)\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Replicating the model across devices\n", - "\n", - "Most TPU configurations include multiple TPU cores per host. For example, a v4-8 TPU has 4 chips total. We can use `tpu-info` to see how many devices are available on this host." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[3mTPU Chips \u001b[0m\n", - "┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mDevice \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mType \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mCores\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mPID \u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━┩\n", - "│ /dev/accel0 │ TPU v4 chip │ 1 │ None │\n", - "│ /dev/accel1 │ TPU v4 chip │ 1 │ None │\n", - "│ /dev/accel2 │ TPU v4 chip │ 1 │ None │\n", - "│ /dev/accel3 │ TPU v4 chip │ 1 │ None │\n", - "└─────────────┴─────────────┴───────┴──────┘\n", - "Libtpu metrics unavailable. Did you start a workload with `TPU_RUNTIME_METRICS_PORTS=8431,8432,8433,8434`?\n" - ] - } - ], - "source": [ - "!tpu-info" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "`torchax` uses JAX as a backend, so we can use JAX to double-check the device count. Don't worry -- we won't have to directly use JAX to run the model." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "M1wGEXY4yRvG", - "outputId": "4bea9105-062d-45d6-bd37-d47e9d06cad6" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "4" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import jax\n", - "\n", - "# The TPU core count will vary depending on your environment.\n", - "jax.device_count()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The device count above should match the output of `tpu-info` (4 devices in the case of a v4-8).\n", - "\n", - "In this example, we'll use `torchax`'s custom `DistributedDataParallel` implementation to replicate the model parameters across all available TPU devices and split input data between each core." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "Y9uhN5Om0f25" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/wcromar/tx2/.venv/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:270: UserWarning: Device capability of jax unspecified, assuming `cpu` and `cuda`. Please specify it via the `devices` argument of `register_backend`.\n", - " warnings.warn(\n" - ] - } - ], - "source": [ - "import torchax\n", - "\n", - "ddp_model = torchax.distributed.DistributedDataParallel(model)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can dig into the underlying JAX array to see that there's an identical copy of the parameter tensor on each TPU device:" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "example_param = next(ddp_model.parameters())" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[Shard(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), index=(slice(None, None, None), slice(None, None, None)), replica_id=0, data=[[ 0.03249096 0.01343462 -0.022144 ... 0.00668433 0.00833362\n", - " 0.00225713]\n", - " [ 0.02272127 0.02205281 0.00828168 ... -0.02310903 0.02183958\n", - " 0.01084254]\n", - " [-0.01985117 -0.01139126 -0.00223861 ... -0.02136385 0.0339912\n", - " -0.02596978]\n", - " ...\n", - " [ 0.0168394 0.0063334 -0.02949585 ... -0.0254653 0.03273752\n", - " -0.02620777]\n", - " [-0.00896274 -0.03342744 -0.0269749 ... 0.01811987 0.03423703\n", - " -0.02689848]\n", - " [ 0.01867637 0.0117135 0.02216029 ... 0.00011777 0.02212651\n", - " 0.00852821]]),\n", - " Shard(device=TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), index=(slice(None, None, None), slice(None, None, None)), replica_id=1, data=[[ 0.03249096 0.01343462 -0.022144 ... 0.00668433 0.00833362\n", - " 0.00225713]\n", - " [ 0.02272127 0.02205281 0.00828168 ... -0.02310903 0.02183958\n", - " 0.01084254]\n", - " [-0.01985117 -0.01139126 -0.00223861 ... -0.02136385 0.0339912\n", - " -0.02596978]\n", - " ...\n", - " [ 0.0168394 0.0063334 -0.02949585 ... -0.0254653 0.03273752\n", - " -0.02620777]\n", - " [-0.00896274 -0.03342744 -0.0269749 ... 0.01811987 0.03423703\n", - " -0.02689848]\n", - " [ 0.01867637 0.0117135 0.02216029 ... 0.00011777 0.02212651\n", - " 0.00852821]]),\n", - " Shard(device=TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), index=(slice(None, None, None), slice(None, None, None)), replica_id=2, data=[[ 0.03249096 0.01343462 -0.022144 ... 0.00668433 0.00833362\n", - " 0.00225713]\n", - " [ 0.02272127 0.02205281 0.00828168 ... -0.02310903 0.02183958\n", - " 0.01084254]\n", - " [-0.01985117 -0.01139126 -0.00223861 ... -0.02136385 0.0339912\n", - " -0.02596978]\n", - " ...\n", - " [ 0.0168394 0.0063334 -0.02949585 ... -0.0254653 0.03273752\n", - " -0.02620777]\n", - " [-0.00896274 -0.03342744 -0.0269749 ... 0.01811987 0.03423703\n", - " -0.02689848]\n", - " [ 0.01867637 0.0117135 0.02216029 ... 0.00011777 0.02212651\n", - " 0.00852821]]),\n", - " Shard(device=TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0), index=(slice(None, None, None), slice(None, None, None)), replica_id=3, data=[[ 0.03249096 0.01343462 -0.022144 ... 0.00668433 0.00833362\n", - " 0.00225713]\n", - " [ 0.02272127 0.02205281 0.00828168 ... -0.02310903 0.02183958\n", - " 0.01084254]\n", - " [-0.01985117 -0.01139126 -0.00223861 ... -0.02136385 0.0339912\n", - " -0.02596978]\n", - " ...\n", - " [ 0.0168394 0.0063334 -0.02949585 ... -0.0254653 0.03273752\n", - " -0.02620777]\n", - " [-0.00896274 -0.03342744 -0.0269749 ... 0.01811987 0.03423703\n", - " -0.02689848]\n", - " [ 0.01867637 0.0117135 0.02216029 ... 0.00011777 0.02212651\n", - " 0.00852821]])]\n" - ] - } - ], - "source": [ - "import pprint\n", - "pprint.pprint(example_param._elem.addressable_shards)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The replicated tensor still behaves as a plain PyTorch tensor, however:" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Tensor( [[ 0.03249096 0.01343462 -0.022144 ... 0.00668433 0.00833362\n", - " 0.00225713]\n", - " [ 0.02272127 0.02205281 0.00828168 ... -0.02310903 0.02183958\n", - " 0.01084254]\n", - " [-0.01985117 -0.01139126 -0.00223861 ... -0.02136385 0.0339912\n", - " -0.02596978]\n", - " ...\n", - " [ 0.0168394 0.0063334 -0.02949585 ... -0.0254653 0.03273752\n", - " -0.02620777]\n", - " [-0.00896274 -0.03342744 -0.0269749 ... 0.01811987 0.03423703\n", - " -0.02689848]\n", - " [ 0.01867637 0.0117135 0.02216029 ... 0.00011777 0.02212651\n", - " 0.00852821]])" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "example_param" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Sharding inputs\n", - "\n", - "Unlike the model parameters, we want to send a different shard of the input data to each device. We'll take one batch of images as an example:" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([128, 1, 28, 28])" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "example_images, _ = next(iter(train_loader))\n", - "example_images.shape" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Sharding the input batch across devices does not change the overall size of the tensor:" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(128, 1, 28, 28)" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sharded_example_images = ddp_model.shard_input(example_images)\n", - "sharded_example_images.shape" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "If we dig into the underlying JAX array, we can see that the input has been split (into quarters in this case) across the batch dimension:" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[(32, 1, 28, 28), (32, 1, 28, 28), (32, 1, 28, 28), (32, 1, 28, 28)]" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "[s.data.shape for s in sharded_example_images._elem.addressable_shards]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Putting it all together\n", - "\n", - "`torchax` allows us to seamlessly shard and replicate tensors across devices, while still maintaining a singular view of that tensor through PyTorch. With some minor changes, we can adapt the conventional PyTorch training loop to use the TPU.\n", - "\n", - "Note that we do not have to spawn any child processes. Although each parameter and input is represented by one tensor, that tensor is already distributed across multiple devices." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The loss function and optimizer stay the same:" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "id": "E5QjcpuY1hx5" - }, - "outputs": [], - "source": [ - "loss_fn = torch.nn.CrossEntropyLoss()\n", - "optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.001, momentum=0.9)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "JAX gets significantly better performance when compiled, normally through `jax.jit`. `torchax`'s DDP implementation contains a utility `jit_step` that can be used to compile a training step. Note that for this to work, the training step must be separated out into a function. Otherwise, the actual contents are the same as they would be for eager CPU or GPU." - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "id": "AojhVVzx0ZEG" - }, - "outputs": [], - "source": [ - "@ddp_model.jit_step\n", - "def train_step(sharded_inputs, sharded_labels):\n", - " optimizer.zero_grad()\n", - " outputs = ddp_model(sharded_inputs)\n", - " loss = loss_fn(outputs, sharded_labels)\n", - " loss.backward()\n", - " optimizer.step()\n", - "\n", - " return loss" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Finally, let's quickly run training for several epochs and check the validation results:" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "id": "QhO7V7JR2l8A" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 0\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - " batch 0 loss: 2.3075523376464844\n", - " batch 100 loss: 2.3029651641845703\n", - " batch 200 loss: 2.2921366691589355\n", - " batch 300 loss: 2.2877070903778076\n", - " batch 400 loss: 2.274242401123047\n", - "Epoch 1\n", - " batch 0 loss: 2.2708349227905273\n", - " batch 100 loss: 2.269294261932373\n", - " batch 200 loss: 2.2480335235595703\n", - " batch 300 loss: 2.243983268737793\n", - " batch 400 loss: 2.2470455169677734\n", - "Epoch 2\n", - " batch 0 loss: 2.234013557434082\n", - " batch 100 loss: 2.2184624671936035\n", - " batch 200 loss: 2.2029666900634766\n", - " batch 300 loss: 2.198725461959839\n", - " batch 400 loss: 2.1829864978790283\n", - "Epoch 3\n", - " batch 0 loss: 2.1811957359313965\n", - " batch 100 loss: 2.1297898292541504\n", - " batch 200 loss: 2.1378531455993652\n", - " batch 300 loss: 2.0720174312591553\n", - " batch 400 loss: 2.0413732528686523\n", - "Epoch 4\n", - " batch 0 loss: 2.046309471130371\n", - " batch 100 loss: 1.9817270040512085\n", - " batch 200 loss: 1.9381718635559082\n", - " batch 300 loss: 1.847656011581421\n", - " batch 400 loss: 1.808678388595581\n", - "Epoch 5\n", - " batch 0 loss: 1.7617125511169434\n", - " batch 100 loss: 1.768508791923523\n", - " batch 200 loss: 1.6427236795425415\n", - " batch 300 loss: 1.6908036470413208\n", - " batch 400 loss: 1.538255214691162\n", - "Epoch 6\n", - " batch 0 loss: 1.4774806499481201\n", - " batch 100 loss: 1.4533928632736206\n", - " batch 200 loss: 1.2804057598114014\n", - " batch 300 loss: 1.2498115301132202\n", - " batch 400 loss: 1.116618275642395\n", - "Epoch 7\n", - " batch 0 loss: 1.1049035787582397\n", - " batch 100 loss: 1.0565766096115112\n", - " batch 200 loss: 1.0216108560562134\n", - " batch 300 loss: 0.9548335671424866\n", - " batch 400 loss: 0.8766275644302368\n", - "Epoch 8\n", - " batch 0 loss: 0.7384852766990662\n", - " batch 100 loss: 0.8499367237091064\n", - " batch 200 loss: 0.8409233689308167\n", - " batch 300 loss: 0.7746399641036987\n", - " batch 400 loss: 0.8063997030258179\n", - "Epoch 9\n", - " batch 0 loss: 0.7310354709625244\n", - " batch 100 loss: 0.825514018535614\n", - " batch 200 loss: 0.6718677878379822\n", - " batch 300 loss: 0.7210809588432312\n", - " batch 400 loss: 0.7002769708633423\n" - ] - } - ], - "source": [ - "for epoch in range(10):\n", - " running_loss = 0\n", - "\n", - " print('Epoch', epoch)\n", - " for i, data in enumerate(train_loader):\n", - " inputs, labels = data\n", - " # Distribute the batch across all TPU cores\n", - " sharded_inputs, sharded_labels = ddp_model.shard_input(inputs), ddp_model.shard_input(labels)\n", - " loss = train_step(sharded_inputs, sharded_labels)\n", - "\n", - " if i % 100 == 0:\n", - " print(' batch {} loss: {}'.format(i, loss.item()))\n", - " running_loss = 0." - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Validation loss 0.6315549612045288\n" - ] - } - ], - "source": [ - "@ddp_model.jit_step\n", - "def eval_step(sharded_vinputs, sharded_vlabels):\n", - " voutputs = ddp_model(sharded_vinputs)\n", - " vloss = loss_fn(voutputs, sharded_vlabels)\n", - " return vloss\n", - "\n", - "ddp_model.eval()\n", - "running_vloss = 0.\n", - "\n", - "# Disable gradient computation and reduce memory consumption.\n", - "with torch.no_grad():\n", - " for i, vdata in enumerate(test_loader):\n", - " vinputs, vlabels = vdata\n", - " sharded_vinputs, sharded_vlabels = ddp_model.shard_input(vinputs), ddp_model.shard_input(vlabels)\n", - " vloss = eval_step(sharded_vinputs, sharded_vlabels)\n", - " running_vloss += vloss\n", - "\n", - "avg_vloss = running_vloss / (i + 1)\n", - "print('Validation loss', avg_vloss.item())" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Conclusion\n", - "\n", - "With some minor changes to your training loop, `torchax` allows you to distribute a model across multiple devices and run a compiled version with JAX. All of the data you interact with directly is still a `torch` tensor, and JAX handles all of the distributed details in the background.\n", - "\n", - "`torchax` (and especially training) is still under heavy development. To learn more about the project and its current status, see https://github.com/pytorch/xla/tree/master/experimental/torchax" - ] - } - ], - "metadata": { - "accelerator": "TPU", - "colab": { - "gpuType": "V28", - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/torchax/examples/train_gpt/train_ddp.py b/torchax/examples/train_gpt/train_ddp.py deleted file mode 100644 index d3b68d7230d7..000000000000 --- a/torchax/examples/train_gpt/train_ddp.py +++ /dev/null @@ -1,140 +0,0 @@ -"""WIP example using `minGPT` with DistributedDataParallel on both CPU and JAX. - -Required `mingpt` package for model definition (see requirements.txt). Some -hyperparameters and training configuration borrowed from nanoGPT: -https://github.com/karpathy/nanoGPT - -Example command (single host): -torchrun --standalone xla/experimental/torchax/examples/train_gpt/train_ddp.py - -Tested on a TPU v4-8 -""" - -import datetime -import jax -import torch -import torch.utils.data -import torch.utils.data.distributed -import torch.distributed as dist -import torch.optim as optim -import torchax -from tqdm import tqdm -from mingpt.model import GPT -from datasets import load_dataset -import tiktoken -import pathlib -import torch.utils._pytree as torch_pytree - - -def _checkpoint(jax_model, path: pathlib.Path): - torch.save( - torch_pytree.tree_map_only( - torchax.tensor.Tensor, - torchax.tensor.Tensor.torch, - jax_model.state_dict(), - ), - path, - ) - - -def main(): - dist.init_process_group(backend="gloo") - dataset_name = "Skylion007/openwebtext" - dataset = load_dataset(dataset_name, split="train", trust_remote_code=True) - - enc = tiktoken.get_encoding("gpt2") - - def tokenize(ex): - """Tokenize each example and append the end-of-text token.""" - ids = enc.encode_ordinary(ex["text"]) - ids.append(enc.eot_token) - return {"ids": ids} - - dataset = dataset.map(tokenize, num_proc=16) - - def group_texts(exs): - """Group batches of tokens into `block_size` chunks.""" - cat = torch.cat([torch.tensor(ex) for ex in exs["ids"]]) - total_len = cat.size()[0] - num_chunks = total_len // 1025 - split = torch.split(cat[:num_chunks * 1025], 1025) - xs = [ex[:-1] for ex in split] - ys = [ex[1:] for ex in split] - return {"x": xs, "y": ys} - - dataset = dataset.map( - group_texts, batched=True, remove_columns=["text", "ids"], num_proc=16) - dataset.shard(dist.get_world_size(), dist.get_rank()) - env = torchax.default_env() - - print(jax.device_count(), "devices") - - torch.manual_seed(0) - per_device_batch_size = 8 - local_batch_size = jax.local_device_count() * per_device_batch_size - global_batch_size = jax.device_count() * per_device_batch_size - dataloader = torch.utils.data.DataLoader( - dataset.with_format("torch"), batch_size=local_batch_size, drop_last=True) - - # Create model and wrap with DDP - def create_model(): - torch.manual_seed(0) - model_config = GPT.get_default_config() - model_config.model_type = "gpt2" - model_config.vocab_size = enc.n_vocab - model_config.block_size = 1024 - # TODO: use bf16 when erroneous type promotions are fixed - return GPT(model_config) # .to(dtype=torch.bfloat16) - - checkpoint_subdir = pathlib.Path( - "checkpoints") / datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - checkpoint_subdir.mkdir(parents=True) - jax_model = torchax.distributed.DistributedDataParallel(create_model(), env) - - # TODO: LR scheduler - jax_optimizer = optim.SGD(jax_model.parameters(), lr=6e-4, weight_decay=0.1) - - # Contents of `step_fn` can be inlined if using eager - @jax_model.jit_step - def step_fn(jax_data, jax_target): - jax_optimizer.zero_grad() - jax_output, jax_loss = jax_model(jax_data, jax_target) - jax_loss.backward() - torch.nn.utils.clip_grad_norm_(jax_model.parameters(), 1.0) - jax_optimizer.step() - - return jax_output, jax_loss - - tokens_per_batch = global_batch_size * 1024 - - for epoch in range(1): - print("epoch", epoch) - for i, batch in enumerate( - tqdm(dataloader, unit="tok", unit_scale=tokens_per_batch)): - data, target = batch["x"], batch["y"] - jax_data, jax_target = env.j2t_iso( - (jax_model.shard_input(data), jax_model.shard_input(target))) - jax_output, jax_loss = step_fn(jax_data, jax_target) - - if i % 1000 == 0: - _checkpoint(jax_model, checkpoint_subdir / "gpt2_124m_{epoch}_{i}.ckpt") - print("step", i, jax_loss.item()) - - with torch.no_grad(): - with env: - inp = enc.encode("This GPT-2 example is") - input_jax = torch.tensor([inp], dtype=torch.long) - # TODO: need to access underlying module for methods - jax_generated = jax_model._module.generate( - jax_model.replicate_input(input_jax), - 100, - do_sample=False, - ) - - print("input sequence:", inp, enc.decode(inp)) - print(jax_generated) - print("predicted (JAX):", enc.decode(jax_generated.numpy().tolist())) - - -if __name__ == "__main__": - main() diff --git a/torchax/test_dist/README.md b/torchax/test_dist/README.md deleted file mode 100644 index f2e0cd36908d..000000000000 --- a/torchax/test_dist/README.md +++ /dev/null @@ -1,4 +0,0 @@ -This directory contains multi-accelerator tests that cannot be distributed with -`pytest-xdist`. - -TODO: merge these into `tests/` diff --git a/torchax/test_dist/__init__.py b/torchax/test_dist/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/torchax/test_dist/test_distributed.py b/torchax/test_dist/test_distributed.py deleted file mode 100644 index 59b966f0ef61..000000000000 --- a/torchax/test_dist/test_distributed.py +++ /dev/null @@ -1,154 +0,0 @@ -import os -import jax -import numpy as np - -import pytest -import torch -import torch.distributed._functional_collectives -import torch.distributed as dist -import torchax -import torchax.distributed - -# Dummy group name to use with functional collectives. Ignored by -# implementations. -# TODO(wcromar): do something useful with group name -GROUP_NAME = "process_group" - -torchax.enable_globally() - - -@pytest.fixture(scope="module") -def multi_cpu(): - # TODO(wcromar): support other device counts - assert (jax.device_count() == 4 - ), "Set XLA_FLAGS=--xla_force_host_platform_device_count=4 if on CPU" - - yield jax.device_count() - - -@pytest.fixture() -def process_group(): - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12355" - dist.init_process_group(backend="jax", init_method="jax://") - # HACK: our default process group has world size 1, regardless of actual - # device count. Only put rank 0 so PyTorch doesn't complain about non-existent - # ranks. Our lowerings ignore this list, so this ends up being fine. - # TODO(wcromar): Figure out if there's a cleaner way - group_ranks = [0] - yield group_ranks - dist.destroy_process_group() - - -def test_all_gather_tensor(multi_cpu, process_group): - device_count = multi_cpu - - def f(index: torchax.tensor.Tensor): - with torchax.default_env(): - output = torch.zeros_like(index).expand(device_count) - dist.all_gather_into_tensor(output, index) - return output - - res = torchax.distributed.spawn(f) - - expected_tensors = [[0, 1, 2, 3] for _ in range(device_count)] - np.testing.assert_equal([r.numpy() for r in res], expected_tensors) - - -def test_all_gather_tensor_func(multi_cpu, process_group): - device_count = multi_cpu - group_ranks = process_group - - def f(index: torchax.tensor.Tensor): - return torch.distributed._functional_collectives.all_gather_tensor( - index, 0, group_ranks) - - res = torchax.distributed.spawn(f) - - expected_tensors = [[0, 1, 2, 3] for _ in range(device_count)] - np.testing.assert_equal([r.numpy() for r in res], expected_tensors) - - -@pytest.mark.parametrize( - ("op", "expected"), - [ - (dist.ReduceOp.SUM, sum(range(4))), - (dist.ReduceOp.AVG, sum(range(4)) // 4), - (dist.ReduceOp.MIN, 0), - (dist.ReduceOp.MAX, 3), - ], -) -def test_all_reduce(op, expected, multi_cpu, process_group): - device_count = multi_cpu - - def f(index): - with torchax.default_env(): - dist.all_reduce(index, op) - return index - - res = torchax.distributed.spawn(f) - - expected_tensors = [expected for _ in range(device_count)] - np.testing.assert_equal(res.numpy(), expected_tensors) - - -@pytest.mark.parametrize( - ("op", "expected"), - [ - ("sum", sum(range(4))), - ("avg", sum(range(4)) / 4), - ("min", 0), - ("max", 3), - ], -) -def test_all_reduce_func(op, expected, multi_cpu): - device_count = multi_cpu - - def f(index): - return torch.distributed._functional_collectives.all_reduce( - index, op, GROUP_NAME) - - res = torchax.distributed.spawn(f) - - expected_tensors = [expected for _ in range(device_count)] - np.testing.assert_equal(res.numpy(), expected_tensors) - - -@pytest.mark.parametrize( - ("rank", "expected"), - [ - (0, 0), - (2, 2), - ], -) -def test_broadcast(rank, expected, multi_cpu, process_group): - device_count = multi_cpu - - def f(index): - dist.broadcast(index, rank) - return index - - res = torchax.distributed.spawn(f) - - expected_tensors = [expected for _ in range(device_count)] - np.testing.assert_equal(res.numpy(), expected_tensors) - - -@pytest.mark.parametrize( - ("rank", "expected"), - [ - (0, 0), - (2, 2), - ], -) -def test_broadcast_func(rank, expected, multi_cpu): - device_count = multi_cpu - - def f(index): - return torch.distributed._functional_collectives.broadcast( - index, rank, GROUP_NAME) - - res = torchax.distributed.spawn(f) - - expected_tensors = [expected for _ in range(device_count)] - np.testing.assert_equal(res.numpy(), expected_tensors) diff --git a/torchax/torchax/__init__.py b/torchax/torchax/__init__.py index e078db6c83ed..d08ba31071c5 100644 --- a/torchax/torchax/__init__.py +++ b/torchax/torchax/__init__.py @@ -6,7 +6,6 @@ import torch from torch.utils import _pytree as pytree from torchax import tensor -from torchax import distributed # noqa: F401 from contextlib import contextmanager __version__ = "0.0.5" diff --git a/torchax/torchax/distributed.py b/torchax/torchax/distributed.py deleted file mode 100644 index eb12f4eb2d56..000000000000 --- a/torchax/torchax/distributed.py +++ /dev/null @@ -1,241 +0,0 @@ -"""`torch.distributed` backend implemented with JAX collective ops. - -EXPERIMENTAL: This module is still highly experimental, and it may be removed -before any stable release. - -Note: JAX collective ops require that axis names be defined in `pmap` or -`shmap`. The distributed backend only supports one axis, named `torch_dist`. -This name is defined by our mirror implementation of `spawn`. -""" - -import datetime -import functools -import logging -import os -from typing import List, Optional, Union - -import jax -import numpy as np -import torch -import torch.distributed as dist -import torch.distributed._functional_collectives -from torch._C._distributed_c10d import ProcessGroup # type: ignore -import torch.distributed -import torchax -from jax.sharding import NamedSharding -from jax.sharding import Mesh, PartitionSpec as P -from jax.experimental import mesh_utils -import torch.utils._pytree as torch_pytree -from torchax import interop - - -class ProcessGroupJax(ProcessGroup): - """Distributed backend implemented with JAX.""" - - def __init__(self, prefix_store, rank, size, timeout): - super().__init__(rank, size) - self._group_name = None - - def getBackendName(self): - return "jax" - - # TODO(wcromar): why doesn't default group name setter work? - # https://github.com/pytorch/pytorch/blob/7b1988f9222f3dec5cc2012afce84218199748ae/torch/csrc/distributed/c10d/ProcessGroup.cpp#L148-L152 - def _set_group_name(self, name: str) -> None: - self._group_name = name - - @property - def group_name(self): - assert self._group_name - return self._group_name - - @staticmethod - def _work( - tensors: Union[torch.Tensor, List[torch.Tensor], - List[List[torch.Tensor]]], - ) -> dist.Work: - fut = torch.futures.Future() - fut.set_result(tensors) - return torch._C._distributed_c10d._create_work_from_future(fut) - - def _allgather_base( - self, - output: torch.Tensor, - input: torch.Tensor, - opts=..., - ) -> dist.Work: - assert isinstance(input, torchax.tensor.Tensor) - assert isinstance(output, torchax.tensor.Tensor) - torch.distributed._functional_collectives.all_gather_tensor_inplace( - output, input, group=self) - return self._work(output) - - def allreduce( - self, - tensors: List[torch.Tensor], - opts: dist.AllreduceOptions = ..., - ) -> dist.Work: - assert len(tensors) == 1 - assert isinstance(tensors[0], torchax.tensor.Tensor) - torch.distributed._functional_collectives.all_reduce_inplace( - tensors[0], - torch.distributed._functional_collectives.REDUCE_OP_TO_STR[ - opts.reduceOp.op], - self, - ) - - return self._work(tensors) - - def broadcast( - self, - tensors: List[torch.Tensor], - opts: dist.BroadcastOptions = ..., - ) -> dist.Work: - assert len(tensors) == 1 - assert isinstance(tensors[0], torchax.tensor.Tensor) - tensors[0].copy_( - torch.distributed._functional_collectives.broadcast( - tensors[0], opts.rootRank, group=self)) - - return self._work(tensors) - - -dist.Backend.register_backend("jax", ProcessGroupJax, devices=["jax"]) - - -def jax_rendezvous_handler(url: str, - timeout: datetime.timedelta = ..., - **kwargs): - """Initialize distributed store with JAX process IDs. - - Requires `$MASTER_ADDR` and `$MASTER_PORT`. - """ - # TODO(wcromar): jax.distributed.initialize(...) for multiprocess on GPU - # TODO(wcromar): Can we use the XLA coordinator as a Store? This isn't part - # of their public Python API - master_ip = os.environ["MASTER_ADDR"] - master_port = int(os.environ["MASTER_PORT"]) - # TODO(wcromar): Use `torchrun`'s store if available - store = dist.TCPStore( - master_ip, - master_port, - jax.process_count(), - is_master=jax.process_index() == 0, - ) - - yield (store, jax.process_index(), jax.process_count()) - - -dist.register_rendezvous_handler("jax", jax_rendezvous_handler) - - -def spawn(f, args=(), env: Optional[torchax.tensor.Environment] = None): - """Wrap `f` in a JAX `pmap` with the axis name `torch_dist` defined. - `f` is expected to take the replica index as a positional argument, similar - to `torch.multiprocessing.spawn`. - Note: `spawn` does not actually create parallel processes. - """ - env = env or torchax.default_env() - - def jax_wrapper(index, jax_args): - index, args = env.j2t_iso([index, jax_args]) - torch_outputs = f(index, *args) - return env.t2j_iso(torch_outputs) - - jax_outputs = jax.pmap( - jax_wrapper, axis_name="torch_dist")(np.arange(jax.device_count()), - env.t2j_iso(args)) - return env.j2t_iso(jax_outputs) - - -class DistributedDataParallel(torch.nn.Module): - """Re-implementation of DistributedDataParallel using JAX SPMD. - - Splits inputs along batch dimension (assumed to be 0) across all devices in - JAX runtime, including remote devices. Each process should load a distinct - shard of the input data using e.g. DistributedSampler. Each process' shard - is then further split among the addressable devices (e.g. local TPU chips) - by `shard_input`. - - Note: since parameters are replicated across addressable devices, inputs - must also be SPMD sharded using `shard_input` or `replicate_input`. - - Example usage: - - ``` - jax_model = torchax.distributed.DistributedDataParallel(create_model()) - for data, dataloader: - jax_data = jax_model.shard_input(data) - jax_output = jax_model(jax_data) - ``` - """ - - def __init__( - self, - module: torch.nn.Module, - env: Optional[torchax.tensor.Environment] = None, - **kwargs, - ): - if kwargs: - logging.warning(f"Unsupported kwargs {kwargs}") - - super().__init__() - self._env = env or torchax.default_env() - self._mesh = Mesh( - mesh_utils.create_device_mesh((jax.device_count(),)), - axis_names=("batch",), - ) - replicated_state = torch_pytree.tree_map_only( - torch.Tensor, - lambda t: self._env.j2t_iso( - jax.device_put( - self._env.to_xla(t)._elem, NamedSharding(self._mesh, P()))), - module.state_dict(), - ) - # TODO: broadcast - module.load_state_dict(replicated_state, assign=True) - self._module = module - - def shard_input(self, inp): - per_process_batch_size = inp.shape[0] # assumes batch dim is 0 - per_replica_batch_size = per_process_batch_size // jax.local_device_count() - per_replica_batches = torch.chunk(inp, jax.local_device_count()) - global_batch_size = per_replica_batch_size * jax.device_count() - global_batch_shape = (global_batch_size,) + inp.shape[1:] - - sharding = NamedSharding(self._mesh, P("batch")) - return self._env.j2t_iso( - jax.make_array_from_single_device_arrays( - global_batch_shape, - NamedSharding(self._mesh, P("batch")), - arrays=[ - jax.device_put(self._env.to_xla(batch)._elem, device) for batch, - device in zip(per_replica_batches, sharding.addressable_devices) - ], - )) - - def replicate_input(self, inp): - return self._env.j2t_iso( - jax.device_put(inp._elem, NamedSharding(self._mesh, P()))) - - def jit_step(self, func): - - @functools.partial( - interop.jax_jit, kwargs_for_jax_jit={'donate_argnums': 0}) - def _jit_fn(states, args): - self.load_state_dict(states) - outputs = func(*args) - return self.state_dict(), outputs - - @functools.wraps(func) - def inner(*args): - jax_states = self.state_dict() - new_states, outputs = _jit_fn(jax_states, args) - self.load_state_dict(new_states) - return outputs - - return inner - - def forward(self, *args): - with self._env: - return self._module(*args) From 95ba754a217965b3087e93fb6149e036a49c04ee Mon Sep 17 00:00:00 2001 From: qihqi Date: Fri, 18 Jul 2025 14:01:13 -0700 Subject: [PATCH 004/133] Allow mixed tensor type math if one of them is a scalar (#9453) --- torchax/test/test_ops.py | 2 +- torchax/torchax/config.py | 5 +++++ torchax/torchax/tensor.py | 4 ++++ 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/torchax/test/test_ops.py b/torchax/test/test_ops.py index 3e47ba21daf5..444f80753e1a 100644 --- a/torchax/test/test_ops.py +++ b/torchax/test/test_ops.py @@ -186,7 +186,7 @@ def setUp(self): self.env = torchax.default_env() torchax.enable_accuracy_mode() #self.env.config.debug_accuracy_for_each_op = True - self.env.config.debug_print_each_op = True + self.env.config.debug_print_each_op = False torch.manual_seed(0) self.old_var = self.env.config.use_torch_native_for_cpu_tensor self.env.config.use_torch_native_for_cpu_tensor = False diff --git a/torchax/torchax/config.py b/torchax/torchax/config.py index 9370625e85cb..0ba3a2a11e5c 100644 --- a/torchax/torchax/config.py +++ b/torchax/torchax/config.py @@ -10,6 +10,11 @@ class Configuration: use_int32_for_index: bool = False + # normally, math between CPU torch.Tensor with torchax.Tensor is not + # allowed. However, if that torch.Tensor happens to be scalar, then we + # can use scalar * tensor math to handle it + allow_mixed_math_with_scalar_tensor: bool = True + # If true, we will convert Views into torchax.Tensors eagerly force_materialize_views: bool = False diff --git a/torchax/torchax/tensor.py b/torchax/torchax/tensor.py index f7ba6867c547..205a041e9a09 100644 --- a/torchax/torchax/tensor.py +++ b/torchax/torchax/tensor.py @@ -639,6 +639,10 @@ def t2j_iso(self, torchtensors): """ def to_jax(x): + if self.config.allow_mixed_math_with_scalar_tensor and not isinstance( + x, Tensor): + if x.squeeze().ndim == 0: + return x.item() if isinstance( x, torch.distributed._functional_collectives.AsyncCollectiveTensor): x = x.wait() From 55b7d02196045f60e21702dfbee812dacbebe4cf Mon Sep 17 00:00:00 2001 From: Carlomus <48855305+Carlomus@users.noreply.github.com> Date: Sun, 20 Jul 2025 19:14:06 +0100 Subject: [PATCH 005/133] Fix nested stableHLO composite regions (#9385) --- test/stablehlo/test_composite.py | 16 +++---- .../runtime/stablehlo_composite_helper.cpp | 42 ++++++++++++++++++- 2 files changed, 48 insertions(+), 10 deletions(-) diff --git a/test/stablehlo/test_composite.py b/test/stablehlo/test_composite.py index 8fe211475ba1..b2c188aba944 100644 --- a/test/stablehlo/test_composite.py +++ b/test/stablehlo/test_composite.py @@ -147,10 +147,10 @@ def forward(self, x, y): stablehlo = self.run_func_get_stablehlo(M(), input_args) self.assertEqual(stablehlo.count("stablehlo.composite \"test.sdpa\""), 2) self.assertTrue( - '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl_0}' + '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl}' in stablehlo) self.assertTrue( - '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl}' + '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl_0}' in stablehlo) def test_composite_builder_sdpa_pattern(self): @@ -175,10 +175,10 @@ def forward(self, x, y): stablehlo = self.run_func_get_stablehlo(M(), input_args) self.assertEqual(stablehlo.count("stablehlo.composite \"test.sdpa\""), 2) self.assertTrue( - '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl_0}' + '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl}' in stablehlo) self.assertTrue( - '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl}' + '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl_0}' in stablehlo) def test_composite_builder_export_sdpa_pattern(self): @@ -208,10 +208,10 @@ def forward(self, x, y): stablehlo = stablehlo_gm.get_stablehlo_text() self.assertEqual(stablehlo.count("stablehlo.composite \"test.sdpa\""), 2) self.assertTrue( - '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl_0}' + '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl}' in stablehlo) self.assertTrue( - '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl}' + '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl_0}' in stablehlo) if has_tf_package(): self.assertTrue(os.path.exists(os.path.join(tmp_path, 'saved_model.pb'))) @@ -240,10 +240,10 @@ def forward(self, x, y): stablehlo = stablehlo_gm.get_stablehlo_text() self.assertEqual(stablehlo.count("stablehlo.composite \"test.sdpa\""), 2) self.assertTrue( - '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl_0}' + '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl}' in stablehlo) self.assertTrue( - '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl}' + '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl_0}' in stablehlo) if has_tf_package(): self.assertTrue(os.path.exists(os.path.join(tmp_path, 'saved_model.pb'))) diff --git a/torch_xla/csrc/runtime/stablehlo_composite_helper.cpp b/torch_xla/csrc/runtime/stablehlo_composite_helper.cpp index 101b36908555..c743f8527285 100644 --- a/torch_xla/csrc/runtime/stablehlo_composite_helper.cpp +++ b/torch_xla/csrc/runtime/stablehlo_composite_helper.cpp @@ -120,8 +120,30 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { std::unordered_map> boundary_output_ops_map = BuildBoundaryOutputOpsMap(func_op); - for (const auto& [unused, ops] : boundary_output_ops_map) { - if (mlir::failed(BuildStableHLOComposite(ops, op_order_map))) { + struct BoundaryGroup { + std::string key; + llvm::SmallVector ops; + size_t last_order; + }; + + llvm::SmallVector groups; + groups.reserve(boundary_output_ops_map.size()); + + for (auto& kv : boundary_output_ops_map) { + size_t last_ord = 0; + for (mlir::Operation* op : kv.second) { + if (op != nullptr) last_ord = std::max(last_ord, op_order_map.at(op)); + } + groups.push_back({kv.first, kv.second, last_ord}); + } + + llvm::sort(groups, [](const BoundaryGroup& a, const BoundaryGroup& b) { + return a.last_order < b.last_order; + }); + + for (auto& grp : groups) { + op_order_map = BuildOpOrderMap(func_op); + if (mlir::failed(BuildStableHLOComposite(grp.ops, op_order_map))) { func_op.emitError() << "failed to build composite."; return signalPassFailure(); } @@ -321,6 +343,22 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { } } + llvm::DenseSet wrapper_set(output_ops.begin(), + output_ops.end()); + + for (mlir::Operation* mark : output_ops) + if (mark->use_empty()) mark->erase(); + + for (mlir::Operation* op : llvm::reverse(impl_ops)) { + if (wrapper_set.contains(op) || !op->use_empty()) continue; + + bool pure_or_composite = mlir::wouldOpBeTriviallyDead(op) || + llvm::isa(op) || + llvm::isa(op); + + if (pure_or_composite) op->erase(); + } + if (!mlir::sortTopologically(composite_op->getBlock())) { composite_op->emitError() << "The graph is not acyclic after BuildStableHLOCompositePass pass."; From 26def0f2c2516c376984d31701f9ab15b4e09a2b Mon Sep 17 00:00:00 2001 From: qihqi Date: Sun, 20 Jul 2025 12:19:50 -0700 Subject: [PATCH 006/133] Misc fixes: (#9491) --- torchax/test/test_interop.py | 12 ++++++++++++ torchax/torchax/__init__.py | 7 ++++--- torchax/torchax/device_module.py | 7 +++++++ torchax/torchax/interop.py | 7 ++++--- 4 files changed, 27 insertions(+), 6 deletions(-) diff --git a/torchax/test/test_interop.py b/torchax/test/test_interop.py index d3f4ced3a149..5854e2c4ac38 100644 --- a/torchax/test/test_interop.py +++ b/torchax/test/test_interop.py @@ -5,6 +5,7 @@ from torchax import interop, jax_device import torchax import jax +import jax.numpy as jnp def is_tpu_available(): @@ -171,6 +172,17 @@ def test_to_jax_device(self): self.assertEqual(d.jax_device.platform, "cpu") self.assertEqual(d.device.type, "jax") + def test_torch_jax_view_dtype(self): + dtype = torch.float32 + self.assertEqual(interop.jax_view(dtype), jnp.float32.dtype) + self.assertEqual(interop.torch_view(interop.jax_view(dtype)), dtype) + dtype = torch.bfloat16 + self.assertEqual(interop.jax_view(dtype), jnp.bfloat16.dtype) + self.assertEqual(interop.torch_view(interop.jax_view(dtype)), dtype) + dtype = torch.int32 + self.assertEqual(interop.jax_view(dtype), jnp.int32.dtype) + self.assertEqual(interop.torch_view(interop.jax_view(dtype)), dtype) + if __name__ == '__main__': unittest.main() diff --git a/torchax/torchax/__init__.py b/torchax/torchax/__init__.py index d08ba31071c5..e0cc69dae364 100644 --- a/torchax/torchax/__init__.py +++ b/torchax/torchax/__init__.py @@ -49,10 +49,11 @@ def extract_jax(mod: torch.nn.Module, env=None): states = env.t2j_copy(states) #@jax.jit - def jax_func(states, inputs): - (states, inputs) = env.j2t_iso((states, inputs)) + def jax_func(states, args, kwargs=None): + (states, args, kwargs) = env.j2t_iso((states, args, kwargs)) with env: - res = torch.func.functional_call(mod, states, inputs, tie_weights=False) + res = torch.func.functional_call( + mod, states, args, kwargs, tie_weights=False) return env.t2j_iso(res) return states, jax_func diff --git a/torchax/torchax/device_module.py b/torchax/torchax/device_module.py index 20fceaf06b43..be028cfcc21d 100644 --- a/torchax/torchax/device_module.py +++ b/torchax/torchax/device_module.py @@ -1,3 +1,6 @@ +import torch + + def _is_in_bad_fork(): return False @@ -24,3 +27,7 @@ def is_available(): def current_device(): return 0 + + +def get_amp_supported_dtype(): + return [torch.float16, torch.bfloat16] diff --git a/torchax/torchax/interop.py b/torchax/torchax/interop.py index bceb008ef4b0..0d0ea655ac5c 100644 --- a/torchax/torchax/interop.py +++ b/torchax/torchax/interop.py @@ -11,6 +11,7 @@ from jax.experimental.shard_map import shard_map from torchax import tensor from torchax import util +from torchax.ops import mappings import torchax from torchax.types import JaxValue, TorchValue, JaxCallable, TorchCallable @@ -183,8 +184,8 @@ def _torch_view(t: JaxValue) -> TorchValue: if isinstance(t, jax.Array): # TODO return tensor.Tensor(t, torchax.default_env()) - if isinstance(t, type(jnp.int32)): - return tensor.t2j_type(t) + if isinstance(t, jnp.dtype): + return mappings.j2t_dtype(t) if callable(t): # t is a JaxCallable return functools.partial(call_jax, t) # regular types are not changed @@ -201,7 +202,7 @@ def _jax_view(t: TorchValue) -> JaxValue: assert isinstance(t, tensor.Tensor) or isinstance(t, tensor.View), type(t) return t.jax() if isinstance(t, type(torch.int32)): - return tensor.t2j_dtype(t) + return mappings.t2j_dtype(t) # torch.nn.Module needs special handling if not isinstance(t, torch.nn.Module) and callable(t): # t is a TorchCallable From e82631ec9d9b6e11c7f5e6478fe4e8834fd88c78 Mon Sep 17 00:00:00 2001 From: vfdev Date: Mon, 21 Jul 2025 18:40:02 +0200 Subject: [PATCH 007/133] Fix python 3.11 cuda wheel link in the readme (#9493) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index b9c1f4a55544..446453705c3d 100644 --- a/README.md +++ b/README.md @@ -261,7 +261,7 @@ GPU release builds and GPU/TPU nightly builds are available in our public GCS bu | Version | Cloud GPU VM Wheels | | --- | ----------- | | 2.7 (CUDA 12.6 + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.6/torch_xla-2.7.0-cp310-cp310-manylinux_2_28_x86_64.whl` | -| 2.7 (CUDA 12.6 + Python 3.11) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl` | +| 2.7 (CUDA 12.6 + Python 3.11) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.6/torch_xla-2.7.0-cp311-cp311-manylinux_2_28_x86_64.whl` | | nightly (Python 3.11) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev-cp311-cp311-linux_x86_64.whl` | | nightly (Python 3.12) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev-cp312-cp312-linux_x86_64.whl` | | nightly (Python 3.13) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev-cp312-cp312-linux_x86_64.whl` | From 31c4c2f239180d834f53998982a0693c79dbc4ec Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Wed, 23 Jul 2025 12:02:07 -0700 Subject: [PATCH 008/133] [Bugfix] fix ragged attention kernel auto-tuning table key (#9497) --- .../pallas_kernels/ragged_paged_attention_v2.py | 17 +++++++++++------ torchax/torchax/interop.py | 7 +++---- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py index 5985649c63d8..61a6411b31f0 100644 --- a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py +++ b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py @@ -25,6 +25,7 @@ from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp +import logging DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max) # The page size is too small. We only have 32 SREGs in TC. If the pages @@ -1421,12 +1422,12 @@ def simplify_key(key): return ( jnp.dtype(q_dtype).name, jnp.dtype(kv_dtype).name, - next_power_of_2(num_q_heads_per_blk), - next_power_of_2(num_kv_heads_per_blk), + num_q_heads_per_blk, + num_kv_heads_per_blk, (head_dim + 127) // 128 * 128, next_power_of_2(page_size), next_power_of_2(max_num_batched_tokens), - next_power_of_2(page_size * pages_per_seq), + page_size * pages_per_seq, ) @@ -1472,7 +1473,7 @@ def get_tuned_block_sizes( max_num_batched_tokens, pages_per_seq, ) - key = simplify_key(key) + simplified_key = simplify_key(key) device_name = get_device_name() # Default block sizes. @@ -1500,8 +1501,12 @@ def compute_actual_vmem_bytes(num_kv_pages_per_blk): # OOM in vmem bkv, bq = (32, 32) elif device_name in TUNED_BLOCK_SIZES: - if key in TUNED_BLOCK_SIZES[device_name]: - bkv, bq = TUNED_BLOCK_SIZES[device_name][key] + if simplified_key in TUNED_BLOCK_SIZES[device_name]: + bkv, bq = TUNED_BLOCK_SIZES[device_name][simplified_key] + else: + logging.warning( + f"simplified_key({simplified_key}) is not in ragged attention kernel's tuning table!, the key before simpilification is {key}" + ) return (min(pages_per_seq, bkv), min(max_num_batched_tokens, bq)) diff --git a/torchax/torchax/interop.py b/torchax/torchax/interop.py index 0d0ea655ac5c..a87efe9dfe74 100644 --- a/torchax/torchax/interop.py +++ b/torchax/torchax/interop.py @@ -239,8 +239,7 @@ def j2t_autograd(fn, call_jax=call_jax): @wraps(fn) def inner(*args, **kwargs): - from jax.tree_util import tree_flatten, tree_unflatten - from jax.util import safe_zip + from jax.tree_util import tree_flatten class JaxFun(torch.autograd.Function): @@ -275,8 +274,8 @@ def backward(ctx, *grad_out): # The subsequent gradients correspond to flat_inputs. # We need to put a None for inputs that did not require gradients. final_grads = [None] - for needs_grad, grad in safe_zip(ctx.needs_input_grad[1:], - input_grads_structured): + for needs_grad, grad in zip( + ctx.needs_input_grad[1:], input_grads_structured, strict=True): final_grads.append(grad if needs_grad else None) return tuple(final_grads) From 299a16b54e9a73d6bcf5267f1a7c5f193db14c0f Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 24 Jul 2025 14:08:34 -0300 Subject: [PATCH 009/133] Error Handling: refactor `ComputationClient::TransferFromDevice` to propagate status. (#9429) --- test/cpp/cpp_test_util.cpp | 5 ++- test/cpp/test_replication.cpp | 5 ++- test/test_operations.py | 15 +++++++++ torch_xla/csrc/init_python_bindings.cpp | 4 +-- torch_xla/csrc/runtime/BUILD | 1 + torch_xla/csrc/runtime/computation_client.h | 2 +- .../csrc/runtime/ifrt_computation_client.cpp | 10 +++--- .../csrc/runtime/ifrt_computation_client.h | 2 +- .../runtime/ifrt_computation_client_test.cpp | 2 +- .../csrc/runtime/pjrt_computation_client.cpp | 31 +++++++------------ .../csrc/runtime/pjrt_computation_client.h | 2 +- .../runtime/pjrt_computation_client_test.cpp | 2 +- torch_xla/csrc/tensor_util.cpp | 5 +-- 13 files changed, 47 insertions(+), 39 deletions(-) diff --git a/test/cpp/cpp_test_util.cpp b/test/cpp/cpp_test_util.cpp index 03efa4207191..afe573101ebc 100644 --- a/test/cpp/cpp_test_util.cpp +++ b/test/cpp/cpp_test_util.cpp @@ -303,9 +303,8 @@ std::vector Execute( std::vector Fetch( absl::Span device_data) { - std::vector literals = - torch_xla::runtime::GetComputationClientOrDie()->TransferFromDevice( - device_data); + std::vector literals = GetValueOrThrow( + runtime::GetComputationClientOrDie()->TransferFromDevice(device_data)); std::vector tensors; for (auto& literal : literals) { tensors.push_back(MakeTensorFromXlaLiteral( diff --git a/test/cpp/test_replication.cpp b/test/cpp/test_replication.cpp index 88175c2fdbb7..b565dc44cd08 100644 --- a/test/cpp/test_replication.cpp +++ b/test/cpp/test_replication.cpp @@ -79,9 +79,8 @@ void TestSingleReplication( counter.Wait(); for (size_t i = 0; i < results.size(); ++i) { - std::vector literals = - torch_xla::runtime::GetComputationClientOrDie()->TransferFromDevice( - results[i]); + std::vector literals = GetValueOrThrow( + runtime::GetComputationClientOrDie()->TransferFromDevice(results[i])); ASSERT_EQ(literals.size(), 1); // The result must be the original tensor value, multiplied by the number of diff --git a/test/test_operations.py b/test/test_operations.py index f037ad4b8cb2..68aa0b6c2c82 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -88,6 +88,11 @@ def skipIfFunctionalizationDisabled(reason): return _skipIfFunctionalization(value=True, reason=reason) +def onlyOnCPU(fn): + accelerator = os.environ.get("PJRT_DEVICE").lower() + return unittest.skipIf(accelerator != "cpu", "PJRT_DEVICE=CUDA required")(fn) + + def onlyOnCUDA(fn): accelerator = os.environ.get("PJRT_DEVICE").lower() return unittest.skipIf(accelerator != "cuda", "PJRT_DEVICE=CUDA required")(fn) @@ -2458,6 +2463,16 @@ def test_add_broadcast_error(self): torch.add(a, b) torch_xla.sync() + @onlyOnCPU + def test_construct_large_tensor_raises_error(self): + with self.assertRaisesRegex(RuntimeError, + r"Out of memory allocating \d+ bytes"): + # When eager-mode is enabled, OOM is triggered here. + a = torch.rand(1024, 1024, 1024, 1024, 1024, device=torch_xla.device()) + b = a.sum() + # OOM is raised when we try to bring data from the device. + b.cpu() + class MNISTComparator(nn.Module): diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index ce55969d693b..5b62d95efd57 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1229,9 +1229,9 @@ class PyLoweringContext { lowering_ctx.GetParametersData(); // Fetch this parameter data - std::vector literals = + std::vector literals = GetValueOrThrow( runtime::GetComputationClientOrDie()->TransferFromDevice( - UnwrapXlaData(device_data)); + UnwrapXlaData(device_data))); // Create a mapping from paramater id to the tensor data std::unordered_map results; diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index d98329718906..c4760783f4d0 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -123,6 +123,7 @@ cc_library( ":tf_logging", ":xla_coordinator", "//torch_xla/csrc:status", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index c7603c8932af..c2f9389a4a0a 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -318,7 +318,7 @@ class ComputationClient { // Note: `TransferFromDevice` call will block until the `DataPtrs` are ready // if they were created by `TransferToDevice` or `Execute*`. Calling this from // python while holding the GIL can cause deadlocks! - virtual std::vector TransferFromDevice( + virtual absl::StatusOr> TransferFromDevice( absl::Span handles) = 0; virtual std::uintptr_t UnsafeBufferPointer(const DataPtr handle) = 0; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cpp b/torch_xla/csrc/runtime/ifrt_computation_client.cpp index a463f79a226f..f5a6af1b267c 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cpp +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cpp @@ -436,8 +436,8 @@ std::shared_ptr IfrtComputationClient::GetPjRtBuffer( XLA_ERROR() << __FUNCTION__ << " not implemented"; } -std::vector IfrtComputationClient::TransferFromDevice( - absl::Span handles) { +absl::StatusOr> +IfrtComputationClient::TransferFromDevice(absl::Span handles) { metrics::TimedSection timed(TransferFromDeviceMetric()); tsl::profiler::TraceMe activity("IfrtComputationClient::TransferFromDevice", tsl::profiler::TraceMeLevel::kInfo); @@ -455,9 +455,9 @@ std::vector IfrtComputationClient::TransferFromDevice( auto& literal = literals.emplace_back( xla::ShapeUtil::DeviceShapeToHostShape(ifrt_data->shape())); std::vector byte_strides(literal.shape().dimensions_size()); - XLA_CHECK_OK(xla::ShapeUtil::ByteStrides(literal.shape(), - absl::MakeSpan(byte_strides))); - XLA_CHECK_OK( + XLA_RETURN_IF_ERROR(xla::ShapeUtil::ByteStrides( + literal.shape(), absl::MakeSpan(byte_strides))); + XLA_RETURN_IF_ERROR( replicated_array ->CopyToHostBuffer(literal.untyped_data(), byte_strides, xla::ifrt::ArrayCopySemantics::kAlwaysCopy) diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index 9c21d7a8d7fb..46b6343dc10a 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -62,7 +62,7 @@ class IfrtComputationClient : public ComputationClient { XLA_ERROR() << __FUNCTION__ << " not implemented"; } - std::vector TransferFromDevice( + absl::StatusOr> TransferFromDevice( absl::Span handles) override; std::uintptr_t UnsafeBufferPointer(const DataPtr handle) override; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client_test.cpp b/torch_xla/csrc/runtime/ifrt_computation_client_test.cpp index 7a4741fc1bc5..eb39f9b2e23f 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client_test.cpp +++ b/torch_xla/csrc/runtime/ifrt_computation_client_test.cpp @@ -70,7 +70,7 @@ TEST(PjRtComputationClientTest, Init) { // Copy the output from device back to host and assert correctness.. ASSERT_EQ(results.size(), 1); - auto result_literals = client->TransferFromDevice(results); + auto result_literals = GetValueOrThrow(client->TransferFromDevice(results)); ASSERT_THAT(result_literals, ::testing::SizeIs(1)); EXPECT_TRUE(xla::LiteralTestUtil::Equal( xla::LiteralUtil::CreateR2({{6.0f, 8.0f}, {10.0f, 12.0f}}), diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cpp b/torch_xla/csrc/runtime/pjrt_computation_client.cpp index 8239da35846d..dd4950d87f5e 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cpp +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cpp @@ -4,6 +4,7 @@ #include #include +#include "absl/log/absl_check.h" #include "absl/strings/ascii.h" #include "absl/synchronization/blocking_counter.h" #include "absl/types/span.h" @@ -508,8 +509,8 @@ std::shared_ptr PjRtComputationClient::GetPjRtBuffer( } } -std::vector PjRtComputationClient::TransferFromDevice( - absl::Span handles) { +absl::StatusOr> +PjRtComputationClient::TransferFromDevice(absl::Span handles) { metrics::TimedSection timed(TransferFromDeviceMetric()); tsl::profiler::TraceMe activity("PjRtComputationClient::TransferFromDevice", tsl::profiler::TraceMeLevel::kInfo); @@ -522,21 +523,17 @@ std::vector PjRtComputationClient::TransferFromDevice( // Use XLA replication to reassemble the sharded data. If input handle // is not sharded, then it is a no-op. std::shared_ptr pjrt_data = ReplicateShardedData(handle); - XLA_CHECK(pjrt_data) << "PjRt_data is null in " << __FUNCTION__; - XLA_CHECK(pjrt_data->buffer != nullptr) + ABSL_CHECK(pjrt_data) << "PjRt_data is null in " << __FUNCTION__; + ABSL_CHECK(pjrt_data->buffer != nullptr) << "PjRt buffer is null in " << __FUNCTION__; - xla::Literal& literal = - literals.emplace_back(host_output_shape(pjrt_data->buffer.get())); + xla::Literal& literal = literals.emplace_back( + xla::Literal(host_output_shape(pjrt_data->buffer.get()))); futures.push_back(pjrt_data->buffer->ToLiteral(&literal)); total_size += literal.size_bytes(); } - for (auto& future : futures) { - absl::Status status = future.Await(); - XLA_CHECK_OK(status) << "Failed to await future from buffer to literal in" - << __FUNCTION__; - } + XLA_RETURN_IF_ERROR(xla::JoinFutures(futures).Await()); InboundDataMetric()->AddSample(total_size); return literals; @@ -773,10 +770,8 @@ PjRtComputationClient::ExecuteComputation( std::optional> returned_future; std::vector> results = - pjrt_computation.executable - ->ExecuteSharded(buffers, pjrt_device, execute_options, - returned_future) - .value(); + GetValueOrThrow(pjrt_computation.executable->ExecuteSharded( + buffers, pjrt_device, execute_options, returned_future)); returned_future->OnReady(std::move( [timed, op_tracker = std::move(op_tracker)](absl::Status unused) mutable { @@ -878,10 +873,8 @@ PjRtComputationClient::ExecuteReplicated( tsl::profiler::TraceMe activity( "PjRtComputationClient::ExecuteReplicated_execute", tsl::profiler::TraceMeLevel::kInfo); - results = pjrt_computation.executable - ->Execute(std::move(argument_handles), execute_options, - returned_futures) - .value(); + results = GetValueOrThrow(pjrt_computation.executable->Execute( + std::move(argument_handles), execute_options, returned_futures)); (*returned_futures)[0].OnReady( std::move([timed, op_tracker = std::move(op_tracker)]( diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index b7c61e2ec74c..3a6b4478f722 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -65,7 +65,7 @@ class PjRtComputationClient : public ComputationClient { absl::Span handles, absl::Span shardings) override; - std::vector TransferFromDevice( + absl::StatusOr> TransferFromDevice( absl::Span handles) override; std::uintptr_t UnsafeBufferPointer(const DataPtr handle) override; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp b/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp index 3398e61a2782..0fe2b2a70fcb 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp +++ b/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp @@ -120,7 +120,7 @@ TEST_F(PjRtComputationClientTest, Init) { // Copy the output from device back to host and assert correctness. ASSERT_EQ(results.size(), 1); - auto result_literals = client_->TransferFromDevice(results); + auto result_literals = GetValueOrThrow(client_->TransferFromDevice(results)); ASSERT_THAT(result_literals, ::testing::SizeIs(1)); EXPECT_TRUE(xla::LiteralTestUtil::Equal( xla::LiteralUtil::CreateR2({{6.0f, 8.0f}, {10.0f, 12.0f}}), diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index 0a7f184cda77..e2cd3a025f59 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -24,6 +24,7 @@ #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/tf_logging.h" #include "torch_xla/csrc/runtime/util.h" +#include "torch_xla/csrc/status.h" #include "torch_xla/csrc/thread_pool.h" #include "torch_xla/csrc/torch_util.h" #include "torch_xla/csrc/xla_backend_impl.h" @@ -909,8 +910,8 @@ std::vector ReleaseGilAndTransferData( save = PyEval_SaveThread(); } std::vector literals = - runtime::GetComputationClientOrDie()->TransferFromDevice( - UnwrapXlaData(xla_data)); + GetValueOrThrow(runtime::GetComputationClientOrDie()->TransferFromDevice( + UnwrapXlaData(xla_data))); if (save) { PyEval_RestoreThread(save); } From ca471983897928a2e7aadeaba805cb90ded58070 Mon Sep 17 00:00:00 2001 From: aws-cph Date: Thu, 24 Jul 2025 10:38:50 -0700 Subject: [PATCH 010/133] Implement XLAShardedTensor._spec and test (#9488) --- test/neuron/run_tests.sh | 11 +- test/run_tests.sh | 1 + test/spmd/test_xla_dtensor_spec_conversion.py | 235 ++++++++++++++++++ test/spmd/test_xla_sharding.py | 4 - test/tpu/run_tests.sh | 1 + .../distributed/spmd/xla_sharded_tensor.py | 80 +++++- torch_xla/distributed/spmd/xla_sharding.py | 40 ++- 7 files changed, 358 insertions(+), 14 deletions(-) create mode 100644 test/spmd/test_xla_dtensor_spec_conversion.py diff --git a/test/neuron/run_tests.sh b/test/neuron/run_tests.sh index 22e778945f9c..f7671cc3d827 100755 --- a/test/neuron/run_tests.sh +++ b/test/neuron/run_tests.sh @@ -56,6 +56,14 @@ function run_test { PJRT_DEVICE=NEURON NEURON_NUM_DEVICES=1 run_coverage "$@" } +function run_test_multi_device { + if ! test_is_selected "$1"; then + return + fi + echo "Running in PjRt runtime: $@" + PJRT_DEVICE=NEURON run_coverage "$@" +} + function run_test_without_functionalization { if ! test_is_selected "$1"; then return @@ -246,7 +254,8 @@ function run_xla_op_tests3 { run_test "$_TEST_DIR/spmd/test_xla_spmd_python_api_interaction.py" #run_test "$_TEST_DIR/spmd/test_dtensor_integration.py" #run_test "$_TEST_DIR/spmd/test_dtensor_integration2.py" - run_test "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" + run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" + run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_spec_conv.py" run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py" #run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py" run_test "$_TEST_DIR/spmd/test_train_spmd_linear_model.py" diff --git a/test/run_tests.sh b/test/run_tests.sh index 93f4cb33c061..b2cc8f751d2c 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -254,6 +254,7 @@ function run_xla_op_tests3 { run_test "$_TEST_DIR/spmd/test_dtensor_integration2.py" run_test_multi_devices_without_func "$_TEST_DIR/spmd/test_dtensor_integration3.py" run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" + run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py" run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py" run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py" run_test "$_TEST_DIR/spmd/test_mp_input_sharding.py" diff --git a/test/spmd/test_xla_dtensor_spec_conversion.py b/test/spmd/test_xla_dtensor_spec_conversion.py new file mode 100644 index 000000000000..81cb8a4aa2e4 --- /dev/null +++ b/test/spmd/test_xla_dtensor_spec_conversion.py @@ -0,0 +1,235 @@ +import os +import sys + +import torch +from torch.distributed.tensor import DeviceMesh, Shard, distribute_tensor +from torch.distributed.tensor.placement_types import Replicate + +import torch_xla +import torch_xla.runtime as xr +from torch_xla.distributed.spmd import XLAShardedTensor +from torch_xla.distributed.spmd.xla_sharding import wrap_as_sharded_tensor + +import unittest +import test_xla_sharding_base + + +class XLADTensorSpecConversionTest(test_xla_sharding_base.XlaShardingTest): + + @classmethod + def setUpClass(cls): + super().setUpClass() + + def test_sample_test_case(self): + world_size = xr.global_runtime_device_count() + mesh = DeviceMesh("xla", torch.arange(world_size)) + big_tensor = torch.randn(100000, 88) + my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(0)]) + + assert my_dtensor._spec.mesh.device_type == mesh.device_type + assert my_dtensor._spec.placements == (Shard(0),) + + def test_xla_to_dtensor_spec_conversion(self): + device_count = xr.global_runtime_device_count() + mesh = DeviceMesh("xla", list(range(device_count))) + + # Test different sharding patterns + test_cases = [ + (torch.randn(100, 50), [Shard(0)]), + (torch.randn(100, 50), [Shard(1)]), + (torch.randn(100, 50, 25), [Shard(0)]), + (torch.randn(100, 50), [Replicate()]), + ] + + for tensor, placements in test_cases: + xla_tensor = distribute_tensor(tensor, mesh, placements) + spec = xla_tensor._spec + + assert spec is not None + assert spec.mesh.device_type == "xla" + assert spec.tensor_meta.shape == tensor.shape + assert spec.tensor_meta.dtype == tensor.dtype + assert len(spec.placements) >= 1 + assert spec.placements == tuple(placements) + + def test_mesh_conversion(self): + device_count = xr.global_runtime_device_count() + original_mesh = DeviceMesh("xla", list(range(device_count))) + tensor = torch.randn(50, 50) + xla_tensor = distribute_tensor(tensor, original_mesh, [Shard(0)]) + + converted_spec = xla_tensor._spec + + assert converted_spec.mesh.device_type == "xla" + assert converted_spec.mesh.size() == device_count + # assert on mesh dimensions + assert converted_spec.mesh.shape == original_mesh.shape + + def test_spec_caching(self): + """Test that _spec property caches results + """ + device_count = xr.global_runtime_device_count() + mesh = DeviceMesh("xla", list(range(device_count))) + tensor = torch.randn(100, 100) + xla_tensor = distribute_tensor(tensor, mesh, [Shard(0)]) + + spec1 = xla_tensor._spec + + assert xla_tensor._cached_spec is not None + assert xla_tensor._cached_spec is spec1 + + spec2 = xla_tensor._spec + assert spec1 is spec2 + + def _create_test_tensor_and_mesh(self, tensor_shape, mesh_shape, placements): + """Helper to create tensor and mesh for testing""" + device_count = xr.global_runtime_device_count() + if device_count < max(mesh_shape): + self.skipTest( + f"Need at least {max(mesh_shape)} devices, got {device_count}") + + mesh = DeviceMesh("xla", torch.arange(device_count).reshape(mesh_shape)) + tensor = torch.randn(*tensor_shape) + return distribute_tensor(tensor, mesh, placements), mesh + + def test_multi_dim_sharding_spec(self): + """Test _spec for multi-dimensional sharding""" + device_count = xr.global_runtime_device_count() + if device_count < 4: + self.skipTest("Need at least 4 devices for 2D mesh") + + mesh_shape = (2, device_count // 2) + xla_tensor, mesh = self._create_test_tensor_and_mesh( + (100, 50), mesh_shape, [Shard(0), Shard(1)]) + spec = xla_tensor._spec + + assert len(spec.placements) == 2 + assert spec.mesh.ndim == 2 + + def test_mixed_placement_spec(self): + """Test _spec for tensors with mixed shard/replicate placements""" + device_count = xr.global_runtime_device_count() + if device_count < 4: + self.skipTest("Need at least 4 devices for 2D mesh") + + mesh_shape = (2, device_count // 2) + xla_tensor, mesh = self._create_test_tensor_and_mesh( + (100, 50), mesh_shape, [Shard(0), Replicate()]) + spec = xla_tensor._spec + + assert len(spec.placements) == 2 + assert isinstance(spec.placements[0], Shard) + assert isinstance(spec.placements[1], Replicate) + + def test_sharding_info_acquisition(self): + """Test that non-XLAShardedTensor can acquire sharding information + + Tests case of 'elem is not an XLAShardedTensor but there exists + sharding information we want to acquire' + """ + + device_count = xr.global_runtime_device_count() + mesh_shape = (device_count,) + partition_spec = (0, None) + + regular_tensor = torch.randn(100, 50).to('xla') + + sharded_tensor = wrap_as_sharded_tensor( + regular_tensor, mesh_shape=mesh_shape, partition_spec=partition_spec) + + # Verify the tensor acquired the sharding information + assert isinstance(sharded_tensor, XLAShardedTensor) + assert sharded_tensor.mesh_shape == mesh_shape + assert sharded_tensor.partition_spec == partition_spec + + def test_resharding_logic(self): + """ + Tests wrap_as_sharded_tensor resharding before returning XLAShardedTensor t. + """ + + device_count = xr.global_runtime_device_count() + if device_count < 4: + self.skipTest("Need at least 4 devices for resharding test") + + # Initial sharding + initial_mesh_shape = (device_count,) + initial_partition_spec = (0, None) + new_mesh_shape = (2, device_count // 2) + new_partition_spec = (0, 1) + + # Create tensor and verify resharding + tensor = torch.randn(100, 50).to('xla') + sharded_tensor = wrap_as_sharded_tensor( + tensor, + mesh_shape=initial_mesh_shape, + partition_spec=initial_partition_spec) + initial_spec = sharded_tensor._spec + + resharded_tensor = wrap_as_sharded_tensor( + sharded_tensor, + mesh_shape=new_mesh_shape, + partition_spec=new_partition_spec) + + # Verify resharding worked and cache was invalidated + assert resharded_tensor.mesh_shape == new_mesh_shape + assert resharded_tensor.partition_spec == new_partition_spec + assert resharded_tensor._spec is not initial_spec + + def test_spec_invalidation_on_resharding(self): + """Tests cases where the cached spec may become outdated. + """ + + device_count = xr.global_runtime_device_count() + if device_count < 4: + self.skipTest("Need at least 4 devices for resharding test") + + tensor = torch.randn(100, 50).to('xla') + initial_mesh_shape = (device_count,) + initial_partition_spec = (0, None) + new_mesh_shape = (2, device_count // 2) + new_partition_spec = (0, 1) + + sharded_tensor = wrap_as_sharded_tensor( + tensor, + mesh_shape=initial_mesh_shape, + partition_spec=initial_partition_spec) + initial_spec = sharded_tensor._spec + assert sharded_tensor._cached_spec is not None + + # Changing mesh_shape / partition_spec through wrap_as_sharded_tensor invalidates cache + resharded_tensor = wrap_as_sharded_tensor( + sharded_tensor, + mesh_shape=new_mesh_shape, + partition_spec=initial_partition_spec) + assert resharded_tensor._spec is not initial_spec + assert resharded_tensor._spec.mesh.shape == new_mesh_shape + + initial_spec = resharded_tensor._spec + resharded_tensor = wrap_as_sharded_tensor( + resharded_tensor, + mesh_shape=new_mesh_shape, + partition_spec=new_partition_spec) + assert resharded_tensor._spec is not initial_spec + assert resharded_tensor._spec.placements[1].dim == 1 + + def test_auto_wrapped_tensor_spec_failure(self): + """Test that auto-wrapped tensors fail when accessing _spec property. + + Auto-wrapped tensors are created through operations that trigger __torch_dispatch__ + but don't yet have access to the sharding propagation done through open xla, + causing ._spec to fail. + """ + device_count = xr.global_runtime_device_count() + mesh = DeviceMesh("xla", torch.arange(device_count)) + tensor = torch.randn(4, 4) + sharded_tensor = distribute_tensor(tensor, mesh, [Shard(0)]) + + auto_wrapped = sharded_tensor + sharded_tensor + + with self.assertRaises(ValueError): + _ = auto_wrapped._spec + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 7b1be7574a1f..48b760f6e3f0 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -1162,10 +1162,6 @@ def test_mark_shard_scalar(self): self.assertIsInstance(shard.indices, type(Ellipsis)) self.assertEqual(shard.replica_id, i) - # It looks like mesh_shape attribute is never implemented. - with self.assertRaises(AttributeError): - xt.mesh_shape - def test_global_mesh(self): expected_mesh = self._get_mesh((1, self.n_devices)) xs.set_global_mesh(expected_mesh) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 1f6f5249b93b..24f18d3bdcda 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -61,6 +61,7 @@ run_test "$_TEST_DIR/spmd/test_xla_spmd_python_api_interaction.py" run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py" run_test "$_TEST_DIR/spmd/test_fsdp_v2.py" run_test "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" +run_test "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py" run_test "$_TEST_DIR/test_gradient_accumulation.py" XLA_EXPERIMENTAL=nonzero:masked_select:nms run_test "$_TEST_DIR/ds/test_dynamic_shape_models.py" -v run_test "$_TEST_DIR/test_autocast.py" diff --git a/torch_xla/distributed/spmd/xla_sharded_tensor.py b/torch_xla/distributed/spmd/xla_sharded_tensor.py index aedfd6a801e3..a20d530f3faa 100644 --- a/torch_xla/distributed/spmd/xla_sharded_tensor.py +++ b/torch_xla/distributed/spmd/xla_sharded_tensor.py @@ -6,6 +6,11 @@ from typing import List, Tuple, Iterator, Union import contextlib import collections +import torch_xla.runtime as xr +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor.placement_types import Shard, Replicate +from torch.utils._pytree import tree_map_only @dataclass @@ -91,10 +96,15 @@ class XLAShardedTensor(torch.Tensor): # >> assert len(input.shape) == len(partition_spec) partition_spec: Tuple[int, None] - __slots__ = ['global_tensor'] + __slots__ = ['global_tensor', 'mesh_shape', 'partition_spec', '_cached_spec'] @staticmethod - def __new__(cls, elem: torch.Tensor, *args, **kwargs): + def __new__(cls, + elem: torch.Tensor, + mesh_shape=None, + partition_spec=None, + *args, + **kwargs): # TODO(yeounoh) wrapper can take different arguments r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] cls, @@ -106,6 +116,13 @@ def __new__(cls, elem: torch.Tensor, *args, **kwargs): device=elem.device, requires_grad=kwargs.get("requires_grad", False)) r.global_tensor = elem.detach() if r.requires_grad else elem + + # Initialize mesh, partition, and spec information + r.mesh_shape = mesh_shape or (elem.mesh_shape if isinstance( + elem, XLAShardedTensor) else None) + r.partition_spec = partition_spec or (elem.partition_spec if isinstance( + elem, XLAShardedTensor) else None) + r._cached_spec = None return r # Shards on the devices are materialized/available after the lazy @@ -130,6 +147,9 @@ def load_local_shards_(self, shards: List[XLAShard]): devices = [s.shard_device for s in shards] torch_xla._XLAC._load_local_shards(self.global_tensor, data, devices) + # Invalidate cached spec since the global_tensor data has changed + self.invalidate_spec_cache() + @property def sharding_spec(self): return torch_xla._XLAC._get_xla_sharding_spec(self.global_tensor) @@ -169,6 +189,62 @@ def wrap(elem): func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) return rs + @property + def _spec(self): + """ + Convert XLA sharding information to DTensorSpec for DTensor interface compatibility. + """ + # Return cached spec if available + if self._cached_spec is not None: + return self._cached_spec + + # use existing mesh_shape + if self.mesh_shape is not None: + device_count = xr.global_runtime_device_count() + device_list = list(range(device_count)) + mesh = DeviceMesh("xla", + torch.tensor(device_list).reshape(self.mesh_shape)) + else: + raise ValueError( + "mesh_shape must be specified to create DTensorSpec. " + "If this tensor was created through torch operations, it may be auto-wrapped. " + "Use wrap_as_sharded_tensor() to set mesh_shape before accessing _spec. " + ) + + # use existing partition_spec + if self.partition_spec is not None: + placements = [] + for mesh_dim in range(len(self.mesh_shape)): + # find tensor dimension sharded on this mesh dimension + tensor_dim = None + for t_dim, m_dim in enumerate(self.partition_spec): + if m_dim == mesh_dim: + tensor_dim = t_dim + break + placements.append( + Shard(tensor_dim) if tensor_dim is not None else Replicate()) + else: + raise ValueError( + "partition_spec must be specified to create DTensorSpec. " + "If this tensor was created through torch operations, it may be auto-wrapped. " + "Use wrap_as_sharded_tensor() to set partition_spec before accessing _spec. " + ) + + # tensor metadata + tensor_meta = TensorMeta( + shape=self.global_tensor.shape, + stride=self.global_tensor.stride(), + dtype=self.global_tensor.dtype) + + # Create and cache the spec + self._cached_spec = DTensorSpec( + mesh=mesh, placements=tuple(placements), tensor_meta=tensor_meta) + return self._cached_spec + + def invalidate_spec_cache(self): + """Invalidate the cached DTensorSpec.""" + self._cached_spec = None + @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): return super().__torch_function__(func, types, args, kwargs) diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index 49229b17cffe..5f4d4378e7d2 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -543,7 +543,8 @@ def enable_manual_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh = get_global_mesh() if mesh is None else mesh t = mark_sharding(unwrap_sharded_tensor(t), mesh, partition_spec) t = torch_xla._XLAC._spmd_full_to_shard_shape(unwrap_sharded_tensor(t)) - return wrap_as_sharded_tensor(t) + return wrap_as_sharded_tensor( + t, mesh_shape=mesh.mesh_shape, partition_spec=partition_spec) def disable_manual_sharding(t: Union[torch.Tensor, XLAShardedTensor], @@ -560,7 +561,8 @@ def disable_manual_sharding(t: Union[torch.Tensor, XLAShardedTensor], t = torch_xla._XLAC._spmd_shard_to_full_shape( unwrap_sharded_tensor(t), mesh.get_op_sharding(partition_spec), full_shape, t.dtype) - return wrap_as_sharded_tensor(t) + return wrap_as_sharded_tensor( + t, mesh_shape=mesh.mesh_shape, partition_spec=partition_spec) def annotate_custom_sharding(t: Union[torch.Tensor, @@ -594,7 +596,8 @@ def annotate_custom_sharding(t: Union[torch.Tensor, op_sharding = mesh.get_op_sharding(partition_spec) annotate_func = torch_xla._XLAC._xla_annotate_custom_sharding annotate_func(unwrap_sharded_tensor(t), op_sharding) - return wrap_as_sharded_tensor(t) + return wrap_as_sharded_tensor( + t, mesh_shape=mesh.mesh_shape, partition_spec=partition_spec) def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, @@ -651,7 +654,9 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, op_sharding = mesh.get_op_sharding(partition_spec) annotate_func = torch_xla._XLAC._xla_mark_sharding annotate_func(unwrap_sharded_tensor(t), op_sharding) - return wrap_as_sharded_tensor(t) + # Pass mesh and partition spec information for DTensor compatibility + return wrap_as_sharded_tensor( + t, mesh_shape=mesh.mesh_shape, partition_spec=partition_spec) def mark_sharding_with_gradients( @@ -755,10 +760,31 @@ def clear_sharding(t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor: return t -def wrap_as_sharded_tensor( - t: Union[torch.Tensor, XLAShardedTensor]) -> XLAShardedTensor: +def wrap_as_sharded_tensor(t: Union[torch.Tensor, XLAShardedTensor], + mesh_shape=None, + partition_spec=None) -> XLAShardedTensor: + # pass along mesh and partition spec information if not isinstance(t, XLAShardedTensor): - return XLAShardedTensor(t) + # Create a new XLAShardedTensor + return XLAShardedTensor( + t, mesh_shape=mesh_shape, partition_spec=partition_spec) + + # Update existing XLAShardedTensor if needed + needs_invalidate = False + + # Always set mesh_shape and partition_spec if provided + if mesh_shape is not None: + t.mesh_shape = mesh_shape + needs_invalidate = True + + if partition_spec is not None: + t.partition_spec = partition_spec + needs_invalidate = True + + # Invalidate cached spec if resharding occurred + if needs_invalidate: + t.invalidate_spec_cache() + return t From 16b120226432a03e7e72e468ced03ed9a70e6dc7 Mon Sep 17 00:00:00 2001 From: Kyuyeun Kim <62023335+kyuyeunk@users.noreply.github.com> Date: Thu, 24 Jul 2025 13:09:58 -0700 Subject: [PATCH 011/133] Clean up quantized matmul condition code (#9506) --- .../experimental/pallas_kernels/quantized_matmul_kernel.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py b/torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py index 54ae75352080..6630566d3eed 100644 --- a/torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py +++ b/torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py @@ -70,7 +70,7 @@ def matmul_kernel( assert quantize_activation assert q_x_scratch is not None assert x_scale_scratch is not None - quant = out_idx == 0 + quant = (out_idx == 0) else: assert q_x_scratch is None assert x_scale_scratch is None @@ -78,8 +78,8 @@ def matmul_kernel( if save_acc: assert acc_scratch is not None - is_first_step = in_idx == 0 - is_last_step = in_idx == n_in - 1 + is_first_step = (in_idx == 0) + is_last_step = (in_idx == (n_in - 1)) else: assert acc_scratch is None is_first_step = True From 0a1594a8b503abb97935918e00a2661f0caad7bf Mon Sep 17 00:00:00 2001 From: qihqi Date: Thu, 24 Jul 2025 14:17:00 -0700 Subject: [PATCH 012/133] Move mutable properties of env to thread local, misc changes (#9501) * Refactored jax device handling * Removed option to use CPU jax array for CPU torch tensors. - changing jax devices after the fact will use different APIs --- torchax/test/test_context.py | 36 ++-- torchax/test/test_core_aten_ops.py | 5 - torchax/test/test_flax.py | 2 +- torchax/test/test_functions.py | 11 +- torchax/test/test_interop.py | 22 +-- torchax/test/test_libraries.py | 1 - torchax/test/test_ops.py | 7 +- torchax/test/test_unbounded_dynamism.py | 5 - torchax/torchax/__init__.py | 36 ---- torchax/torchax/amp.py | 5 +- torchax/torchax/config.py | 1 - torchax/torchax/mesh_util.py | 11 +- torchax/torchax/ops/jaten.py | 5 +- torchax/torchax/ops/jtorch.py | 21 ++- torchax/torchax/tensor.py | 233 ++++++++++++------------ 15 files changed, 178 insertions(+), 223 deletions(-) diff --git a/torchax/test/test_context.py b/torchax/test/test_context.py index 1f97b7a24ef6..ace28eeb4265 100644 --- a/torchax/test/test_context.py +++ b/torchax/test/test_context.py @@ -10,16 +10,9 @@ class TestContext(unittest.TestCase): - def setUp(self): - self.old_var = xla_env.config.use_torch_native_for_cpu_tensor - xla_env.config.use_torch_native_for_cpu_tensor = False - - def tearDown(self): - xla_env.config.use_torch_native_for_cpu_tensor = self.old_var - def test_mode_context_manager(self): with xla_env: - x = torch.full((3, 3), -1) + x = torch.full((3, 3), -1, device='jax') self.assertIsInstance(x, tensor.Tensor) y = x.abs() self.assertIsInstance(y, tensor.Tensor) @@ -27,7 +20,7 @@ def test_mode_context_manager(self): @staticmethod @xla_env def _test_mode_decorator(): - x = torch.full((3, 3), -1) + x = torch.full((3, 3), -1).to('jax') y = x.abs() return x, y @@ -40,11 +33,11 @@ def test_mode_decorator(self): def test_same_manual_seed(self): with xla_env: xla_env.manual_seed(1234) - x = torch.randn((3, 3)) + x = torch.randn((3, 3), device='jax') self.assertIsInstance(x, tensor.Tensor) xla_env.manual_seed(1234) - y = torch.randn((3, 3)) + y = torch.randn((3, 3), device='jax') self.assertIsInstance(y, tensor.Tensor) self.assertTrue(torch.allclose(x, y)) @@ -52,11 +45,11 @@ def test_same_manual_seed(self): def test_different_manual_seed(self): with xla_env: xla_env.manual_seed(1234) - x = torch.randn((3, 3)) + x = torch.randn((3, 3), device='jax') self.assertIsInstance(x, tensor.Tensor) xla_env.manual_seed(12345) - y = torch.randn((3, 3)) + y = torch.randn((3, 3), device='jax') self.assertIsInstance(y, tensor.Tensor) self.assertFalse(torch.allclose(x, y)) @@ -66,21 +59,24 @@ def test_jit_with_rng(self): with xla_env: def random_op(): - x = torch.randn(3, 3) - y = torch.randn(3, 3) + x = torch.randn(3, 3, device='jax') + y = torch.randn(3, 3, device='jax') return x @ y random_jit = torchax.interop.jax_jit(random_op) self.assertIsInstance(random_jit(), tensor.Tensor) # If we run the JIT twice, the random values should be different. - with self.assertRaises(AssertionError): - torch.testing.assert_close(random_jit(), random_jit(), atol=0, rtol=0) + # TODO(qihqi): think about API for passing down seed + # with self.assertRaises(AssertionError): + # torch.testing.assert_close(random_jit(), random_jit(), atol=0, rtol=0) def test_generator_seed(self): with xla_env: - x = torch.randn(2, 3, generator=torch.Generator().manual_seed(0)) - y = torch.randn(2, 3, generator=torch.Generator().manual_seed(0)) + x = torch.randn( + 2, 3, generator=torch.Generator().manual_seed(0), device='jax') + y = torch.randn( + 2, 3, generator=torch.Generator().manual_seed(0), device='jax') # Values will be the same given the same seed. torch.testing.assert_close(x, y) @@ -97,7 +93,7 @@ def __init__(self): # Test context manager. with xla_env: - m = M() + m = M().to('jax') self.assertIsInstance(m.c, tensor.Tensor) self.assertIsInstance(m.c2, tensor.Tensor) # Test `to_xla`. diff --git a/torchax/test/test_core_aten_ops.py b/torchax/test/test_core_aten_ops.py index 3e7c36ef5916..7a24c8bac408 100644 --- a/torchax/test/test_core_aten_ops.py +++ b/torchax/test/test_core_aten_ops.py @@ -90,11 +90,6 @@ def setUp(self): super().setUp() torch.manual_seed(0) self.env = tensor.Environment() - self.old_var = self.env.config.use_torch_native_for_cpu_tensor - self.env.config.use_torch_native_for_cpu_tensor = False - - def tearDown(self): - self.env.config.use_torch_native_for_cpu_tensor = self.old_var def test_aten_abs_0(self): args = (torch.randn((10, 10)).to(torch.float32),) diff --git a/torchax/test/test_flax.py b/torchax/test/test_flax.py index 43d989b76738..bc5b7f219786 100644 --- a/torchax/test/test_flax.py +++ b/torchax/test/test_flax.py @@ -81,7 +81,7 @@ def forward(self, x): return res with env: - nn_module = Parent() + nn_module = Parent().to('jax') @jax_jit def jitted(weights, args): diff --git a/torchax/test/test_functions.py b/torchax/test/test_functions.py index 38bd6c9da98a..03c3778bb00e 100644 --- a/torchax/test/test_functions.py +++ b/torchax/test/test_functions.py @@ -88,8 +88,15 @@ def test_rms_norm(self): res2 = model(x) self.assertTrue(torch.allclose(res, res2.to('cpu'))) - def test_randn_requires_grad(self): - x = torch.randn((3, 3), requires_grad=True, device='jax') + @parameterized.named_parameters( + ('ones', torch.ones, ((2, 2),)), ('zeros', torch.zeros, ((2, 2),)), + ('empty', torch.empty, + ((2, 2),)), ('empty_strided', torch.empty_strided, + ((2, 2), (2, 1))), ('tensor', torch.tensor, ([2.0, 2.0],)), + ('eye', torch.eye, (2,)), ('randn', torch.randn, ((2, 2),)), + ('rand', torch.rand, ((2, 2),)), ('full', torch.full, ((2, 2), 0))) + def test_requires_grad(self, func, args): + x = func(*args, requires_grad=True, device='jax') self.assertEqual(x.requires_grad, True) diff --git a/torchax/test/test_interop.py b/torchax/test/test_interop.py index 5854e2c4ac38..fe17c95292a7 100644 --- a/torchax/test/test_interop.py +++ b/torchax/test/test_interop.py @@ -2,7 +2,7 @@ import torch import unittest import torchax -from torchax import interop, jax_device +from torchax import interop import torchax import jax import jax.numpy as jnp @@ -143,7 +143,7 @@ def test_to_jax_device(self): self.assertEqual(e.jax_device.platform, "cpu") self.assertEqual(e.device.type, "jax") - with jax_device("cpu"): + with jax.default_device(jax.devices("cpu")[0]): # move torch.tensor to torchax.tensor CPU b = a.to("jax") self.assertEqual(b.jax_device.platform, "cpu") @@ -151,27 +151,11 @@ def test_to_jax_device(self): if is_tpu_available(): # move torch.tensor to torchax.tensor TPU - with jax_device("tpu"): + with jax.default_device(jax.local_devices("tpu")[0]): c = a.to("jax") self.assertEqual(c.jax_device.platform, "tpu") self.assertEqual(c.device.type, "jax") - # move torchax.tensor on CPU to TPU - with jax_device("tpu"): - self.assertEqual(b.jax_device.platform, "cpu") - self.assertEqual(c.device.type, "jax") - c = b.to("jax") - self.assertEqual(c.jax_device.platform, "tpu") - self.assertEqual(c.device.type, "jax") - - # move torchax.tensor on TPU to CPU - with jax_device("cpu"): - self.assertEqual(c.jax_device.platform, "tpu") - self.assertEqual(c.device.type, "jax") - d = c.to("jax") - self.assertEqual(d.jax_device.platform, "cpu") - self.assertEqual(d.device.type, "jax") - def test_torch_jax_view_dtype(self): dtype = torch.float32 self.assertEqual(interop.jax_view(dtype), jnp.float32.dtype) diff --git a/torchax/test/test_libraries.py b/torchax/test/test_libraries.py index bcbc7d41e76e..69ed3c77e53b 100644 --- a/torchax/test/test_libraries.py +++ b/torchax/test/test_libraries.py @@ -54,7 +54,6 @@ class LibraryTest(unittest.TestCase): def setUp(self): torch.manual_seed(0) - torchax.default_env().config.use_torch_native_for_cpu_tensor = False def test_basic_sdpa_library(self): diff --git a/torchax/test/test_ops.py b/torchax/test/test_ops.py index 444f80753e1a..54ef1c30b5a3 100644 --- a/torchax/test/test_ops.py +++ b/torchax/test/test_ops.py @@ -140,6 +140,8 @@ def run_export_and_compare(testcase, with testcase.subTest("torchax_eval"): input2, args2, kwargs2 = testcase.env.to_xla( (sample_input.input, sample_input.args, sample_input.kwargs)) + if 'device' in kwargs2: + kwargs2['device'] = 'jax' with testcase.env: res2 = func(input2, *args2, **kwargs2) res2 = pytree.tree_map_only(tensor.Tensor, lambda t: t.torch(), res2) @@ -188,11 +190,6 @@ def setUp(self): #self.env.config.debug_accuracy_for_each_op = True self.env.config.debug_print_each_op = False torch.manual_seed(0) - self.old_var = self.env.config.use_torch_native_for_cpu_tensor - self.env.config.use_torch_native_for_cpu_tensor = False - - def tearDown(self): - self.env.config.use_torch_native_for_cpu_tensor = self.old_var # Replaces all values in the input torch_tensor that are less than the given threshold # with the threshold value itself. diff --git a/torchax/test/test_unbounded_dynamism.py b/torchax/test/test_unbounded_dynamism.py index 4638657bdaca..d15b1750678a 100644 --- a/torchax/test/test_unbounded_dynamism.py +++ b/torchax/test/test_unbounded_dynamism.py @@ -53,13 +53,8 @@ def forward(self, *args): class UnboundedDynamismExportTest(unittest.TestCase): def setUp(self): - self.env = torchax.default_env() - self.env.config.use_torch_native_for_cpu_tensor = False torchax.enable_accuracy_mode() - def tearDown(self): - self.env.config.use_torch_native_for_cpu_tensor = True - def test_add(self): args = (torch.rand((10, 197, 768)), torch.rand((10, 197, 768))) dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),) diff --git a/torchax/torchax/__init__.py b/torchax/torchax/__init__.py index e0cc69dae364..fe4c1c8ff046 100644 --- a/torchax/torchax/__init__.py +++ b/torchax/torchax/__init__.py @@ -81,11 +81,6 @@ def disable_temporarily(): torch.utils.rename_privateuse1_backend('jax') unsupported_dtype = [torch.quint8] -torch.utils.generate_methods_for_privateuse1_backend( - for_tensor=True, - for_module=True, - for_storage=True, - unsupported_dtype=unsupported_dtype) import jax import torchax.device_module @@ -129,34 +124,3 @@ def compile(fn, options: Optional[CompileOptions] = None): raise RuntimeError('dynamo mode is not supported yet') elif options.mode == 'export': raise RuntimeError('export mode is not supported yet') - - -@contextmanager -def jax_device(target_device: str, env: tensor.Environment | None = None): - """ - to("jax") cannot differentiate the device/platform (cpu vs tpu). - Use this context manager to control jax array's storage device - - Examples: - - a = torch.ones(3, 3) - - with jax_device("cpu"): - b = a.to("jax") - - with jax_device("tpu"): - c = a.to("jax") - - with jax_device("tpu"): - c = b.to("jax") - - """ - if env is None: - env = default_env() - - prev_target_device = env.target_device - try: - env.target_device = target_device - yield env - finally: - env.target_device = prev_target_device diff --git a/torchax/torchax/amp.py b/torchax/torchax/amp.py index ef06e884a8a8..ccbc63bead63 100644 --- a/torchax/torchax/amp.py +++ b/torchax/torchax/amp.py @@ -61,9 +61,8 @@ def autocast(device, dtype=torch.bfloat16, env=None): if env is None: import torchax env = torchax.default_env() - env.autocast_dtype, old = dtype, env.autocast_dtype - yield - env.autocast_dtype = old + with env.override_property(autocast_dtype=dtype): + yield # https://github.com/pytorch/pytorch/blob/05faba40287cf7d8734da96cb2e904f39710bf29/aten/src/ATen/autocast_mode.cpp#L327 diff --git a/torchax/torchax/config.py b/torchax/torchax/config.py index 0ba3a2a11e5c..f439c656287b 100644 --- a/torchax/torchax/config.py +++ b/torchax/torchax/config.py @@ -27,5 +27,4 @@ class Configuration: # device treat_cuda_as_jax_device: bool = True - use_torch_native_for_cpu_tensor: bool = True internal_respect_torch_return_dtypes: bool = False diff --git a/torchax/torchax/mesh_util.py b/torchax/torchax/mesh_util.py index 3f65b8440b59..208d86a1bac6 100644 --- a/torchax/torchax/mesh_util.py +++ b/torchax/torchax/mesh_util.py @@ -199,7 +199,7 @@ def initialize_model_sharded(self, } def model_initializer(): - with torchax.default_env(): + with torchax.default_env(), torch.device('meta'): model = model_class(*init_args, **init_kwargs) return dict(model.state_dict()) @@ -209,3 +209,12 @@ def model_initializer(): model.load_state_dict(weights_dict, assign=True) return model + + def shard_model(self, model, override_sharder=None): + sharder = override_sharder or self._sharder + states = model.state_dict() + output_shards = { + name: NamedSharding(self.jax_mesh, sharder(name, tensor)) + for name, tensor in states.items() + } + model.load_state_dict(output_shards, assign=True) diff --git a/torchax/torchax/ops/jaten.py b/torchax/torchax/ops/jaten.py index d26ef4233a1d..711b4bbe8b06 100644 --- a/torchax/torchax/ops/jaten.py +++ b/torchax/torchax/ops/jaten.py @@ -3532,7 +3532,7 @@ def _aten_tensor_split(ary, indices_or_sections, axis=0): @op(torch.ops.aten.randn, needs_env=True) @op_base.convert_dtype() -def _randn( +def _aten_randn( *size, generator=None, out=None, @@ -3652,7 +3652,7 @@ def _aten_native_batch_norm(input, @op(torch.ops.aten.normal, needs_env=True) def _aten_normal(self, mean=0, std=1, generator=None, env=None): shape = self.shape - res = _randn(*shape, generator=generator, env=env) + res = _aten_randn(*shape, generator=generator, env=env) return res * std + mean @@ -5541,6 +5541,7 @@ def _aten_floor_divide(x, y): @op(torch.ops.aten._assert_tensor_metadata) +@op(torch.ops.aten._assert_scalar) def _aten__assert_tensor_metadata(*args, **kwargs): pass diff --git a/torchax/torchax/ops/jtorch.py b/torchax/torchax/ops/jtorch.py index 935c214d78f5..b53f27d462d2 100644 --- a/torchax/torchax/ops/jtorch.py +++ b/torchax/torchax/ops/jtorch.py @@ -341,7 +341,7 @@ def empty(*size: Sequence[int], dtype=None, **kwargs): return jnp.empty(size, dtype=dtype) -@register_function(torch.arange, is_jax_function=False) +@register_function(torch.arange, is_jax_function=True) def arange( start, end=None, @@ -358,10 +358,10 @@ def arange( start = 0 if step is None: step = 1 - return torch.ops.aten.arange(start, end, step, dtype=dtype) + return jaten._aten_arange(start, end, step, dtype=dtype) -@register_function(torch.empty_strided, is_jax_function=False) +@register_function(torch.empty_strided, is_jax_function=True) def empty_strided( size, stride, @@ -372,7 +372,7 @@ def empty_strided( requires_grad=False, pin_memory=False, ): - return empty(size, dtype=dtype) + return empty(size, dtype=dtype, requires_grad=requires_grad) @register_function(torch.unravel_index) @@ -380,14 +380,14 @@ def unravel_index(indices, shape): return jnp.unravel_index(indices, shape) -@register_function(torch.rand, is_jax_function=False) +@register_function(torch.rand, is_jax_function=True, needs_env=True) def rand(*size, **kwargs): if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): size = size[0] - return torch.ops.aten.rand(size, **kwargs) + return jaten._rand(size, **kwargs) -@register_function(torch.randn, is_jax_function=False) +@register_function(torch.randn, is_jax_function=True, needs_env=True) def randn( *size, generator=None, @@ -397,15 +397,16 @@ def randn( device=None, requires_grad=False, pin_memory=False, + env=None, ): if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): size = size[0] - return torch.ops.aten.randn(size, generator=generator, dtype=dtype) + return jaten._aten_randn(size, generator=generator, dtype=dtype, env=env) -@register_function(torch.randint, is_jax_function=False) +@register_function(torch.randint, is_jax_function=False, needs_env=True) def randint(*args, **kwargs): - return torch.ops.aten.randint(*args, **kwargs) + return jaten._aten_randint(*args, **kwargs) @register_function(torch.logdet) diff --git a/torchax/torchax/tensor.py b/torchax/torchax/tensor.py index 205a041e9a09..3916fe6501b8 100644 --- a/torchax/torchax/tensor.py +++ b/torchax/torchax/tensor.py @@ -1,3 +1,4 @@ +import threading import logging import sys import contextlib @@ -16,7 +17,6 @@ from torchax import config from torchax.ops import mappings, ops_registry from torchax import amp -from jax.experimental import mutable_array logger = logging.getLogger(__name__) @@ -25,14 +25,6 @@ class OperatorNotFound(Exception): pass -def wrap(jaxarray): - return torch_pytree.tree_map_only(jnp.ndarray, Tensor, jaxarray) - - -def unwrap(torchtensors): - return torch_pytree.tree_map_only(Tensor, lambda x: x._elem, torchtensors) - - @contextlib.contextmanager def log_nested(env, message): if env.config.debug_print_each_op: @@ -48,7 +40,7 @@ def log_nested(env, message): class Tensor(torch.Tensor): @staticmethod - def __new__(cls, elem, env): + def __new__(cls, elem, env, requires_grad=False): dtype = mappings.j2t_dtype(elem.dtype) shape = list(elem.shape) for i, s in enumerate(shape): @@ -56,15 +48,19 @@ def __new__(cls, elem, env): shape[i] = 1 if dtype is None: dtype = torch.float32 + #dispatch_keys = torch.DispatchKeySet(torch._C.DispatchKey.PrivateUse1).add(torch._C.DispatchKey.AutogradPrivateUse1) + if not (dtype.is_floating_point or dtype.is_complex): + requires_grad = False + return torch.Tensor._make_wrapper_subclass( cls, shape, dtype=dtype, - device="meta", - requires_grad=False, + device='meta', + requires_grad=requires_grad, ) - def __init__(self, elem: jax.Array, env: "Environment"): + def __init__(self, elem: jax.Array, env: "Environment", requires_grad=False): super().__init__() self._elem = elem self._env = env @@ -109,6 +105,8 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # TODO(hanq): figure out why is dispatch mode not sufficient if func == torch.ops._c10d_functional.wait_tensor.default: return args[0]._env.dispatch(func, types, args, kwargs) + if func == torch.ops.prim.device.default: + return torch.device('privateuseone', 0) raise AssertionError( 'torchax Tensors can only do math within the torchax environment.' 'Please wrap your code with `with torchax.default_env()` or ' @@ -298,6 +296,38 @@ def _name_of_func(func): SUPPORTED_JAX_PLATFROM = ["cpu", "tpu"] +class RuntimeProperty: + mesh: Any + prng: Any + autocast_dtype: Any + + def __init__(self, mesh, prng, autocast_dtype): + self.mesh = mesh + self.prng = prng + self.autocast_dtype = autocast_dtype + + def override(self, **kwargs): + return OverrideProperty(self, kwargs) + + def get_and_rotate_prng_key(self): + old_key = self.prng + new_prng_key, next_key = jax.random.split(old_key) + self.prng = new_prng_key + return next_key + + +class OverrideProperty(RuntimeProperty): + + def __init__(self, parent, override): + self.parent = parent + self._override = dict(override) + + def __getattr__(self, name): + if name in self._override: + return self._override[name] + return getattr(self.parent, name) + + class Environment(contextlib.ContextDecorator): """This class holds a set of configurations and "globals" needed @@ -321,62 +351,55 @@ def __init__(self, configuration=None): self.load_ops() - self._mesh = None + _mesh = None self.config = configuration or config.Configuration() - self._manually_entered = False self.enabled = False - self._prng_key = mutable_array( - jax.random.key(torch.initial_seed() % (1 << 63))) - self.autocast_dtype = None - self._target_device = jax.local_devices()[0].platform + autocast_dtype = None - @property - def target_device(self): - return self._target_device + _prng_key = jax.random.key(torch.initial_seed() % (1 << 63)) + self._property = threading.local() + self._property.content = [ + RuntimeProperty( + mesh=_mesh, prng=_prng_key, autocast_dtype=autocast_dtype) + ] - @target_device.setter - def target_device(self, device: str): - self._target_device = device.lower() + @property + def param(self): + return self._property.content[-1] def manual_seed(self, key): - self._prng_key = mutable_array(jax.random.key(key)) + jax_key = jax.random.PRNGKey(key) + new_prop = self.param.override(prng=jax_key) + self._property.content.append(new_prop) @property def prng_key(self): - return self._prng_key[...] + return self.param.prng - def get_as_jax_device(self, device: Any): + def _should_use_torchax_tensor(self, device): if device is None: device = torch.get_default_device() if isinstance(device, torch.device): - device = str(device) - - if not self.config.use_torch_native_for_cpu_tensor and device.startswith( - "cpu"): - return jax.devices("cpu")[0] - - if self.config.treat_cuda_as_jax_device and device.startswith("cuda"): - return jax.local_devices()[0] - - if device.startswith("xla"): - return jax.local_devices()[0] - - # TODO (wen): jax is NOT a device type, - # once we can register more than one backend, revisit - if device.startswith("jax"): - match self.target_device: - case "cpu": - return jax.devices("cpu")[0] - case "tpu": - return jax.devices("tpu")[0] - case _: - raise AttributeError( - f"Cannot handle env.target_device {self.target_device}") - - return None # fallback to torch + device = device.type + + if ':' in device: + device = device.split(':')[0] + + match device: + case 'cpu': + return False + case 'cuda': + return self.config.treat_cuda_as_jax_device + case 'jax': + return True + case 'privateuseone': + return True + case 'meta': + return self.enabled + return False def load_ops(self): from torchax.ops import jaten, jtorch, jc10d, jtorchvision_nms @@ -423,80 +446,61 @@ def _get_from_dict(op_dict, op): return op + def _is_same_device(self, the_tensor, new_device): + if new_device is None: + return True + if new_device == 'meta' and the_tensor.device.type == 'jax': + return True + if the_tensor.device.type != new_device: + if the_tensor.device.type == 'cuda': + return self.config.treat_cuda_as_jax_device + return False + return True + def _to_copy(self, the_tensor, new_dtype, new_device): if isinstance(the_tensor, View): the_tensor = the_tensor.torch() - - if isinstance(the_tensor, Tensor): - - arr = the_tensor.jax() - - if new_dtype is not None and new_dtype != arr.dtype: - arr = arr.astype(mappings.t2j_dtype(new_dtype)) - - if new_device is not None: - match str(new_device).lower(): - case "cpu": - # converting to a non-jax device: let torch native handle it - torch_tensor = self.j2t_copy(arr) if isinstance(the_tensor, - Tensor) else arr - with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): - return torch_tensor.to(new_device) - case "jax": - # move torchax.tensor / jax tensor between devices - # I don't know ifgit this will work after the model is jitted - if self.target_device != the_tensor.jax_device.platform: - arr = jax.device_put(the_tensor.jax(), - jax.devices(self.target_device)[0]) - return Tensor(arr, self) - case _: - logging.error(f"torchax.Tenosr cannot handle device {new_device}") - - else: - if new_dtype is not None and new_dtype != the_tensor.dtype: + if isinstance(new_device, torch.device): + new_device = new_device.type + res = the_tensor + if not self._is_same_device(the_tensor, new_device): + if isinstance(the_tensor, Tensor): + torch_tensor = self.j2t_copy(the_tensor._elem) with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): - the_tensor = the_tensor.to(new_dtype) - - if new_device is None: ## device is None means don't change device - return the_tensor - - jax_device = self.get_as_jax_device(new_device) - if jax_device: + return torch_tensor.to(device=new_device, dtype=new_dtype) + else: arr = self.t2j_copy(the_tensor) - arr = jax.device_put(arr, jax_device) + res = Tensor(arr, self, the_tensor.requires_grad) + + if new_dtype is not None and new_dtype != the_tensor.dtype: + if isinstance(the_tensor, Tensor): + res = res.apply_jax(jnp.astype, mappings.t2j_dtype(new_dtype)) else: with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): - return the_tensor.to(new_device) - - return Tensor(arr, self) + return the_tensor.to(device=new_device, dtype=new_dtype) + return res def get_and_rotate_prng_key(self, generator: Optional[torch.Generator] = None): if generator is not None: - with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): - self._prng_key[...] = jax.random.key(generator.initial_seed() % (2**63)) - old_key = self._prng_key[...] - new_prng_key, next_key = jax.random.split(old_key) - self._prng_key[...] = new_prng_key - return next_key + return jax.random.PRNGKey(generator.initial_seed() % (2**63)) + return self.param.get_and_rotate_prng_key() def _handle_tensor_constructor(self, func, args, kwargs): device = kwargs.get("device") - jax_device = self.get_as_jax_device(device) - # TODO(qihqi) figure out better ways for device propagation - if not self._manually_entered and jax_device is None: - # let torch handle it - with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): - return func(*args, **kwargs) - with jax.default_device(jax_device): + if self._should_use_torchax_tensor(device): + # don't set default device, let caller set it requires_grad = kwargs.get("requires_grad", False) op = self._get_op_or_decomp(func) + if op.needs_env: + kwargs['env'] = self res = op.func(*args, **kwargs) if isinstance(res, jax.Array): - res = Tensor(res, self) - if requires_grad: - res.requires_grad = True + res = Tensor(res, self, requires_grad) return res + else: + with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): + return func(*args, **kwargs) def _torch_Tensor_to(self, args, kwargs): the_tensor = args[0] @@ -560,11 +564,11 @@ def is_not_torchax_tensor(x): args, kwargs = self.v2t_iso((args, kwargs)) with self: - if self.autocast_dtype is not None: + if self.param.autocast_dtype is not None: autocast_policy = amp.autocast_policy.get(func) if autocast_policy is not None: args, kwargs = amp.execute_policy(autocast_policy, args, kwargs, - self.autocast_dtype) + self.param.autocast_dtype) if op.is_jax_function: args, kwargs = self.t2j_iso((args, kwargs)) @@ -609,11 +613,9 @@ def disable_torch_modes(self, *exc): def __enter__(self): self.enable_torch_modes() - self._manually_entered = True return self def __exit__(self, *exc): - self._manually_entered = False self.disable_torch_modes(*exc) def _move_one_value(self, val): @@ -701,3 +703,10 @@ def override_op_definition(self, op_to_override, op_impl): is_user_defined=True, needs_env=False, ) + + @contextlib.contextmanager + def override_property(self, **kwargs): + new_prop = self.param.override(**kwargs) + self._property.content.append(new_prop) + yield + self._property.content.pop() From 29ae4c76c026185f417a25e841d2cd5e65f087a3 Mon Sep 17 00:00:00 2001 From: Kyuyeun Kim <62023335+kyuyeunk@users.noreply.github.com> Date: Fri, 25 Jul 2025 23:22:51 -0700 Subject: [PATCH 013/133] Optimize w8a8 kernel vmem limit (#9508) --- .../pallas_kernels/quantized_matmul_kernel.py | 97 ++++++++++++++----- 1 file changed, 71 insertions(+), 26 deletions(-) diff --git a/torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py b/torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py index 6630566d3eed..6401f87a8dd7 100644 --- a/torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py +++ b/torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py @@ -42,12 +42,12 @@ def matmul_kernel( x_abs_max_ref: jax.Array, # (1, batch_block_size) out_ref: jax.Array, # (batch_block_size, out_block_size) acc_scratch: jax.Array, # (batch_block_size, out_block_size) - q_x_scratch: jax.Array, # (batch_block_size, in_block_size) + x_q_scratch: jax.Array, # (batch_block_size, in_block_size) x_scale_scratch: jax.Array, # (batch_block_size, 1) *, quantize_activation: bool, save_acc: bool, - save_q_x: bool, + save_x_q: bool, batch_block_size: int, out_block_size: int, in_block_size: int, @@ -66,13 +66,13 @@ def matmul_kernel( assert out_ref.shape == (batch_block_size, out_block_size), "out_ref shape is not correct" - if save_q_x: + if save_x_q: assert quantize_activation - assert q_x_scratch is not None + assert x_q_scratch is not None assert x_scale_scratch is not None quant = (out_idx == 0) else: - assert q_x_scratch is None + assert x_q_scratch is None assert x_scale_scratch is None quant = quantize_activation @@ -88,18 +88,18 @@ def matmul_kernel( def matmul_body(quant, is_first_step, is_last_step): if quantize_activation: if quant: - q_x_tmp, x_scale_tmp = _quantize_array(x_ref[...], x_abs_max_ref[...]) - if save_q_x: - q_x_scratch[...] = q_x_tmp + x_q_tmp, x_scale_tmp = _quantize_array(x_ref[...], x_abs_max_ref[...]) + if save_x_q: + x_q_scratch[...] = x_q_tmp x_scale_scratch[...] = x_scale_tmp else: - assert save_q_x - q_x_tmp = q_x_scratch[...] + assert save_x_q + x_q_tmp = x_q_scratch[...] if is_last_step: x_scale_tmp = x_scale_scratch[...] acc = jax.lax.dot_general( - q_x_tmp, + x_q_tmp, w_ref[...], (((1,), (1,)), ((), ())), preferred_element_type=jnp.int32, @@ -130,6 +130,44 @@ def _next_multiple(x, multiple): return ((x + multiple - 1) // multiple) * multiple +def _get_vmem_limit(n_bs, n_out, n_in, batch_block_size, out_block_size, + in_block_size, x_bytes, w_bytes, x_q_bytes, scale_bytes, + out_bytes, acc_bytes, save_acc, save_x_q): + # Calculate in/out VMEM size. + x_size = batch_block_size * in_block_size * x_bytes + x_abs_max_val_size = batch_block_size * scale_bytes + w_size = out_block_size * in_block_size * w_bytes + scalar_size = out_block_size * scale_bytes + out_size = batch_block_size * out_block_size * out_bytes + + vmem_in_out = x_size + x_abs_max_val_size + w_size + scalar_size + out_size + vmem_in_out *= 2 # Account for compute and vreg spills. + + # Account for double buffering. + # Double buffering is used only if there are multiple blocks per in/out. + vmem_in_out += x_size if (n_bs > 1 or n_in > 1) else 0 + vmem_in_out += x_abs_max_val_size if (n_bs > 1) else 0 + vmem_in_out += w_size if (n_out > 1 or n_in > 1) else 0 + vmem_in_out += scalar_size if (n_out > 1) else 0 + vmem_in_out += out_size if (n_bs > 1 or n_out > 1) else 0 + + # Calculate scratch VMEM size. + acc_size = batch_block_size * out_block_size * acc_bytes + x_q_scratch_size = batch_block_size * in_block_size * x_q_bytes + x_scale_scratch_size = batch_block_size * scale_bytes + + vmem_scratch = acc_size if save_acc else 0 + vmem_scratch += x_q_scratch_size + x_scale_scratch_size if save_x_q else 0 + vmem_scratch *= 2 # Account for compute and vreg spills. + + # Add in/out and scratch VMEM size. + vmem_used = vmem_in_out + vmem_scratch + # Specify upper limit as 96MB. + vmem_limit_bytes = min(vmem_used, 96 * 1024 * 1024) + + return vmem_limit_bytes + + @functools.partial( jax.jit, static_argnames=[ @@ -196,17 +234,6 @@ def quantized_matmul_int8( assert x.shape[ 1] % in_block_size == 0, f"x.shape[1] ({x.shape[1]}) must be a multiple of block size ({in_block_size})" - acc_dtype = jnp.int32 if quantize_activation else x.dtype - vmem_to_be_transferred = 2 * ( - batch_block_size * in_block_size * x.dtype.itemsize + - out_block_size * in_block_size * w.dtype.itemsize + out_block_size * - scalar.dtype.itemsize + batch_block_size * x_abs_max_val.dtype.itemsize + - batch_block_size * out_block_size * x.dtype.itemsize - ) + batch_block_size * out_block_size * jnp.dtype(acc_dtype).itemsize - # Within the kernel, it will use some extra VMEM for computation or vreg spills. - vmem_used = vmem_to_be_transferred * 2 - vmem_limit_bytes = min(vmem_used * 2, 96 * 1024 * 1024) - n_bs = padded_bs // batch_block_size n_out = padded_out_features // out_block_size n_in = padded_in_features // in_block_size @@ -214,14 +241,32 @@ def quantized_matmul_int8( save_acc = n_in > 1 # Remove redundant input quantization logic by caching quantized input. # For best performance, only enable this behavior when single input block is used per batch. - save_q_x = quantize_activation and n_in == 1 and n_out > 1 + save_x_q = quantize_activation and n_in == 1 and n_out > 1 + + acc_dtype = jnp.int32 if quantize_activation else jnp.float32 + + vmem_limit_bytes = _get_vmem_limit( + n_bs=n_bs, + n_out=n_out, + n_in=n_in, + batch_block_size=batch_block_size, + out_block_size=out_block_size, + in_block_size=in_block_size, + x_bytes=x.dtype.itemsize, + w_bytes=w.dtype.itemsize, + x_q_bytes=jnp.dtype(jnp.int8).itemsize, + scale_bytes=jnp.dtype(jnp.float32).itemsize, + out_bytes=x.dtype.itemsize, + acc_bytes=jnp.dtype(acc_dtype).itemsize, + save_acc=save_acc, + save_x_q=save_x_q) kernel = pl.pallas_call( functools.partial( matmul_kernel, quantize_activation=quantize_activation, save_acc=save_acc, - save_q_x=save_q_x, + save_x_q=save_x_q, batch_block_size=batch_block_size, out_block_size=out_block_size, in_block_size=in_block_size), @@ -243,9 +288,9 @@ def quantized_matmul_int8( pltpu.VMEM((batch_block_size, out_block_size), acc_dtype) if save_acc else None, pltpu.VMEM((batch_block_size, - in_block_size), jnp.int8) if save_q_x else None, + in_block_size), jnp.int8) if save_x_q else None, pltpu.VMEM( - (batch_block_size, 1), jnp.float32) if save_q_x else None, + (batch_block_size, 1), jnp.float32) if save_x_q else None, ], grid=(n_bs, n_out, n_in), ), From 2820f7c528daf9cb0c5c61fd9b034b68b5f56309 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Mon, 28 Jul 2025 19:18:08 -0300 Subject: [PATCH 014/133] Error Handling: return status value when loading PjRt dynamic plugin. (#9495) --- test/run_tests.sh | 1 + ...est_runtime_client_initialization_error.py | 35 +++++++++++++++++++ torch_xla/csrc/runtime/pjrt_registry.cpp | 23 ++++++++---- torch_xla/csrc/runtime/runtime.cpp | 12 ++++--- 4 files changed, 60 insertions(+), 11 deletions(-) create mode 100644 test/test_runtime_client_initialization_error.py diff --git a/test/run_tests.sh b/test/run_tests.sh index b2cc8f751d2c..ec92a0a2691c 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -227,6 +227,7 @@ function run_xla_op_tests2 { run_test "$_TEST_DIR/test_assume_pure_spmd.py" run_test "$_TEST_DIR/test_assume_pure_torch.py" run_test "$_TEST_DIR/test_dynamic_shapes_detector.py" + run_test "$_TEST_DIR/test_runtime_client_initialization_error.py" } function run_xla_op_tests3 { diff --git a/test/test_runtime_client_initialization_error.py b/test/test_runtime_client_initialization_error.py new file mode 100644 index 000000000000..5260661948f3 --- /dev/null +++ b/test/test_runtime_client_initialization_error.py @@ -0,0 +1,35 @@ +import os +import torch_xla +import torch_xla.core.xla_env_vars as xenv +import unittest + + +class TestClientInitializationError(unittest.TestCase): + + def test(self): + + def initialize_client(device): + os.environ[xenv.PJRT_DEVICE] = device + + # The message does not change! + # After the first call with DUMMY_DEVICE, all other calls will have + # "DUMMY_DEVICE" in their message. + message = ( + f"No PjRtPlugin registered for: DUMMY_DEVICE. " + f"Make sure the environment variable {xenv.PJRT_DEVICE} is set " + "to a correct device name.") + + with self.assertRaisesRegex(RuntimeError, expected_regex=message): + torch_xla._XLAC._init_computation_client() + + # Run the initialization function the first time, ending up in an + # exception thrown. + initialize_client("DUMMY_DEVICE") + + # Even if the device exists, this call should fail, since the result + # of the first call is cached. + initialize_client("CPU") + + +if __name__ == '__main__': + unittest.main() diff --git a/torch_xla/csrc/runtime/pjrt_registry.cpp b/torch_xla/csrc/runtime/pjrt_registry.cpp index 87de9b6fd1cf..44603efca833 100644 --- a/torch_xla/csrc/runtime/pjrt_registry.cpp +++ b/torch_xla/csrc/runtime/pjrt_registry.cpp @@ -61,10 +61,21 @@ xla::GpuAllocatorConfig GetGpuAllocatorConfig() { return allocator_config; } -std::shared_ptr GetPjRtPlugin( +absl::StatusOr> GetPjRtPlugin( const std::string& device_type) { - auto plugin_path = pjrt_plugins_.find(device_type); - return plugin_path != pjrt_plugins_.end() ? plugin_path->second : nullptr; + auto entry = pjrt_plugins_.find(device_type); + if (entry == pjrt_plugins_.end()) { + std::string message = absl::StrCat( + "No PjRtPlugin registered for: ", device_type, + ". Make sure the environment variable ", env::kEnvPjRtDevice, + " is set to a correct device name. See " + "https://github.com/pytorch/xla/blob/master/docs/source/" + "contribute/plugins.md for more information on " + "implementing and registering a new " + "plugin."); + return XLA_ERROR_WITH_LOCATION(absl::FailedPreconditionError(message)); + } + return entry->second; } } // namespace @@ -83,10 +94,10 @@ InitializePjRt(const std::string& device_type) { if (sys_util::GetEnvBool(env::kEnvPjrtDynamicPlugins, false) && device_type != "CPU") { - std::shared_ptr plugin = GetPjRtPlugin(device_type); + TF_VLOG(1) << "Initializing client for PjRt plugin " << device_type; + XLA_ASSIGN_OR_RETURN(std::shared_ptr plugin, + GetPjRtPlugin(device_type)); if (plugin) { - TF_VLOG(1) << "Initializing client for PjRt plugin " << device_type; - // Init the absl logging to avoid the log spam. absl::InitializeLog(); diff --git a/torch_xla/csrc/runtime/runtime.cpp b/torch_xla/csrc/runtime/runtime.cpp index 7603eb524c5d..69bb53fe1df4 100644 --- a/torch_xla/csrc/runtime/runtime.cpp +++ b/torch_xla/csrc/runtime/runtime.cpp @@ -19,10 +19,6 @@ static std::atomic g_computation_client_initialized(false); // Can only be called when g_computation_client_initialized is false. static absl::StatusOr InitializeComputationClient() { - ABSL_CHECK(!g_computation_client_initialized) - << "InitializeComputationClient() can only be called once."; - g_computation_client_initialized = true; - if (sys_util::GetEnvBool("XLA_DUMP_FATAL_STACK", false)) { tsl::testing::InstallStacktraceHandler(); } @@ -37,6 +33,9 @@ InitializeComputationClient() { absl::FailedPreconditionError("$PJRT_DEVICE is not set.")); } + ABSL_CHECK(!g_computation_client_initialized) + << "ComputationClient can only be initialized once."; + std::unique_ptr client; if (use_ifrt) { XLA_ASSIGN_OR_RETURN(client, IfrtComputationClient::Create()); @@ -44,6 +43,9 @@ InitializeComputationClient() { XLA_ASSIGN_OR_RETURN(client, PjRtComputationClient::Create()); } + // Set only if we actually successfully initialized a client. + g_computation_client_initialized = true; + return client.release(); } @@ -59,7 +61,7 @@ const absl::StatusOr& GetComputationClient() { } ComputationClient* absl_nonnull GetComputationClientOrDie() { - return GetComputationClient().value(); + return GetValueOrThrow(GetComputationClient()); } ComputationClient* GetComputationClientIfInitialized() { From 531c724fb0a5549efa4b3085d80e641a4d8b28b3 Mon Sep 17 00:00:00 2001 From: XiongfeiWei Date: Mon, 28 Jul 2025 17:06:03 -0700 Subject: [PATCH 015/133] Add block sizes for Qwen/Qwen2.5-32B-Instruct (#9516) --- torch_xla/experimental/custom_kernel.py | 4 + .../pallas_kernels/quantized_matmul_kernel.py | 188 +++++++++++++----- 2 files changed, 142 insertions(+), 50 deletions(-) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 8a71ae7432b1..0e273c5a4912 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -2,6 +2,7 @@ import os import math import warnings +import logging import torch from torch.library import impl, custom_op @@ -1097,6 +1098,9 @@ def quantized_matmul_int8( "out_block_size": out_block_size, "in_block_size": in_block_size, }) + logging.warning( + f"Couldn't find w8a8 quantized matmul kernel block sizes for {bs=}, {n_out_features=}, {n_in_features=}, {jnp.dtype(jax_dtype).name=}, {quantize_activation=}, falling back to XLA quantized matmul kernel." + ) from torch_xla.experimental.xla_quantized_matmul import quantized_matmul_xla return quantized_matmul_xla( x, w, scalar, quantize_activation=quantize_activation).to(x.dtype) diff --git a/torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py b/torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py index 6401f87a8dd7..d7356be54be6 100644 --- a/torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py +++ b/torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py @@ -324,70 +324,158 @@ def quantized_matmul_int8( # - out_block_size # - in_block_size TUNED_BLOCK_SIZES = { - (6, 128, 28672, 4096, 'bfloat16', True): (128, 28672, 256), - (6, 128, 4096, 14336, 'bfloat16', True): (128, 4096, 896), - (6, 2048, 6144, 4096, 'bfloat16', True): (2048, 512, 4096), - (6, 2048, 4096, 4096, 'bfloat16', True): (2048, 512, 4096), - (6, 2048, 4096, 14336, 'bfloat16', True): (2048, 4096, 512), - (6, 128, 6144, 4096, 'bfloat16', True): (128, 768, 4096), - (6, 128, 4096, 4096, 'bfloat16', True): (128, 512, 4096), - (6, 2048, 28672, 4096, 'bfloat16', True): (2048, 1024, 4096), - (6, 16, 6144, 4096, 'bfloat16', True): (128, 768, 4096), - (6, 16, 4096, 4096, 'bfloat16', True): (128, 512, 4096), - (6, 64, 28672, 4096, 'bfloat16', True): (128, 28672, 256), - (6, 64, 4096, 14336, 'bfloat16', True): (128, 4096, 896), - (6, 256, 6144, 4096, 'bfloat16', True): (256, 512, 4096), - (6, 256, 4096, 4096, 'bfloat16', True): (256, 512, 4096), - (6, 256, 28672, 4096, 'bfloat16', True): (256, 2048, 4096), - (6, 256, 4096, 14336, 'bfloat16', True): (256, 4096, 512), - (6, 16, 28672, 4096, 'bfloat16', True): (128, 28672, 256), - (6, 512, 6144, 4096, 'bfloat16', True): (512, 1024, 4096), - (6, 512, 4096, 4096, 'bfloat16', True): (512, 1024, 4096), - (6, 512, 28672, 4096, 'bfloat16', True): (512, 2048, 4096), - (6, 512, 4096, 14336, 'bfloat16', True): (512, 256, 14336), - (6, 1024, 6144, 4096, 'bfloat16', True): (1024, 768, 4096), - (6, 1024, 4096, 4096, 'bfloat16', True): (1024, 512, 4096), + (6, 1024, 1280, 8192, 'bfloat16', True): (1024, 256, 8192), + (6, 1024, 13824, 5120, 'bfloat16', True): (1024, 768, 5120), + (6, 1024, 1792, 5120, 'bfloat16', True): (1024, 256, 5120), (6, 1024, 28672, 4096, 'bfloat16', True): (1024, 2048, 4096), (6, 1024, 4096, 14336, 'bfloat16', True): (1024, 256, 14336), - (6, 16, 4096, 14336, 'bfloat16', True): (128, 4096, 896), - (6, 32, 6144, 4096, 'bfloat16', True): (128, 768, 4096), - (6, 32, 4096, 4096, 'bfloat16', True): (128, 512, 4096), - (6, 32, 28672, 4096, 'bfloat16', True): (128, 28672, 256), - (6, 32, 4096, 14336, 'bfloat16', True): (128, 4096, 896), - (6, 64, 6144, 4096, 'bfloat16', True): (128, 768, 4096), - (6, 64, 4096, 4096, 'bfloat16', True): (128, 512, 4096), - (6, 16, 1280, 8192, 'bfloat16', True): (128, 256, 8192), - (6, 16, 8192, 1024, 'bfloat16', True): (128, 2048, 1024), - (6, 64, 7168, 8192, 'bfloat16', True): (128, 256, 8192), - (6, 64, 8192, 3584, 'bfloat16', True): (128, 1024, 3584), + (6, 1024, 4096, 4096, 'bfloat16', True): (1024, 512, 4096), + (6, 1024, 5120, 1280, 'bfloat16', True): (1024, 1280, 1280), + (6, 1024, 5120, 3456, 'bfloat16', True): (1024, 1024, 3456), + (6, 1024, 5120, 640, 'bfloat16', True): (256, 5120, 640), + (6, 1024, 5120, 6912, 'bfloat16', True): (1024, 512, 6912), + (6, 1024, 6144, 4096, 'bfloat16', True): (1024, 768, 4096), + (6, 1024, 6912, 5120, 'bfloat16', True): (1024, 768, 5120), + (6, 1024, 7168, 8192, 'bfloat16', True): (1024, 512, 8192), + (6, 1024, 8192, 1024, 'bfloat16', True): (1024, 4096, 1024), + (6, 1024, 8192, 3584, 'bfloat16', True): (1024, 1024, 3584), + (6, 1024, 896, 5120, 'bfloat16', True): (1024, 896, 2560), (6, 128, 1280, 8192, 'bfloat16', True): (128, 1280, 2048), - (6, 128, 8192, 1024, 'bfloat16', True): (128, 2048, 1024), + (6, 128, 13824, 5120, 'bfloat16', True): (128, 512, 5120), + (6, 128, 1792, 5120, 'bfloat16', True): (128, 1792, 1280), + (6, 128, 28672, 4096, 'bfloat16', True): (128, 28672, 256), + (6, 128, 4096, 14336, 'bfloat16', True): (128, 4096, 896), + (6, 128, 4096, 4096, 'bfloat16', True): (128, 512, 4096), + (6, 128, 5120, 1280, 'bfloat16', True): (128, 1280, 1280), + (6, 128, 5120, 3456, 'bfloat16', True): (128, 640, 3456), + (6, 128, 5120, 640, 'bfloat16', True): (128, 2560, 640), + (6, 128, 5120, 6912, 'bfloat16', True): (128, 2560, 1152), + (6, 128, 6144, 4096, 'bfloat16', True): (128, 768, 4096), + (6, 128, 6912, 5120, 'bfloat16', True): (128, 1152, 2560), (6, 128, 7168, 8192, 'bfloat16', True): (128, 256, 8192), + (6, 128, 8192, 1024, 'bfloat16', True): (128, 2048, 1024), (6, 128, 8192, 3584, 'bfloat16', True): (128, 8192, 512), - (6, 256, 1280, 8192, 'bfloat16', True): (256, 256, 8192), - (6, 256, 8192, 1024, 'bfloat16', True): (256, 2048, 1024), - (6, 256, 7168, 8192, 'bfloat16', True): (256, 512, 8192), - (6, 256, 8192, 3584, 'bfloat16', True): (256, 8192, 512), + (6, 128, 896, 5120, 'bfloat16', True): (128, 896, 2560), + (6, 16, 1280, 8192, 'bfloat16', True): (128, 256, 8192), + (6, 16, 13824, 5120, 'bfloat16', True): (128, 512, 5120), + (6, 16, 1792, 5120, 'bfloat16', True): (128, 896, 2560), + (6, 16, 28672, 4096, 'bfloat16', True): (128, 28672, 256), + (6, 16, 4096, 14336, 'bfloat16', True): (128, 4096, 896), + (6, 16, 4096, 4096, 'bfloat16', True): (128, 512, 4096), + (6, 16, 5120, 1280, 'bfloat16', True): (128, 1280, 1280), + (6, 16, 5120, 3456, 'bfloat16', True): (128, 640, 3456), + (6, 16, 5120, 640, 'bfloat16', True): (128, 2560, 640), + (6, 16, 5120, 6912, 'bfloat16', True): (128, 1280, 2304), + (6, 16, 6144, 4096, 'bfloat16', True): (128, 768, 4096), + (6, 16, 6912, 5120, 'bfloat16', True): (128, 1152, 2560), (6, 16, 7168, 8192, 'bfloat16', True): (128, 256, 8192), - (6, 512, 1280, 8192, 'bfloat16', True): (512, 256, 8192), - (6, 512, 8192, 1024, 'bfloat16', True): (512, 4096, 1024), - (6, 512, 7168, 8192, 'bfloat16', True): (512, 512, 8192), - (6, 512, 8192, 3584, 'bfloat16', True): (512, 2048, 3584), - (6, 1024, 1280, 8192, 'bfloat16', True): (1024, 256, 8192), - (6, 1024, 8192, 1024, 'bfloat16', True): (1024, 4096, 1024), - (6, 1024, 7168, 8192, 'bfloat16', True): (1024, 512, 8192), - (6, 1024, 8192, 3584, 'bfloat16', True): (1024, 1024, 3584), - (6, 2048, 1280, 8192, 'bfloat16', True): (2048, 256, 8192), - (6, 2048, 8192, 1024, 'bfloat16', True): (256, 8192, 1024), + (6, 16, 8192, 1024, 'bfloat16', True): (128, 2048, 1024), (6, 16, 8192, 3584, 'bfloat16', True): (128, 1024, 3584), + (6, 16, 896, 5120, 'bfloat16', True): (128, 896, 2560), + (6, 16384, 13824, 5120, 'bfloat16', True): (2048, 1536, 5120), + (6, 16384, 1792, 5120, 'bfloat16', True): (1024, 1792, 5120), + (6, 16384, 5120, 1280, 'bfloat16', True): (512, 5120, 1280), + (6, 16384, 5120, 3456, 'bfloat16', True): (512, 5120, 3456), + (6, 16384, 5120, 640, 'bfloat16', True): (512, 5120, 640), + (6, 16384, 5120, 6912, 'bfloat16', True): (512, 5120, 6912), + (6, 16384, 6912, 5120, 'bfloat16', True): (512, 6912, 5120), + (6, 16384, 896, 5120, 'bfloat16', True): (1024, 896, 5120), + (6, 2048, 1280, 8192, 'bfloat16', True): (2048, 256, 8192), + (6, 2048, 13824, 5120, 'bfloat16', True): (2048, 768, 5120), + (6, 2048, 1792, 5120, 'bfloat16', True): (2048, 256, 5120), + (6, 2048, 28672, 4096, 'bfloat16', True): (2048, 1024, 4096), + (6, 2048, 4096, 14336, 'bfloat16', True): (2048, 4096, 512), + (6, 2048, 4096, 4096, 'bfloat16', True): (2048, 512, 4096), + (6, 2048, 5120, 1280, 'bfloat16', True): (256, 5120, 1280), + (6, 2048, 5120, 3456, 'bfloat16', True): (2048, 512, 3456), + (6, 2048, 5120, 640, 'bfloat16', True): (256, 5120, 640), + (6, 2048, 5120, 6912, 'bfloat16', True): (2048, 512, 6912), + (6, 2048, 6144, 4096, 'bfloat16', True): (2048, 512, 4096), + (6, 2048, 6912, 5120, 'bfloat16', True): (2048, 768, 5120), (6, 2048, 7168, 8192, 'bfloat16', True): (2048, 256, 8192), + (6, 2048, 8192, 1024, 'bfloat16', True): (256, 8192, 1024), (6, 2048, 8192, 3584, 'bfloat16', True): (2048, 512, 3584), + (6, 2048, 896, 5120, 'bfloat16', True): (1024, 896, 5120), + (6, 256, 1280, 8192, 'bfloat16', True): (256, 256, 8192), + (6, 256, 13824, 5120, 'bfloat16', True): (256, 512, 5120), + (6, 256, 1792, 5120, 'bfloat16', True): (256, 1792, 1280), + (6, 256, 28672, 4096, 'bfloat16', True): (256, 2048, 4096), + (6, 256, 4096, 14336, 'bfloat16', True): (256, 4096, 512), + (6, 256, 4096, 4096, 'bfloat16', True): (256, 512, 4096), + (6, 256, 5120, 1280, 'bfloat16', True): (256, 2560, 1280), + (6, 256, 5120, 3456, 'bfloat16', True): (256, 1024, 3456), + (6, 256, 5120, 640, 'bfloat16', True): (256, 2560, 640), + (6, 256, 5120, 6912, 'bfloat16', True): (256, 5120, 768), + (6, 256, 6144, 4096, 'bfloat16', True): (256, 512, 4096), + (6, 256, 6912, 5120, 'bfloat16', True): (256, 6912, 512), + (6, 256, 7168, 8192, 'bfloat16', True): (256, 512, 8192), + (6, 256, 8192, 1024, 'bfloat16', True): (256, 2048, 1024), + (6, 256, 8192, 3584, 'bfloat16', True): (256, 8192, 512), + (6, 256, 896, 5120, 'bfloat16', True): (256, 896, 2560), (6, 32, 1280, 8192, 'bfloat16', True): (128, 256, 8192), - (6, 32, 8192, 1024, 'bfloat16', True): (128, 2048, 1024), + (6, 32, 13824, 5120, 'bfloat16', True): (128, 512, 5120), + (6, 32, 1792, 5120, 'bfloat16', True): (128, 896, 2560), + (6, 32, 28672, 4096, 'bfloat16', True): (128, 28672, 256), + (6, 32, 4096, 14336, 'bfloat16', True): (128, 4096, 896), + (6, 32, 4096, 4096, 'bfloat16', True): (128, 512, 4096), + (6, 32, 5120, 1280, 'bfloat16', True): (128, 1280, 1280), + (6, 32, 5120, 3456, 'bfloat16', True): (128, 640, 3456), + (6, 32, 5120, 640, 'bfloat16', True): (128, 2560, 640), + (6, 32, 5120, 6912, 'bfloat16', True): (128, 1280, 2304), + (6, 32, 6144, 4096, 'bfloat16', True): (128, 768, 4096), + (6, 32, 6912, 5120, 'bfloat16', True): (128, 2304, 1280), (6, 32, 7168, 8192, 'bfloat16', True): (128, 256, 8192), + (6, 32, 8192, 1024, 'bfloat16', True): (128, 2048, 1024), (6, 32, 8192, 3584, 'bfloat16', True): (128, 1024, 3584), + (6, 32, 896, 5120, 'bfloat16', True): (128, 896, 2560), + (6, 4096, 13824, 5120, 'bfloat16', True): (2048, 1536, 5120), + (6, 4096, 1792, 5120, 'bfloat16', True): (512, 1792, 5120), + (6, 4096, 5120, 1280, 'bfloat16', True): (256, 5120, 1280), + (6, 4096, 5120, 3456, 'bfloat16', True): (4096, 512, 3456), + (6, 4096, 5120, 640, 'bfloat16', True): (256, 5120, 640), + (6, 4096, 5120, 6912, 'bfloat16', True): (256, 5120, 6912), + (6, 4096, 6912, 5120, 'bfloat16', True): (256, 6912, 5120), + (6, 4096, 896, 5120, 'bfloat16', True): (256, 896, 5120), + (6, 512, 1280, 8192, 'bfloat16', True): (512, 256, 8192), + (6, 512, 13824, 5120, 'bfloat16', True): (512, 13824, 512), + (6, 512, 1792, 5120, 'bfloat16', True): (512, 1792, 1280), + (6, 512, 28672, 4096, 'bfloat16', True): (512, 2048, 4096), + (6, 512, 4096, 14336, 'bfloat16', True): (512, 256, 14336), + (6, 512, 4096, 4096, 'bfloat16', True): (512, 1024, 4096), + (6, 512, 5120, 1280, 'bfloat16', True): (512, 2560, 1280), + (6, 512, 5120, 3456, 'bfloat16', True): (512, 1280, 3456), + (6, 512, 5120, 640, 'bfloat16', True): (512, 2560, 640), + (6, 512, 5120, 6912, 'bfloat16', True): (512, 512, 6912), + (6, 512, 6144, 4096, 'bfloat16', True): (512, 1024, 4096), + (6, 512, 6912, 5120, 'bfloat16', True): (512, 768, 5120), + (6, 512, 7168, 8192, 'bfloat16', True): (512, 512, 8192), + (6, 512, 8192, 1024, 'bfloat16', True): (512, 4096, 1024), + (6, 512, 8192, 3584, 'bfloat16', True): (512, 2048, 3584), + (6, 512, 896, 5120, 'bfloat16', True): (512, 896, 2560), (6, 64, 1280, 8192, 'bfloat16', True): (128, 256, 8192), + (6, 64, 13824, 5120, 'bfloat16', True): (128, 512, 5120), + (6, 64, 1792, 5120, 'bfloat16', True): (128, 896, 2560), + (6, 64, 28672, 4096, 'bfloat16', True): (128, 28672, 256), + (6, 64, 4096, 14336, 'bfloat16', True): (128, 4096, 896), + (6, 64, 4096, 4096, 'bfloat16', True): (128, 512, 4096), + (6, 64, 5120, 1280, 'bfloat16', True): (128, 1280, 1280), + (6, 64, 5120, 3456, 'bfloat16', True): (128, 1024, 3456), + (6, 64, 5120, 640, 'bfloat16', True): (128, 2560, 640), + (6, 64, 5120, 6912, 'bfloat16', True): (128, 1280, 2304), + (6, 64, 6144, 4096, 'bfloat16', True): (128, 768, 4096), + (6, 64, 6912, 5120, 'bfloat16', True): (128, 2304, 1280), + (6, 64, 7168, 8192, 'bfloat16', True): (128, 256, 8192), (6, 64, 8192, 1024, 'bfloat16', True): (128, 2048, 1024), + (6, 64, 8192, 3584, 'bfloat16', True): (128, 1024, 3584), + (6, 64, 896, 5120, 'bfloat16', True): (128, 896, 2560), + (6, 8192, 13824, 5120, 'bfloat16', True): (2048, 1536, 5120), + (6, 8192, 1792, 5120, 'bfloat16', True): (512, 1792, 5120), + (6, 8192, 5120, 1280, 'bfloat16', True): (256, 5120, 1280), + (6, 8192, 5120, 3456, 'bfloat16', True): (512, 5120, 3456), + (6, 8192, 5120, 640, 'bfloat16', True): (512, 5120, 640), + (6, 8192, 5120, 6912, 'bfloat16', True): (512, 5120, 6912), + (6, 8192, 6912, 5120, 'bfloat16', True): (512, 6912, 5120), + (6, 8192, 896, 5120, 'bfloat16', True): (512, 896, 5120), } From 1ed6b46c0ee8be13619500a95336daf1609f4d0b Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Tue, 29 Jul 2025 12:03:31 -0300 Subject: [PATCH 016/133] Error Handling: propagate status for `ReleaseGilAndTransferData` and `XlaDataToTensors`. (#9431) --- test/cpp/test_xla_sharding.cpp | 2 +- torch_xla/csrc/init_python_bindings.cpp | 2 +- torch_xla/csrc/tensor.cpp | 3 ++- torch_xla/csrc/tensor_util.cpp | 16 ++++++++++------ torch_xla/csrc/tensor_util.h | 4 ++-- torch_xla/csrc/xla_backend_impl.cpp | 4 +++- torch_xla/csrc/xla_graph_executor.cpp | 3 ++- 7 files changed, 21 insertions(+), 13 deletions(-) diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index 3d276f2dc263..b179c6e523cc 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -29,7 +29,7 @@ bool XlaDataValuesEqual(torch::lazy::BackendDataPtr a, torch::lazy::BackendDataPtr b, at::ScalarType element_type) { std::vector tensors = - XlaDataToTensors({a, b}, {element_type, element_type}); + GetValueOrThrow(XlaDataToTensors({a, b}, {element_type, element_type})); return TensorCompare(tensors[0], tensors[1]); } } // namespace diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 5b62d95efd57..8873fb434e0f 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2712,7 +2712,7 @@ void InitXlaModuleBindings(py::module m) { } std::vector cpu_shards = - XlaDataToTensors(WrapXlaData(handles), element_types); + GetValueOrThrow(XlaDataToTensors(WrapXlaData(handles), element_types)); // Populate the resulting vector of shards and device strings std::vector>> result; int shards_per_tensor = diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 1a1a7737ccfe..6459293a87ff 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -40,6 +40,7 @@ #include "torch_xla/csrc/runtime/pjrt_computation_client.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/xla_util.h" +#include "torch_xla/csrc/status.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/torch_util.h" #include "torch_xla/csrc/xla_graph_executor.h" @@ -512,7 +513,7 @@ at::Tensor XLATensor::ToTensor(bool detached) { // The GetXlaData() call will trigger an ApplyPendingGraph() if an IR // XlaNode is available on the tensor. std::vector tensors = - XlaDataToTensors({GetXlaData()}, {dtype()}); + GetValueOrThrow(XlaDataToTensors({GetXlaData()}, {dtype()})); tensor = std::move(tensors.front()); if (!detached) { SetTensorData(tensor); diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index e2cd3a025f59..26c669b1e4f8 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -896,7 +896,7 @@ xla::Literal GetTensorLiteral(const at::Tensor& tensor, const xla::Shape* shape, return literal; } -std::vector ReleaseGilAndTransferData( +absl::StatusOr> ReleaseGilAndTransferData( absl::Span xla_data) { // HACK: This method may be called outside of python (mainly in C++ tests) or // when the GIL is already released, so we must check both cases here. If @@ -909,9 +909,12 @@ std::vector ReleaseGilAndTransferData( if (release_gil && Py_IsInitialized() && PyGILState_Check()) { save = PyEval_SaveThread(); } - std::vector literals = - GetValueOrThrow(runtime::GetComputationClientOrDie()->TransferFromDevice( - UnwrapXlaData(xla_data))); + + XLA_ASSIGN_OR_RETURN(runtime::ComputationClient * client, + runtime::GetComputationClient()); + XLA_ASSIGN_OR_RETURN(std::vector literals, + client->TransferFromDevice(UnwrapXlaData(xla_data))); + if (save) { PyEval_RestoreThread(save); } @@ -919,10 +922,11 @@ std::vector ReleaseGilAndTransferData( return literals; } -std::vector XlaDataToTensors( +absl::StatusOr> XlaDataToTensors( absl::Span xla_data, absl::Span dest_element_type) { - std::vector literals = ReleaseGilAndTransferData(xla_data); + XLA_ASSIGN_OR_RETURN(std::vector literals, + ReleaseGilAndTransferData(xla_data)); std::vector tensors(literals.size()); absl::BlockingCounter counter(literals.size()); for (size_t i = 0; i < tensors.size(); ++i) { diff --git a/torch_xla/csrc/tensor_util.h b/torch_xla/csrc/tensor_util.h index 0804d3e9f781..a0f6dea480f1 100644 --- a/torch_xla/csrc/tensor_util.h +++ b/torch_xla/csrc/tensor_util.h @@ -28,11 +28,11 @@ at::Tensor MakeTensorFromXlaLiteral(const xla::Literal& literal, // Execution and data transfer are async in PJRT, so TransferFromDevice may // block until `DataPtr`s are ready. Release the GIL so other threads can // proceed and unblock any transfers or collective computations. -std::vector ReleaseGilAndTransferData( +absl::StatusOr> ReleaseGilAndTransferData( absl::Span xla_data); // TODO LTC @wonjoo - Migrate to upstream after Device -> BackendDevice -std::vector XlaDataToTensors( +absl::StatusOr> XlaDataToTensors( absl::Span xla_data, absl::Span dest_element_type); diff --git a/torch_xla/csrc/xla_backend_impl.cpp b/torch_xla/csrc/xla_backend_impl.cpp index bf130e1fab73..df52770b11ef 100644 --- a/torch_xla/csrc/xla_backend_impl.cpp +++ b/torch_xla/csrc/xla_backend_impl.cpp @@ -10,6 +10,8 @@ #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/runtime.h" +#include "torch_xla/csrc/status.h" +#include "torch_xla/csrc/tensor_util.h" namespace at { // This function is defined in the codegenerated RegisterDispatchKey.cpp file. @@ -92,7 +94,7 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface { const torch::lazy::BackendDataPtr data, std::optional logical_scalar_type) const override { // TODO(JackCaoG): handle the logical_scalar_type == nullptr case - return XlaDataToTensors({data}, {*logical_scalar_type})[0]; + return GetValueOrThrow(XlaDataToTensors({data}, {*logical_scalar_type}))[0]; } std::unique_ptr CreateLoweringContext( diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 65eee78bc023..0931578047e7 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -497,7 +497,8 @@ std::vector XLAGraphExecutor::GetTensors( async != nullptr ? async->tensors_data : absl::Span()); - std::vector literals = ReleaseGilAndTransferData(tensors_data); + std::vector literals = + GetValueOrThrow(ReleaseGilAndTransferData(tensors_data)); return FetchTensors(tensors, literals, async != nullptr ? &async->indices : nullptr); From b0ffc4917f25918a79afe3cc5b4d6afee32ea608 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Tue, 29 Jul 2025 13:24:12 -0300 Subject: [PATCH 017/133] Error Handling: refactor `ExecuteComputation` and `ExecuteReplicated` to propagate status. (#9445) --- test/cpp/cpp_test_util.cpp | 8 +++-- test/cpp/test_replication.cpp | 4 +-- torch_xla/csrc/runtime/computation_client.h | 5 +-- .../csrc/runtime/ifrt_computation_client.cpp | 22 ++++++------ .../csrc/runtime/ifrt_computation_client.h | 4 +-- .../runtime/ifrt_computation_client_test.cpp | 7 ++-- .../csrc/runtime/pjrt_computation_client.cpp | 36 ++++++++++--------- .../csrc/runtime/pjrt_computation_client.h | 4 +-- .../runtime/pjrt_computation_client_test.cpp | 8 +++-- torch_xla/csrc/xla_backend_impl.cpp | 4 +-- torch_xla/csrc/xla_graph_executor.cpp | 34 ++++++++++-------- 11 files changed, 74 insertions(+), 62 deletions(-) diff --git a/test/cpp/cpp_test_util.cpp b/test/cpp/cpp_test_util.cpp index afe573101ebc..6731f5800fd5 100644 --- a/test/cpp/cpp_test_util.cpp +++ b/test/cpp/cpp_test_util.cpp @@ -295,9 +295,11 @@ std::vector Execute( std::move(instances)); torch_xla::runtime::ComputationClient::ExecuteComputationOptions options; - return torch_xla::runtime::GetComputationClientOrDie()->ExecuteComputation( - *computations.front(), UnwrapXlaData(lowering_ctx.GetParametersData()), - device.toString(), options); + return GetValueOrThrow( + torch_xla::runtime::GetComputationClientOrDie()->ExecuteComputation( + *computations.front(), + UnwrapXlaData(lowering_ctx.GetParametersData()), device.toString(), + options)); } std::vector Fetch( diff --git a/test/cpp/test_replication.cpp b/test/cpp/test_replication.cpp index b565dc44cd08..386f9db3a9a8 100644 --- a/test/cpp/test_replication.cpp +++ b/test/cpp/test_replication.cpp @@ -65,13 +65,13 @@ void TestSingleReplication( torch_xla::runtime::ComputationClient::ExecuteComputationOptions exec_options; for (size_t i = 0; i < device_strings.size(); ++i) { auto executor = [&, i]() { - results[i] = + results[i] = GetValueOrThrow( torch_xla::runtime::GetComputationClientOrDie()->ExecuteComputation( *compiled_computations[i], {std::dynamic_pointer_cast< torch_xla::runtime::ComputationClient::Data>( tensors_data[i])}, - device_strings[i], exec_options); + device_strings[i], exec_options)); counter.DecrementCount(); }; torch_xla::thread::Schedule(std::move(executor)); diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index c2f9389a4a0a..c5b550fb6846 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -16,6 +16,7 @@ #include #include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "torch_xla/csrc/device.h" @@ -346,7 +347,7 @@ class ComputationClient { // The passed device must match the common device of the arguments Data. // If options.explode_tuple is true, the output tuple will be decomposed into // its single elements. - virtual std::vector ExecuteComputation( + virtual absl::StatusOr> ExecuteComputation( const Computation& computation, absl::Span arguments, const std::string& device, const ExecuteComputationOptions& options = @@ -357,7 +358,7 @@ class ComputationClient { // as `devices`. If options.explode_tuple is true, the output tuples will be // decomposed into their single elements. Returns a vector of outputs, each // of which is sharded in the same order as `devices`. - virtual std::vector ExecuteReplicated( + virtual absl::StatusOr> ExecuteReplicated( const Computation& computation, absl::Span arguments, absl::Span devices, const ExecuteReplicatedOptions& options) = 0; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cpp b/torch_xla/csrc/runtime/ifrt_computation_client.cpp index f5a6af1b267c..5538cb4a5e22 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cpp +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cpp @@ -4,6 +4,7 @@ #include #include +#include "absl/log/absl_check.h" #include "absl/strings/ascii.h" #include "absl/synchronization/blocking_counter.h" #include "absl/types/span.h" @@ -416,8 +417,8 @@ tsl::RCReference IfrtComputationClient::ReplicateShardedData( torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions execute_options; - auto sharded_results = ExecuteReplicated(*computations.front(), {{handle}}, - GetLocalDevices(), execute_options); + auto sharded_results = GetValueOrThrow(ExecuteReplicated( + *computations.front(), {{handle}}, GetLocalDevices(), execute_options)); auto replicated_output = std::dynamic_pointer_cast(sharded_results[0]) ->buffer->FullyReplicatedShard( @@ -537,16 +538,16 @@ std::vector IfrtComputationClient::Compile( return computations; } -std::vector +absl::StatusOr> IfrtComputationClient::ExecuteComputation( const ComputationClient::Computation& computation, absl::Span arguments, const std::string& device, const ExecuteComputationOptions& options) { // TODO: Implement sharded exec in IFRT - XLA_ERROR() << __FUNCTION__ << " not implemented"; + return absl::UnimplementedError("ExecuteComputation not implemented"); } -std::vector +absl::StatusOr> IfrtComputationClient::ExecuteReplicated( const ComputationClient::Computation& computation, const absl::Span arguments, @@ -591,11 +592,10 @@ IfrtComputationClient::ExecuteReplicated( TF_VLOG(5) << "ExecuteReplicated acquiring IFRT device lock for " << spmd_device_str << " Done"; - xla::ifrt::LoadedExecutable::ExecuteResult result = - ifrt_computation.executable - ->Execute(absl::MakeSpan(argument_handles), execute_options, - std::nullopt) - .value(); + XLA_ASSIGN_OR_RETURN( + xla::ifrt::LoadedExecutable::ExecuteResult result, + ifrt_computation.executable->Execute(absl::MakeSpan(argument_handles), + execute_options, std::nullopt)); result.status.OnReady(std::move([timed, op_tracker = std::move(op_tracker)]( absl::Status status) mutable { @@ -612,7 +612,7 @@ IfrtComputationClient::ExecuteReplicated( ? *ifrt_computation.output_shardings_ : std::vector(outputs.size(), xla::HloSharding::Replicate().ToProto()); - XLA_CHECK_EQ(output_shardings.size(), outputs.size()); + ABSL_CHECK_EQ(output_shardings.size(), outputs.size()); std::vector data_handles(outputs.size()); { diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index 46b6343dc10a..ab24d1ae357b 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -78,12 +78,12 @@ class IfrtComputationClient : public ComputationClient { std::vector Compile( std::vector instances) override; - std::vector ExecuteComputation( + absl::StatusOr> ExecuteComputation( const Computation& computation, absl::Span arguments, const std::string& device, const ExecuteComputationOptions& options) override; - std::vector ExecuteReplicated( + absl::StatusOr> ExecuteReplicated( const Computation& computation, const absl::Span arguments, absl::Span devices, const ExecuteReplicatedOptions& options) override; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client_test.cpp b/torch_xla/csrc/runtime/ifrt_computation_client_test.cpp index eb39f9b2e23f..d48b4337d21c 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client_test.cpp +++ b/torch_xla/csrc/runtime/ifrt_computation_client_test.cpp @@ -64,9 +64,10 @@ TEST(PjRtComputationClientTest, Init) { std::make_shared(std::move(literal_y), device)}; // Execute the graph. - std::vector results = client->ExecuteReplicated( - *computations[0], client->TransferToDevice(absl::MakeConstSpan(args)), - {device}, options); + std::vector results = + GetValueOrThrow(client->ExecuteReplicated( + *computations[0], client->TransferToDevice(absl::MakeConstSpan(args)), + {device}, options)); // Copy the output from device back to host and assert correctness.. ASSERT_EQ(results.size(), 1); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cpp b/torch_xla/csrc/runtime/pjrt_computation_client.cpp index dd4950d87f5e..d57dbf9be6ce 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cpp +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cpp @@ -387,8 +387,8 @@ PjRtComputationClient::ReplicateShardedData( torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions execute_options; auto sharded_results = - ExecuteReplicated(*computations.front(), {sharded_data}, - GetLocalDevices(), execute_options); + GetValueOrThrow(ExecuteReplicated(*computations.front(), {sharded_data}, + GetLocalDevices(), execute_options)); XLA_CHECK(sharded_results.size() > 0) << "empty ExecuteReplicated results returned."; XLA_CHECK(sharded_results.size() == 1) @@ -474,8 +474,8 @@ std::vector PjRtComputationClient::ReshardData( torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions execute_options; - auto resharded_results = ExecuteReplicated( - *computation, handles, GetLocalDevices(), execute_options); + auto resharded_results = GetValueOrThrow(ExecuteReplicated( + *computation, handles, GetLocalDevices(), execute_options)); return resharded_results; } @@ -722,7 +722,7 @@ torch::lazy::hash_t PjRtComputationClient::HashCompilationEnv() { return comp_env_hash_; } -std::vector +absl::StatusOr> PjRtComputationClient::ExecuteComputation( const ComputationClient::Computation& computation, absl::Span arguments, @@ -742,14 +742,14 @@ PjRtComputationClient::ExecuteComputation( dynamic_cast(computation); xla::PjRtDevice* pjrt_device = StringToPjRtDevice(device); - XLA_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString(); + ABSL_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString(); std::vector buffers; buffers.reserve(arguments.size()); for (auto& argument : arguments) { const PjRtData* pjrt_data = dynamic_cast(argument.get()); - XLA_CHECK(pjrt_device == pjrt_data->buffer->device()) + ABSL_CHECK(pjrt_device == pjrt_data->buffer->device()) << "The device currently being used : " << pjrt_device->DebugString() << " is different from the device where the buffer resides: " << pjrt_data->buffer->device()->DebugString(); @@ -769,8 +769,9 @@ PjRtComputationClient::ExecuteComputation( << " Done"; std::optional> returned_future; - std::vector> results = - GetValueOrThrow(pjrt_computation.executable->ExecuteSharded( + XLA_ASSIGN_OR_RETURN( + std::vector> results, + pjrt_computation.executable->ExecuteSharded( buffers, pjrt_device, execute_options, returned_future)); returned_future->OnReady(std::move( @@ -795,7 +796,7 @@ PjRtComputationClient::ExecuteComputation( return datas; } -std::vector +absl::StatusOr> PjRtComputationClient::ExecuteReplicated( const ComputationClient::Computation& computation, absl::Span arguments, @@ -829,15 +830,15 @@ PjRtComputationClient::ExecuteReplicated( for (int32_t i = start; i < end; ++i) { auto pjrt_data = std::dynamic_pointer_cast(arguments[i]); - XLA_CHECK_EQ(pjrt_data->shards.size(), devices.size()) + ABSL_CHECK_EQ(pjrt_data->shards.size(), devices.size()) << "Expected one shard per device"; for (int32_t d = 0; d < devices.size(); d++) { std::shared_ptr shard = pjrt_data->shards[d]; xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[d]); - XLA_CHECK_EQ(shard->buffer->device(), pjrt_device); - XLA_CHECK(pjrt_device->IsAddressable()) + ABSL_CHECK_EQ(shard->buffer->device(), pjrt_device); + ABSL_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString(); argument_handles[d][i] = shard->buffer.get(); @@ -873,8 +874,9 @@ PjRtComputationClient::ExecuteReplicated( tsl::profiler::TraceMe activity( "PjRtComputationClient::ExecuteReplicated_execute", tsl::profiler::TraceMeLevel::kInfo); - results = GetValueOrThrow(pjrt_computation.executable->Execute( - std::move(argument_handles), execute_options, returned_futures)); + XLA_ASSIGN_OR_RETURN(results, pjrt_computation.executable->Execute( + std::move(argument_handles), + execute_options, returned_futures)); (*returned_futures)[0].OnReady( std::move([timed, op_tracker = std::move(op_tracker)]( @@ -897,7 +899,7 @@ PjRtComputationClient::ExecuteReplicated( const std::vector& output_shapes = result_shape.IsTuple() ? result_shape.tuple_shapes() : std::vector({result_shape}); - XLA_CHECK_EQ(output_shapes.size(), num_outputs); + ABSL_CHECK_EQ(output_shapes.size(), num_outputs); const std::vector& output_shardings = pjrt_computation.output_shardings_.has_value() && num_outputs > 0 @@ -906,7 +908,7 @@ PjRtComputationClient::ExecuteReplicated( // Without an explicit sharding annotation, the output is implicitly // replicated, and we mark explicitly replicated here. std::vector(num_outputs); - XLA_CHECK_EQ(output_shardings.size(), num_outputs); + ABSL_CHECK_EQ(output_shardings.size(), num_outputs); absl::BlockingCounter counter(num_outputs); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 3a6b4478f722..9a93d2864f4e 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -85,12 +85,12 @@ class PjRtComputationClient : public ComputationClient { ComputationPtr DeserializeComputation(const std::string& serialized) override; - std::vector ExecuteComputation( + absl::StatusOr> ExecuteComputation( const Computation& computation, absl::Span arguments, const std::string& device, const ExecuteComputationOptions& options) override; - std::vector ExecuteReplicated( + absl::StatusOr> ExecuteReplicated( const Computation& computation, absl::Span arguments, absl::Span devices, const ExecuteReplicatedOptions& options) override; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp b/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp index 0fe2b2a70fcb..64496312ae4d 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp +++ b/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp @@ -114,9 +114,11 @@ TEST_F(PjRtComputationClientTest, Init) { std::make_shared(std::move(literal_y), device_)}; // Execute the graph. - std::vector results = client_->ExecuteComputation( - *computations[0], client_->TransferToDevice(absl::MakeConstSpan(args)), - device_, options); + std::vector results = + GetValueOrThrow(client_->ExecuteComputation( + *computations[0], + client_->TransferToDevice(absl::MakeConstSpan(args)), device_, + options)); // Copy the output from device back to host and assert correctness. ASSERT_EQ(results.size(), 1); diff --git a/torch_xla/csrc/xla_backend_impl.cpp b/torch_xla/csrc/xla_backend_impl.cpp index df52770b11ef..78f8548ff178 100644 --- a/torch_xla/csrc/xla_backend_impl.cpp +++ b/torch_xla/csrc/xla_backend_impl.cpp @@ -163,11 +163,11 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface { torch::lazy::ComputationPtr computation, c10::ArrayRef arguments, const torch::lazy::BackendDevice& device) const override { - std::vector results = + std::vector results = GetValueOrThrow( runtime::GetComputationClientOrDie()->ExecuteComputation( *std::dynamic_pointer_cast( computation), - UnwrapXlaData(arguments), device.toString()); + UnwrapXlaData(arguments), device.toString())); return WrapXlaData(results); } diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 0931578047e7..8ea25adcf034 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -845,10 +845,11 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( // tensor results. Both sharded and unsharded results should be // "Assign"ed to the corresponding data placeholders. std::vector outputs = - runtime::GetComputationClientOrDie()->ExecuteReplicated( - *async->cached_computation->computation, - UnwrapXlaData(async->parameters_data), devices, - execute_options); + GetValueOrThrow( + runtime::GetComputationClientOrDie()->ExecuteReplicated( + *async->cached_computation->computation, + UnwrapXlaData(async->parameters_data), devices, + execute_options)); results = WrapXlaData(outputs); TF_VLOG(3) << "Executing Dynamo IR sharded graph hash " << torch::lazy::HashToString(hash) << " on devices " @@ -943,8 +944,8 @@ std::vector XLAGraphExecutor::ExecuteStablehlo( } std::vector result_data = - runtime::GetComputationClientOrDie()->ExecuteComputation( - *computations[0], UnwrapXlaData(arguments), device.toString()); + GetValueOrThrow(runtime::GetComputationClientOrDie()->ExecuteComputation( + *computations[0], UnwrapXlaData(arguments), device.toString())); return WrapXlaData(result_data); } @@ -1120,10 +1121,11 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph( // tensor results. Both sharded and unsharded results should be // "Assign"ed to the corresponding data placeholders. std::vector outputs = - runtime::GetComputationClientOrDie()->ExecuteReplicated( - *async->cached_computation->computation, - UnwrapXlaData(async->parameters_data), devices, - execute_options); + GetValueOrThrow( + runtime::GetComputationClientOrDie()->ExecuteReplicated( + *async->cached_computation->computation, + UnwrapXlaData(async->parameters_data), devices, + execute_options)); results = WrapXlaData(outputs); TORCH_LAZY_COUNTER("ExecuteReplicated", 1); TF_VLOG(3) << "Executing IR graph hash " @@ -1135,11 +1137,13 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph( << torch::lazy::HashToString(hash) << " on device " << async->device << " ..."; std::vector outputs = - runtime::GetComputationClientOrDie()->ExecuteComputation( - *async->cached_computation->computation, - UnwrapXlaData(async->parameters_data), async->device.toString(), - {/*explode_tuple=*/true, - /*eager_mode=*/use_eager_mode}); + GetValueOrThrow( + runtime::GetComputationClientOrDie()->ExecuteComputation( + *async->cached_computation->computation, + UnwrapXlaData(async->parameters_data), + async->device.toString(), + {/*explode_tuple=*/true, + /*eager_mode=*/use_eager_mode})); results = WrapXlaData(outputs); TORCH_LAZY_COUNTER("ExecuteComputation", 1); TF_VLOG(3) << "Executing IR graph hash " From cd3bd91f1b959c27047196855649a6a933023428 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Tue, 29 Jul 2025 14:46:53 -0300 Subject: [PATCH 018/133] Error Handling: refactor `GetXlaTensor` and related functions to use status types. (#9510) --- test/cpp/cpp_test_util.cpp | 6 +- test/cpp/test_aten_xla_tensor_1.cpp | 2 +- torch_xla/csrc/aten_autograd_ops.cpp | 22 +- torch_xla/csrc/aten_fallback.cpp | 2 +- torch_xla/csrc/aten_xla_bridge.cpp | 125 +-- torch_xla/csrc/aten_xla_bridge.h | 69 +- torch_xla/csrc/aten_xla_type.cpp | 969 ++++++++++++--------- torch_xla/csrc/cross_replica_reduces.cpp | 12 +- torch_xla/csrc/init_python_bindings.cpp | 264 +++--- torch_xla/csrc/tensor_methods.cpp | 6 +- torch_xla/csrc/tensor_util.cpp | 2 +- torch_xla/csrc/xla_graph_executor.cpp | 15 +- torch_xla/csrc/xla_manual_registration.cpp | 4 +- torch_xla/csrc/xla_sharding_util.cpp | 2 +- 14 files changed, 848 insertions(+), 652 deletions(-) diff --git a/test/cpp/cpp_test_util.cpp b/test/cpp/cpp_test_util.cpp index 6731f5800fd5..4ca8e4981a91 100644 --- a/test/cpp/cpp_test_util.cpp +++ b/test/cpp/cpp_test_util.cpp @@ -246,17 +246,17 @@ void WithAllDevices( } std::string GetTensorTextGraph(at::Tensor tensor) { - XLATensorPtr xtensor = bridge::GetXlaTensor(tensor); + XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(tensor)); return DumpUtil::ToText({xtensor->GetIrValue().node.get()}); } std::string GetTensorDotGraph(at::Tensor tensor) { - XLATensorPtr xtensor = bridge::GetXlaTensor(tensor); + XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(tensor)); return DumpUtil::ToDot({xtensor->GetIrValue().node.get()}); } std::string GetTensorHloGraph(at::Tensor tensor) { - XLATensorPtr xtensor = bridge::GetXlaTensor(tensor); + XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(tensor)); return DumpUtil::ToHlo({xtensor->GetIrValue()}, xtensor->GetDevice()); } diff --git a/test/cpp/test_aten_xla_tensor_1.cpp b/test/cpp/test_aten_xla_tensor_1.cpp index d204b344808b..694c45945639 100644 --- a/test/cpp/test_aten_xla_tensor_1.cpp +++ b/test/cpp/test_aten_xla_tensor_1.cpp @@ -27,7 +27,7 @@ TEST_F(AtenXlaTensorTest, TestStorage) { torch::Tensor a = torch::tensor({0.0}); ForEachDevice([&](const torch::Device& device) { torch::Tensor xla_a = CopyToDevice(a, device); - XLATensorPtr xla_tensor_a = bridge::GetXlaTensor(xla_a); + XLATensorPtr xla_tensor_a = GetValueOrThrow(bridge::GetXlaTensor(xla_a)); EXPECT_EQ(xla_a.device(), xla_tensor_a->Storage().device()); AllClose(a, xla_a); }); diff --git a/torch_xla/csrc/aten_autograd_ops.cpp b/torch_xla/csrc/aten_autograd_ops.cpp index 081c713fd0fa..c8fe95d536db 100644 --- a/torch_xla/csrc/aten_autograd_ops.cpp +++ b/torch_xla/csrc/aten_autograd_ops.cpp @@ -8,6 +8,7 @@ #include "torch_xla/csrc/aten_fallback.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/helpers.h" +#include "torch_xla/csrc/status.h" #include "torch_xla/csrc/tensor_methods.h" #include "torch_xla/csrc/torch_util.h" @@ -33,7 +34,8 @@ torch::Tensor EinsumAutogradFunction::forward( } ctx->save_for_backward(vars); - std::vector xla_tensors = bridge::GetXlaTensors(tensors); + std::vector xla_tensors = + GetValueOrThrow(bridge::GetXlaTensors(tensors)); XLATensorPtr output = tensor_methods::einsum(eq_str, xla_tensors); return bridge::AtenFromXlaTensor(output); } @@ -43,11 +45,13 @@ torch::autograd::variable_list EinsumAutogradFunction::backward( torch::autograd::variable_list grad_output) { std::string equation = ctx->saved_data["equation"].toString()->string(); torch::autograd::variable_list tensors = ctx->get_saved_variables(); - std::vector xla_tensors = bridge::GetXlaTensors(tensors); + std::vector xla_tensors = + GetValueOrThrow(bridge::GetXlaTensors(tensors)); std::tuple outputs = - tensor_methods::einsum_backward(bridge::GetXlaTensor(grad_output[0]), - xla_tensors, equation); + tensor_methods::einsum_backward( + GetValueOrThrow(bridge::GetXlaTensor(grad_output[0])), xla_tensors, + equation); // For both einsum and max pool, we use "undef" as a placeholder for the // non-tensor grad inputs, in this case the equation string. @@ -190,7 +194,7 @@ torch::Tensor MaxPool3dAutogradFunction::forward( } ctx->save_for_backward({self}); auto outputs = tensor_methods::max_pool_nd( - bridge::GetXlaTensor(self), /*spatial_dim_count=*/3, + GetValueOrThrow(bridge::GetXlaTensor(self)), /*spatial_dim_count=*/3, XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode); return bridge::AtenFromXlaTensor(std::get<0>(outputs)); @@ -218,7 +222,8 @@ torch::autograd::variable_list MaxPool3dAutogradFunction::backward( ceil_mode, indices); } grad = bridge::AtenFromXlaTensor(tensor_methods::max_pool_nd_backward( - bridge::GetXlaTensor(grad_output[0]), bridge::GetXlaTensor(self), + GetValueOrThrow(bridge::GetXlaTensor(grad_output[0])), + GetValueOrThrow(bridge::GetXlaTensor(self)), /*spatial_dim_count=*/3, XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode)); @@ -234,7 +239,7 @@ torch::Tensor max_pool2d_forward(torch::Tensor self, torch::IntArrayRef padding, torch::IntArrayRef dilation, bool ceil_mode) { auto outputs = tensor_methods::max_pool_nd( - bridge::GetXlaTensor(self), /*spatial_dim_count=*/2, + GetValueOrThrow(bridge::GetXlaTensor(self)), /*spatial_dim_count=*/2, XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode); return bridge::AtenFromXlaTensor(std::get<0>(outputs)); @@ -245,7 +250,8 @@ torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self, torch::IntArrayRef stride, torch::IntArrayRef padding, bool ceil_mode) { auto grad = bridge::AtenFromXlaTensor(tensor_methods::max_pool_nd_backward( - bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), + GetValueOrThrow(bridge::GetXlaTensor(grad_output)), + GetValueOrThrow(bridge::GetXlaTensor(self)), /*spatial_dim_count=*/2, XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode)); return grad; diff --git a/torch_xla/csrc/aten_fallback.cpp b/torch_xla/csrc/aten_fallback.cpp index 26e33b3d1b2e..45f1c64980a9 100644 --- a/torch_xla/csrc/aten_fallback.cpp +++ b/torch_xla/csrc/aten_fallback.cpp @@ -137,7 +137,7 @@ static bool validate_tensor_list(const c10::List& tensorlist) { // Retrieve the inner XLATensorPtr, and check it lives inside CUDA. static XLATensorPtr get_xla_cuda_tensor(const at::Tensor& tensor) { - XLATensorPtr xla_tensor = bridge::GetXlaTensor(tensor); + XLATensorPtr xla_tensor = GetValueOrThrow(bridge::GetXlaTensor(tensor)); const torch::lazy::BackendDevice& device = xla_tensor->GetDevice(); TORCH_CHECK(device.type() == static_cast(XlaDeviceType::CUDA), "OpenXLA CUDA fallback only supports XLA:CUDA tensors. Found a " diff --git a/torch_xla/csrc/aten_xla_bridge.cpp b/torch_xla/csrc/aten_xla_bridge.cpp index 8b0516da2217..05d92101383a 100644 --- a/torch_xla/csrc/aten_xla_bridge.cpp +++ b/torch_xla/csrc/aten_xla_bridge.cpp @@ -7,10 +7,12 @@ #include #include +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/runtime.h" +#include "torch_xla/csrc/status.h" #include "torch_xla/csrc/tensor_impl.h" #include "torch_xla/csrc/torch_util.h" #include "torch_xla/csrc/xla_graph_executor.h" @@ -72,72 +74,68 @@ AtenXlaDeviceMapper* AtenXlaDeviceMapper::Get() { return device_mapper; } -XLATensorImpl* GetXlaTensorImpl(const at::Tensor& tensor) { +static absl::StatusOr GetXlaTensorImpl( + const at::Tensor& tensor) { auto inner_tensor = torch::lazy::maybe_unwrap_functional(tensor); - return dynamic_cast(inner_tensor.unsafeGetTensorImpl()); + XLATensorImpl* impl = + dynamic_cast(inner_tensor.unsafeGetTensorImpl()); + if (impl == nullptr) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + "Input tensor is not an XLA tensor: ", tensor.toString()))); + } + return impl; } } // namespace XLATensorPtr TryGetXlaTensor(const at::Tensor& tensor) { + return GetXlaTensor(tensor).value_or(XLATensorPtr{}); +} + +absl::StatusOr GetXlaTensor( + const at::Tensor& tensor) { if (tensor.defined() && at::functionalization::impl::isFunctionalTensor(tensor)) { // To make sure we have the most updated version of tensor. at::functionalization::impl::sync(tensor); } - XLATensorImpl* impl = GetXlaTensorImpl(tensor); - if (impl == nullptr) { - return XLATensorPtr(); - } + XLA_ASSIGN_OR_RETURN(XLATensorImpl * impl, GetXlaTensorImpl(tensor)); return impl->tensor(); } -std::vector TryGetXlaTensors(const at::ITensorListRef& tensors) { - std::vector xla_tensors; +absl::StatusOr> GetXlaTensors( + const at::ITensorListRef& tensors) { + std::vector xla_tensors; xla_tensors.reserve(tensors.size()); for (const auto& tensor : tensors) { - xla_tensors.push_back(bridge::TryGetXlaTensor(tensor)); + XLA_ASSIGN_OR_RETURN(XLATensorPtr ptr, bridge::GetXlaTensor(tensor)); + xla_tensors.push_back(std::move(ptr)); } return xla_tensors; } bool IsXlaTensor(const at::Tensor& tensor) { - return GetXlaTensorImpl(tensor) != nullptr; -} - -XLATensorPtr GetXlaTensor(const at::Tensor& tensor) { - auto xtensor = TryGetXlaTensor(tensor); - XLA_CHECK(xtensor) << "Input tensor is not an XLA tensor: " - << tensor.toString(); - return xtensor; + return GetXlaTensorImpl(tensor).ok(); } -void ReplaceXlaTensor(const at::Tensor& tensor, XLATensorPtr new_xla_tensor) { - auto inner_tensor = torch::lazy::maybe_unwrap_functional(tensor); - XLATensorImpl* impl = - dynamic_cast(inner_tensor.unsafeGetTensorImpl()); - XLA_CHECK(impl != nullptr) - << "Input tensor is not an XLA tensor: " << inner_tensor.toString(); +absl::Status ReplaceXlaTensor(const at::Tensor& tensor, + XLATensorPtr new_xla_tensor) { + XLA_ASSIGN_OR_RETURN(XLATensorImpl * impl, GetXlaTensorImpl(tensor)); impl->set_tensor(std::move(new_xla_tensor)); + return absl::OkStatus(); } -void ReplaceXlaTensor(const std::vector& tensors, - const std::vector new_xla_tensors) { - XLA_CHECK(tensors.size() == new_xla_tensors.size()) - << "The size of tensors and new_xla_tensors are not equal: " - << tensors.size() << " vs. " << new_xla_tensors.size(); - for (size_t i = 0; i < tensors.size(); ++i) { - ReplaceXlaTensor(tensors[i], new_xla_tensors[i]); +absl::Status ReplaceXlaTensor(const std::vector& tensors, + const std::vector new_xla_tensors) { + if (tensors.size() != new_xla_tensors.size()) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError( + absl::StrCat("The size of tensors and new_xla_tensors are not equal: ", + tensors.size(), " vs. ", new_xla_tensors.size()))); } -} - -std::vector GetXlaTensors(const at::ITensorListRef& tensors) { - std::vector xla_tensors; - xla_tensors.reserve(tensors.size()); - for (const auto& tensor : tensors) { - xla_tensors.push_back(bridge::GetXlaTensor(tensor)); + for (size_t i = 0; i < tensors.size(); ++i) { + XLA_RETURN_IF_ERROR(ReplaceXlaTensor(tensors[i], new_xla_tensors[i])); } - return xla_tensors; + return absl::OkStatus(); } torch_xla::XLATensorPtr GetXlaTensorOrCreateForWrappedNumber( @@ -146,7 +144,7 @@ torch_xla::XLATensorPtr GetXlaTensorOrCreateForWrappedNumber( (tensor.dim() == 0 && tensor.numel() == 1)) { return torch_xla::bridge::GetOrCreateXlaTensor(tensor, device); } else { - return torch_xla::bridge::GetXlaTensor(tensor); + return GetValueOrThrow(torch_xla::bridge::GetXlaTensor(tensor)); } } @@ -155,22 +153,23 @@ XLATensorPtr GetOrCreateXlaTensor(const at::Tensor& tensor, if (!tensor.defined()) { return XLATensorPtr(); } + auto inner_tensor = torch::lazy::maybe_unwrap_functional(tensor); if (!inner_tensor.defined()) { return XLATensorPtr(); } - auto xtensor = TryGetXlaTensor(tensor); - return xtensor ? xtensor : XLATensor::Create(inner_tensor, device); + + auto xtensor = GetXlaTensor(tensor); + return xtensor.ok() ? xtensor.value() + : XLATensor::Create(inner_tensor, device); } XLATensorPtr GetOrCreateXlaTensor(const std::optional& tensor, const torch::lazy::BackendDevice& device) { - if (!IsDefined(tensor)) { + if (!tensor.has_value()) { return XLATensorPtr(); } - auto xtensor = TryGetXlaTensor(*tensor); - auto inner_tensor = torch::lazy::maybe_unwrap_functional(*tensor); - return xtensor ? xtensor : XLATensor::Create(inner_tensor, device); + return GetOrCreateXlaTensor(*tensor, device); } std::vector GetOrCreateXlaTensors( @@ -199,10 +198,10 @@ std::vector XlaCreateTensorList(const at::ITensorListRef& tensors) { continue; } - auto xtensor = TryGetXlaTensor(tensor); - if (xtensor) { + auto xtensor_status = GetXlaTensor(tensor); + if (xtensor_status.ok()) { to_translate[ix] = true; - xla_tensors.push_back(xtensor); + xla_tensors.push_back(xtensor_status.value()); } else { aten_xla_tensors[ix] = tensor; } @@ -253,13 +252,14 @@ void XlaUpdateTensors(absl::Span dest_xla_tensors, for (auto index : indices) { at::Tensor dest = dest_xla_tensors.at(index); at::Tensor source = source_cpu_tensors.at(index); - XLATensorImpl* dest_impl = GetXlaTensorImpl(dest); - if (dest_impl != nullptr) { - auto xla_source = TryGetXlaTensor(source); - if (!xla_source) { - dest_impl->tensor()->UpdateFromTensorOut(source); + auto dest_impl_status = GetXlaTensorImpl(dest); + if (dest_impl_status.ok()) { + auto dest_impl = std::move(dest_impl_status).value(); + auto xla_source_status = GetXlaTensor(source); + if (xla_source_status.ok()) { + dest_impl->tensor()->UpdateFromTensorOut(xla_source_status.value()); } else { - dest_impl->tensor()->UpdateFromTensorOut(xla_source); + dest_impl->tensor()->UpdateFromTensorOut(source); } dest_impl->force_refresh_sizes(); } else { @@ -270,11 +270,11 @@ void XlaUpdateTensors(absl::Span dest_xla_tensors, std::optional GetXlaDevice( const at::Tensor& tensor) { - auto xtensor = TryGetXlaTensor(tensor); - if (!xtensor) { + auto xtensor_status = GetXlaTensor(tensor); + if (!xtensor_status.ok()) { return std::nullopt; } - return xtensor->GetDevice(); + return xtensor_status.value()->GetDevice(); } std::optional GetXlaDevice( @@ -469,12 +469,15 @@ std::vector CreateXlaTensors( } const at::Tensor& GetRootBase(const at::Tensor& tensor) { - auto xla_tensor = TryGetXlaTensor(tensor); - if (xla_tensor && xla_tensor->Base().defined()) { - return GetRootBase(xla_tensor->Base()); - } else { + auto xla_tensor_status = GetXlaTensor(tensor); + if (!xla_tensor_status.ok()) { + return tensor; + } + auto xla_tensor = std::move(xla_tensor_status).value(); + if (!xla_tensor->Base().defined()) { return tensor; } + return GetRootBase(xla_tensor->Base()); } XLATensorPtr SetBaseTensor(XLATensorPtr tensor, const at::Tensor& base) { diff --git a/torch_xla/csrc/aten_xla_bridge.h b/torch_xla/csrc/aten_xla_bridge.h index a862e3a72e25..d04873ec8ff3 100644 --- a/torch_xla/csrc/aten_xla_bridge.h +++ b/torch_xla/csrc/aten_xla_bridge.h @@ -7,6 +7,8 @@ #include +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/tensor.h" @@ -14,26 +16,65 @@ namespace torch_xla { namespace bridge { +// TODO(ysiraichi): remove this function once codegen does not need it. +// +// We still need this function because lazy codegen needs a function that +// returns a value of type `T`, which can be: +// +// 1. cast to boolean; and +// 2. accessed with "->" +// +// e.g. pointers and optional types +// +// A StatusOr type fulfills only (2), so we can't use it there. In order +// to do so, we have to change upstream accordingly. +// +ABSL_DEPRECATED( + "Use GetXlaTensor(), instead. " + "This function returns an null-initialized `XLATensorPtr`, instead of " + "propagating errors with StatusOr values.") XLATensorPtr TryGetXlaTensor(const at::Tensor& tensor); +// Retrieves the underlying `XLATensorPtr` from `tensor`. +// +// This function does the following things in order to retrieve +// (if exists) the underlying `XLATensorPtr`: +// +// 1. Synchronizes the tensor, if it's a tensor wrapped in a functional tensor +// 2. Retrieves the inner `XLATensorImpl` instance +// 3. Finally, retrieves the `XLATensor` that lives inside `XLATensorImpl` +// +// An error might ocurr if, after unwrapping the wrapper functional tensor +// (if exists), the `TensorImpl` of the unwrapped tensor is not a +// `XLATensorImpl`. This might happen if: +// +// 1. `tensor` lives in another device +// 2. `tensor` wasn't created within this project +// (e.g. meta tensors whose device is XLA) +// +absl::StatusOr GetXlaTensor( + const at::Tensor& tensor); + // Same as above, applied to a list of tensors. -std::vector TryGetXlaTensors(const at::ITensorListRef& tensors); +absl::StatusOr> GetXlaTensors( + const at::ITensorListRef& tensors); bool IsXlaTensor(const at::Tensor& tensor); -// Extracts the XLATensorPtr out of our version of at::Tensor. Throws an -// exception if tensor is not an XLA tensor. -XLATensorPtr GetXlaTensor(const at::Tensor& tensor); - -// Replaces the XLA tensor embedded within the XLA TensorImpl with the new -// version. -void ReplaceXlaTensor(const at::Tensor& tensor, XLATensorPtr new_xla_tensor); - -void ReplaceXlaTensor(const std::vector& tensor, - const std::vector new_xla_tensor); - -// Same as above, applied to a list of tensors. -std::vector GetXlaTensors(const at::ITensorListRef& tensors); +// Replaces the XLA tensor embedded within `tensor`'s XLA TensorImpl with +// `new_xla_tensor`. +// +// Fails if `tensor` is not an XLA tensor. +absl::Status ReplaceXlaTensor(const at::Tensor& tensor, + XLATensorPtr new_xla_tensor); + +// Replaces the XLA tensor embedded within the `tensors` XLA TensorImpl +// with `new_xla_tensors`. +// +// Fails if any of `tensors` is not an XLA tensor, or if the number of `tensors` +// does not match the number of `new_xla_tensors`. +absl::Status ReplaceXlaTensor(const std::vector& tensors, + const std::vector new_xla_tensors); torch_xla::XLATensorPtr GetXlaTensorOrCreateForWrappedNumber( const at::Tensor& tensor, const torch::lazy::BackendDevice& device); diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 3cb1f6c51b95..64354c893a13 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -17,6 +17,7 @@ #include #include +#include "absl/log/absl_check.h" #include "torch/csrc/lazy/core/helpers.h" #include "torch/csrc/lazy/core/shape_inference.h" #include "torch/csrc/lazy/core/tensor_util.h" @@ -44,6 +45,7 @@ #include "torch_xla/csrc/runtime/metrics.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/util.h" +#include "torch_xla/csrc/status.h" #include "torch_xla/csrc/tensor_impl.h" #include "torch_xla/csrc/tensor_methods.h" #include "torch_xla/csrc/tensor_util.h" @@ -158,13 +160,14 @@ class OpConfig { // Look for, at least, one tensor already in PyTorch/XLA. InputVector::iterator it = std::find_if( inputs.begin(), inputs.end(), [](const at::Tensor& tensor) { - return bridge::TryGetXlaTensor(tensor); + return bridge::GetXlaTensor(tensor).ok(); }); XLA_CHECK(it != inputs.end()); // Transform the inputs into a list of XLATensorPtr. // For that, either get their corresponding XLATensorPtr, or use the found // XLA tensor's BackendDevice for creating a new one. - torch::lazy::BackendDevice device = bridge::GetXlaTensor(*it)->GetDevice(); + torch::lazy::BackendDevice device = + GetValueOrThrow(bridge::GetXlaTensor(*it))->GetDevice(); XLAInputVector xla_inputs(inputs.size()); std::transform(inputs.begin(), inputs.end(), xla_inputs.begin(), [&](const at::Tensor& tensor) { @@ -332,12 +335,12 @@ std::pair GetBinaryOperands( const at::Tensor& self, const at::Tensor& other) { XLATensorPtr self_tensor; XLATensorPtr other_tensor; - auto self_xtensor = bridge::TryGetXlaTensor(self); - if (!self_xtensor) { - other_tensor = bridge::GetXlaTensor(other); + auto self_xtensor_status = bridge::GetXlaTensor(self); + if (!self_xtensor_status.ok()) { + other_tensor = GetValueOrThrow(bridge::GetXlaTensor(other)); self_tensor = bridge::GetOrCreateXlaTensor(self, other_tensor->GetDevice()); } else { - self_tensor = self_xtensor; + self_tensor = std::move(self_xtensor_status).value(); other_tensor = bridge::GetOrCreateXlaTensor(other, self_tensor->GetDevice()); } @@ -384,7 +387,7 @@ template at::Tensor DoBinaryOp(const at::Tensor& self, const at::Scalar& other, const B& bin_op) { at::ScalarType dtype = at::result_type(self, other); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); XLATensorPtr result = bin_op(self_tensor, other, dtype); return bridge::AtenFromXlaTensor(result); } @@ -393,7 +396,7 @@ template at::Tensor DoBinaryOp(const at::Scalar& self, const at::Tensor& other, const B& bin_op) { at::ScalarType dtype = at::result_type(self, other); - XLATensorPtr other_tensor = bridge::GetXlaTensor(other); + XLATensorPtr other_tensor = GetValueOrThrow(bridge::GetXlaTensor(other)); XLATensorPtr result = bin_op(self, other_tensor, dtype); return bridge::AtenFromXlaTensor(result); } @@ -411,7 +414,7 @@ at::Tensor DoBinaryOpWithoutPromo(const at::Tensor& self, template at::Tensor DoBinaryOpWithoutPromo(const at::Tensor& self, const at::Scalar& other, const B& bin_op) { - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); XLATensorPtr result = bin_op(self_tensor, other); return bridge::AtenFromXlaTensor(result); } @@ -423,7 +426,7 @@ void DoBinaryOpOut(const at::Tensor& self, const at::Tensor& other, XLA_CHECK(at::canCast(/*from=*/dtype, /*to=*/out.scalar_type())); std::pair operands = GetBinaryOperands(self, UnwrapNumber(other, dtype)); - XLATensorPtr out_tensor = bridge::GetXlaTensor(out); + XLATensorPtr out_tensor = GetValueOrThrow(bridge::GetXlaTensor(out)); bin_op_out(operands.first, operands.second, out_tensor); } @@ -432,7 +435,7 @@ void DoBinaryOpOut(const at::Tensor& self, const at::Tensor& other, at::Tensor& XLANativeFunctions::__ilshift__(at::Tensor& self, const at::Scalar& other) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); tensor_methods::__ilshift__(self_tensor, other); return self; } @@ -441,8 +444,9 @@ at::Tensor& XLANativeFunctions::__ilshift__(at::Tensor& self, const at::Tensor& other) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); CheckBinaryOpTypePromotion(self, self, other); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); - tensor_methods::__ilshift__(self_tensor, bridge::GetXlaTensor(other)); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + tensor_methods::__ilshift__(self_tensor, + GetValueOrThrow(bridge::GetXlaTensor(other))); return self; } @@ -450,7 +454,7 @@ at::Tensor& XLANativeFunctions::__irshift__(at::Tensor& self, const at::Scalar& other) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); CheckBinaryOpTypePromotion(self, self, other); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); tensor_methods::__irshift__(self_tensor, other); return self; } @@ -459,8 +463,9 @@ at::Tensor& XLANativeFunctions::__irshift__(at::Tensor& self, const at::Tensor& other) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); CheckBinaryOpTypePromotion(self, self, other); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); - tensor_methods::__irshift__(self_tensor, bridge::GetXlaTensor(other)); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + tensor_methods::__irshift__(self_tensor, + GetValueOrThrow(bridge::GetXlaTensor(other))); return self; } @@ -516,7 +521,7 @@ at::Tensor XLANativeFunctions::_adaptive_avg_pool3d( auto common_device = torch_xla::bridge::GetXlaDevice(self); XLA_CHECK(common_device); torch::lazy::NodePtr node = torch_xla::MakeNode( - bridge::GetXlaTensor(self)->GetIrValue(), + GetValueOrThrow(bridge::GetXlaTensor(self))->GetIrValue(), std::vector(output_size.begin(), output_size.end())); return torch_xla::bridge::AtenFromXlaTensor( torch_xla::XLATensor::Create(std::move(node), *common_device)); @@ -538,8 +543,8 @@ at::Tensor XLANativeFunctions::_adaptive_avg_pool3d_backward( auto common_device = torch_xla::bridge::GetXlaDevice(grad_output, self); XLA_CHECK(common_device); torch::lazy::NodePtr node = torch_xla::MakeNode( - bridge::GetXlaTensor(grad_output)->GetIrValue(), - bridge::GetXlaTensor(self)->GetIrValue()); + GetValueOrThrow(bridge::GetXlaTensor(grad_output))->GetIrValue(), + GetValueOrThrow(bridge::GetXlaTensor(self))->GetIrValue()); return torch_xla::bridge::AtenFromXlaTensor( torch_xla::XLATensor::Create(std::move(node), *common_device)); @@ -555,7 +560,7 @@ at::Tensor XLANativeFunctions::_adaptive_avg_pool2d( &xla_fallback, ATEN_OP(_adaptive_avg_pool2d)>::call(self, output_size); } return bridge::AtenFromXlaTensor(tensor_methods::_adaptive_avg_pool2d( - bridge::GetXlaTensor(self), output_size_list)); + GetValueOrThrow(bridge::GetXlaTensor(self)), output_size_list)); } at::Tensor XLANativeFunctions::_adaptive_avg_pool2d_backward( @@ -572,7 +577,8 @@ at::Tensor XLANativeFunctions::_adaptive_avg_pool2d_backward( } return bridge::AtenFromXlaTensor( tensor_methods::_adaptive_avg_pool2d_backward( - bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self))); + GetValueOrThrow(bridge::GetXlaTensor(grad_output)), + GetValueOrThrow(bridge::GetXlaTensor(self)))); } std::tuple XLANativeFunctions::adaptive_max_pool2d( @@ -585,8 +591,8 @@ std::tuple XLANativeFunctions::adaptive_max_pool2d( &xla_fallback, ATEN_OP(adaptive_max_pool2d)>::call(self, output_size); } std::tuple res = - tensor_methods::adaptive_max_pool2d(bridge::GetXlaTensor(self), - output_size_list); + tensor_methods::adaptive_max_pool2d( + GetValueOrThrow(bridge::GetXlaTensor(self)), output_size_list); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(res)), bridge::AtenFromXlaTensor(std::get<1>(res))); } @@ -606,16 +612,18 @@ at::Tensor XLANativeFunctions::adaptive_max_pool2d_backward( indices); } return bridge::AtenFromXlaTensor(tensor_methods::adaptive_max_pool2d_backward( - bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self))); + GetValueOrThrow(bridge::GetXlaTensor(grad_output)), + GetValueOrThrow(bridge::GetXlaTensor(self)))); } void XLANativeFunctions::_amp_foreach_non_finite_check_and_unscale_( at::TensorList self, at::Tensor& found_inf, const at::Tensor& inv_scale) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr found_inf_tensor = bridge::GetXlaTensor(found_inf); + XLATensorPtr found_inf_tensor = + GetValueOrThrow(bridge::GetXlaTensor(found_inf)); tensor_methods::_amp_foreach_non_finite_check_and_unscale_( - bridge::GetXlaTensors(self), found_inf_tensor, - bridge::GetXlaTensor(inv_scale)); + GetValueOrThrow(bridge::GetXlaTensors(self)), found_inf_tensor, + GetValueOrThrow(bridge::GetXlaTensor(inv_scale))); } at::Tensor& XLANativeFunctions::_amp_update_scale_(at::Tensor& current_scale, @@ -625,11 +633,13 @@ at::Tensor& XLANativeFunctions::_amp_update_scale_(at::Tensor& current_scale, double scale_backoff_factor, int64_t growth_interval) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr growth_tracker_tensor = bridge::GetXlaTensor(growth_tracker); - XLATensorPtr current_scale_tensor = bridge::GetXlaTensor(current_scale); + XLATensorPtr growth_tracker_tensor = + GetValueOrThrow(bridge::GetXlaTensor(growth_tracker)); + XLATensorPtr current_scale_tensor = + GetValueOrThrow(bridge::GetXlaTensor(current_scale)); tensor_methods::_amp_update_scale_( growth_tracker_tensor, current_scale_tensor, - bridge::GetXlaTensor(found_inf), scale_growth_factor, + GetValueOrThrow(bridge::GetXlaTensor(found_inf)), scale_growth_factor, scale_backoff_factor, growth_interval); return current_scale; } @@ -638,22 +648,23 @@ at::Tensor XLANativeFunctions::_copy_from(const at::Tensor& self, const at::Tensor& dst, bool /*non_blocking*/) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto dst_tensor = bridge::TryGetXlaTensor(dst); - auto self_tensor = bridge::TryGetXlaTensor(self); - if (!self_tensor) { + auto dst_tensor_status = bridge::GetXlaTensor(dst); + auto self_tensor_status = bridge::GetXlaTensor(self); + ABSL_CHECK(self_tensor_status.ok() || dst_tensor_status.ok()); + if (!self_tensor_status.ok()) { static bool sync_update = runtime::sys_util::GetEnvBool("XLA_TENSOR_UPDATE_SYNC", true) && !UseVirtualDevice(); - dst_tensor->UpdateFromTensor(self, /*sync=*/sync_update); - XLA_CHECK(dst_tensor); - } else if (!dst_tensor) { - at::Tensor tensor = self_tensor->ToTensor(/*detached=*/true); + dst_tensor_status.value()->UpdateFromTensor(self, /*sync=*/sync_update); + } else if (!dst_tensor_status.ok()) { + at::Tensor tensor = self_tensor_status.value()->ToTensor(/*detached=*/true); at::Tensor typed_tensor = torch::lazy::CopyTensor(tensor, dst.scalar_type(), /*copy=*/false); dst.resize_as_(typed_tensor).copy_(typed_tensor); } else { - tensor_methods::copy_(dst_tensor, self_tensor); - bridge::ReplaceXlaTensor(dst, dst_tensor); + auto dst_tensor = std::move(dst_tensor_status).value(); + tensor_methods::copy_(dst_tensor, self_tensor_status.value()); + MaybeThrow(bridge::ReplaceXlaTensor(dst, dst_tensor)); } return dst; } @@ -661,13 +672,13 @@ at::Tensor XLANativeFunctions::_copy_from(const at::Tensor& self, at::Tensor XLANativeFunctions::_copy_from_and_resize(const at::Tensor& self, const at::Tensor& dst) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto dst_tensor = bridge::TryGetXlaTensor(dst); - auto self_tensor = bridge::TryGetXlaTensor(self); - if (!self_tensor) { - XLA_CHECK(dst_tensor); - dst_tensor->UpdateFromTensorOut(self); - } else if (!dst_tensor) { - at::Tensor tensor = self_tensor->ToTensor(/*detached=*/true); + auto dst_tensor_status = bridge::GetXlaTensor(dst); + auto self_tensor_status = bridge::GetXlaTensor(self); + ABSL_CHECK(self_tensor_status.ok() || dst_tensor_status.ok()); + if (!self_tensor_status.ok()) { + dst_tensor_status.value()->UpdateFromTensorOut(self); + } else if (!dst_tensor_status.ok()) { + at::Tensor tensor = self_tensor_status.value()->ToTensor(/*detached=*/true); at::Tensor typed_tensor = torch::lazy::CopyTensor(tensor, dst.scalar_type(), /*copy=*/false); dst.resize_as_(typed_tensor).copy_(typed_tensor); @@ -675,7 +686,8 @@ at::Tensor XLANativeFunctions::_copy_from_and_resize(const at::Tensor& self, // at this point we know dst is an XLA tensor XLATensorImpl* dest_impl = dynamic_cast(dst.unsafeGetTensorImpl()); - dest_impl->tensor()->UpdateFromTensorOut(self_tensor); + dest_impl->tensor()->UpdateFromTensorOut( + std::move(self_tensor_status).value()); dest_impl->force_refresh_sizes(); } return dst; @@ -719,7 +731,7 @@ at::Tensor XLANativeFunctions::_to_copy( if (device && device->type() != c10::kXLA) { XLA_CHECK(device->type() == c10::kCPU) << "only cpu device is supported in _to_copy."; - auto self_tensor = bridge::GetXlaTensor(self); + auto self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); auto eager_tensor = self_tensor->ToTensor(/*detached=*/true); // Use the eager .to on the eager tensor. @@ -751,7 +763,7 @@ std::tuple XLANativeFunctions::_linalg_eigh( ATEN_OP(_linalg_eigh)>::call(self, uplo, compute_v); } - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); auto outputs = tensor_methods::eigh(self_tensor, uplo); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)), bridge::AtenFromXlaTensor(std::get<1>(outputs))); @@ -760,7 +772,7 @@ std::tuple XLANativeFunctions::_linalg_eigh( std::tuple XLANativeFunctions::_linalg_slogdet(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); auto outputs = tensor_methods::slogdet(self_tensor); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)), bridge::AtenFromXlaTensor(std::get<1>(outputs)), @@ -776,8 +788,9 @@ at::Tensor XLANativeFunctions::_log_softmax(const at::Tensor& self, int64_t dim, std::vector shapes{ torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; - return bridge::AtenFromXlaTensor(tensor_methods::log_softmax( - bridge::GetXlaTensor(self), dim, std::nullopt, std::move(shapes))); + return bridge::AtenFromXlaTensor( + tensor_methods::log_softmax(GetValueOrThrow(bridge::GetXlaTensor(self)), + dim, std::nullopt, std::move(shapes))); } at::Tensor XLANativeFunctions::_log_softmax_backward_data( @@ -785,7 +798,8 @@ at::Tensor XLANativeFunctions::_log_softmax_backward_data( at::ScalarType /* input_dtype */) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::log_softmax_backward( - bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(output), dim)); + GetValueOrThrow(bridge::GetXlaTensor(grad_output)), + GetValueOrThrow(bridge::GetXlaTensor(output)), dim)); } std::tuple XLANativeFunctions::_pack_padded_sequence( @@ -799,8 +813,8 @@ std::tuple XLANativeFunctions::_pack_padded_sequence( at::Tensor XLANativeFunctions::_softmax(const at::Tensor& self, int64_t dim, bool /* half_to_float */) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::softmax(bridge::GetXlaTensor(self), dim, std::nullopt)); + return bridge::AtenFromXlaTensor(tensor_methods::softmax( + GetValueOrThrow(bridge::GetXlaTensor(self)), dim, std::nullopt)); } at::Tensor XLANativeFunctions::_softmax_backward_data( @@ -808,7 +822,8 @@ at::Tensor XLANativeFunctions::_softmax_backward_data( at::ScalarType input_dtype) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::softmax_backward( - bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(output), dim)); + GetValueOrThrow(bridge::GetXlaTensor(grad_output)), + GetValueOrThrow(bridge::GetXlaTensor(output)), dim)); } at::Tensor XLANativeFunctions::_unsafe_view(const at::Tensor& self, @@ -857,16 +872,16 @@ at::Tensor XLANativeFunctions::addmm(const at::Tensor& self, return at::native::call_fallback_fn<&xla_fallback, ATEN_OP(addmm)>::call( self, mat1, mat2, beta, alpha); } - return bridge::AtenFromXlaTensor( - tensor_methods::addmm(bridge::GetXlaTensor(mat1), - /*weight=*/bridge::GetXlaTensor(mat2), - /*bias=*/bridge::GetXlaTensor(self))); + return bridge::AtenFromXlaTensor(tensor_methods::addmm( + GetValueOrThrow(bridge::GetXlaTensor(mat1)), + /*weight=*/GetValueOrThrow(bridge::GetXlaTensor(mat2)), + /*bias=*/GetValueOrThrow(bridge::GetXlaTensor(self)))); } at::Tensor XLANativeFunctions::alias(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( - tensor_methods::alias(bridge::GetXlaTensor(self))); + tensor_methods::alias(GetValueOrThrow(bridge::GetXlaTensor(self)))); } at::Tensor XLANativeFunctions::alias_copy(const at::Tensor& self) { @@ -879,7 +894,7 @@ at::Tensor& XLANativeFunctions::arange_out(const at::Scalar& start, const at::Scalar& step, at::Tensor& out) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr out_tensor = bridge::GetXlaTensor(out); + XLATensorPtr out_tensor = GetValueOrThrow(bridge::GetXlaTensor(out)); tensor_methods::arange_out(out_tensor, start, end, step, out.scalar_type()); return out; } @@ -940,7 +955,8 @@ static at::Tensor as_strided_eliminate_one_dim_fast_path( } } return bridge::AtenFromXlaTensor(tensor_methods::squeeze( - tensor_methods::slice(bridge::GetXlaTensor(tensor), skip_dim, 0, 1, 1), + tensor_methods::slice(GetValueOrThrow(bridge::GetXlaTensor(tensor)), + skip_dim, 0, 1, 1), skip_dim)); } // now tensor_dim.size() == stride.size() @@ -972,9 +988,9 @@ static at::Tensor as_strided_eliminate_one_dim_fast_path( // stride. K = 1; } - return bridge::AtenFromXlaTensor( - tensor_methods::slice(bridge::GetXlaTensor(tensor), reduce_size_location, - 0, size[reduce_size_location] * K, K)); + return bridge::AtenFromXlaTensor(tensor_methods::slice( + GetValueOrThrow(bridge::GetXlaTensor(tensor)), reduce_size_location, 0, + size[reduce_size_location] * K, K)); } at::Tensor XLANativeFunctions::as_strided_copy( @@ -987,7 +1003,7 @@ at::Tensor XLANativeFunctions::as_strided_copy( // Retrieve the base tensor, if there's one. // This function actually operates on the tensor's storage. Since XLA does not // expose the actual storage, we use the originally allocated tensor. - const at::Tensor& base = bridge::GetXlaTensor(self)->Base(); + const at::Tensor& base = GetValueOrThrow(bridge::GetXlaTensor(self))->Base(); at::Tensor tensor = base.defined() ? base : self; // Fast path: using slice to replace as_strided to avoid the index copy. @@ -1004,10 +1020,10 @@ at::Tensor XLANativeFunctions::as_strided_copy( // Even though this function copies (without aliasing) tensor, it's still // treated as a view function in the functionalization layer. return bridge::AtenFromXlaTensor(bridge::SetBaseTensor( - tensor_methods::as_strided(bridge::GetXlaTensor(tensor), - XlaHelpers::I64List(size), - XlaHelpers::I64List(stride), - XlaHelpers::I64Optional(storage_offset)), + tensor_methods::as_strided( + GetValueOrThrow(bridge::GetXlaTensor(tensor)), + XlaHelpers::I64List(size), XlaHelpers::I64List(stride), + XlaHelpers::I64Optional(storage_offset)), tensor)); } @@ -1101,7 +1117,7 @@ at::Tensor XLANativeFunctions::as_strided_scatter( at::IntArrayRef size, at::IntArrayRef stride, std::optional storage_offset) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto base_ = bridge::GetXlaTensor(base); + auto base_ = GetValueOrThrow(bridge::GetXlaTensor(base)); auto xsize = XlaHelpers::I64List(size); auto xstride = XlaHelpers::I64List(stride); if (!AsStrided::StrideIsSupported(base_->shape(), xsize, xstride, @@ -1111,7 +1127,7 @@ at::Tensor XLANativeFunctions::as_strided_scatter( size, stride, storage_offset); } - auto mutated_view_ = bridge::GetXlaTensor(mutated_view); + auto mutated_view_ = GetValueOrThrow(bridge::GetXlaTensor(mutated_view)); return bridge::AtenFromXlaTensor( base_->CreateFrom(torch_xla::MakeNode( base_->GetIrValue(), mutated_view_->GetIrValue(), @@ -1124,9 +1140,9 @@ at::Tensor XLANativeFunctions::atan2(const at::Tensor& self, TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); auto common_device = torch_xla::bridge::GetXlaDevice(self, other); XLA_CHECK(common_device); - torch::lazy::NodePtr node = - torch_xla::MakeNode(bridge::GetXlaTensor(self)->GetIrValue(), - bridge::GetXlaTensor(other)->GetIrValue()); + torch::lazy::NodePtr node = torch_xla::MakeNode( + GetValueOrThrow(bridge::GetXlaTensor(self))->GetIrValue(), + GetValueOrThrow(bridge::GetXlaTensor(other))->GetIrValue()); return torch_xla::bridge::AtenFromXlaTensor( torch_xla::XLATensor::Create(std::move(node), *common_device)); @@ -1138,7 +1154,7 @@ at::Tensor XLANativeFunctions::avg_pool2d( std::optional divisor_override) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::avg_pool_nd( - bridge::GetXlaTensor(self), /*spatial_dim_count=*/2, + GetValueOrThrow(bridge::GetXlaTensor(self)), /*spatial_dim_count=*/2, XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode, count_include_pad, divisor_override)); @@ -1159,7 +1175,8 @@ at::Tensor XLANativeFunctions::avg_pool2d_backward( divisor_override); } return bridge::AtenFromXlaTensor(tensor_methods::avg_pool_nd_backward( - bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), + GetValueOrThrow(bridge::GetXlaTensor(grad_output)), + GetValueOrThrow(bridge::GetXlaTensor(self)), /*spatial_dim_count=*/2, XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode, count_include_pad)); @@ -1171,7 +1188,7 @@ at::Tensor XLANativeFunctions::avg_pool3d( std::optional divisor_override) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::avg_pool_nd( - bridge::GetXlaTensor(self), /*spatial_dim_count=*/3, + GetValueOrThrow(bridge::GetXlaTensor(self)), /*spatial_dim_count=*/3, XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode, count_include_pad, divisor_override)); @@ -1192,7 +1209,8 @@ at::Tensor XLANativeFunctions::avg_pool3d_backward( divisor_override); } return bridge::AtenFromXlaTensor(tensor_methods::avg_pool_nd_backward( - bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), + GetValueOrThrow(bridge::GetXlaTensor(grad_output)), + GetValueOrThrow(bridge::GetXlaTensor(self)), /*spatial_dim_count=*/3, XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode, count_include_pad)); @@ -1206,8 +1224,9 @@ at::Tensor XLANativeFunctions::baddbmm(const at::Tensor& self, TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::baddbmm( - bridge::GetXlaTensor(self), bridge::GetXlaTensor(batch1), - bridge::GetXlaTensor(batch2), beta, alpha)); + GetValueOrThrow(bridge::GetXlaTensor(self)), + GetValueOrThrow(bridge::GetXlaTensor(batch1)), + GetValueOrThrow(bridge::GetXlaTensor(batch2)), beta, alpha)); } at::Tensor XLANativeFunctions::bernoulli( @@ -1218,7 +1237,7 @@ at::Tensor XLANativeFunctions::bernoulli( ATEN_OP(bernoulli)>::call(self, generator); } - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::bernoulli(self_tensor)); } @@ -1229,7 +1248,7 @@ at::Tensor XLANativeFunctions::bernoulli( return at::native::call_fallback_fn< &xla_fallback, ATEN_OP2(bernoulli, p)>::call(self, p, generator); } - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::bernoulli(self_tensor, p)); } @@ -1241,8 +1260,9 @@ at::Tensor& XLANativeFunctions::bernoulli_( return at::native::call_fallback_fn< &xla_fallback, ATEN_OP2(bernoulli_, Tensor)>::call(self, p, generator); } - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); - tensor_methods::bernoulli_(self_tensor, bridge::GetXlaTensor(p)); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + tensor_methods::bernoulli_(self_tensor, + GetValueOrThrow(bridge::GetXlaTensor(p))); return self; } @@ -1286,28 +1306,30 @@ at::Tensor XLANativeFunctions::bitwise_xor(const at::Tensor& self, at::Tensor XLANativeFunctions::bmm(const at::Tensor& self, const at::Tensor& mat2) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::bmm( - bridge::GetXlaTensor(self), bridge::GetXlaTensor(mat2))); + return bridge::AtenFromXlaTensor( + tensor_methods::bmm(GetValueOrThrow(bridge::GetXlaTensor(self)), + GetValueOrThrow(bridge::GetXlaTensor(mat2)))); } at::Tensor XLANativeFunctions::cat(const at::ITensorListRef& tensors, int64_t dim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::cat( - bridge::GetXlaTensors(tensors), dim, at::native::result_type(tensors))); + return bridge::AtenFromXlaTensor( + tensor_methods::cat(GetValueOrThrow(bridge::GetXlaTensors(tensors)), dim, + at::native::result_type(tensors))); } at::Tensor XLANativeFunctions::celu(const at::Tensor& self, const at::Scalar& alpha) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( - tensor_methods::celu(bridge::GetXlaTensor(self), alpha)); + tensor_methods::celu(GetValueOrThrow(bridge::GetXlaTensor(self)), alpha)); } at::Tensor& XLANativeFunctions::celu_(at::Tensor& self, const at::Scalar& alpha) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); tensor_methods::celu_(self_tensor, alpha); return self; } @@ -1316,29 +1338,29 @@ at::Tensor XLANativeFunctions::clamp(const at::Tensor& self, const std::optional& min, const std::optional& max) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::clamp(bridge::GetXlaTensor(self), min, max)); + return bridge::AtenFromXlaTensor(tensor_methods::clamp( + GetValueOrThrow(bridge::GetXlaTensor(self)), min, max)); } at::Tensor XLANativeFunctions::clamp_max(const at::Tensor& self, const at::Scalar& max) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::clamp(bridge::GetXlaTensor(self), std::nullopt, max)); + return bridge::AtenFromXlaTensor(tensor_methods::clamp( + GetValueOrThrow(bridge::GetXlaTensor(self)), std::nullopt, max)); } at::Tensor XLANativeFunctions::clamp_min(const at::Tensor& self, const at::Scalar& min) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::clamp(bridge::GetXlaTensor(self), min, std::nullopt)); + return bridge::AtenFromXlaTensor(tensor_methods::clamp( + GetValueOrThrow(bridge::GetXlaTensor(self)), min, std::nullopt)); } at::Tensor XLANativeFunctions::clone( const at::Tensor& self, std::optional /* memory_format */) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto tensor = bridge::GetXlaTensor(self); + auto tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); if (self.is_conj()) { // Materialize the conjugate if necessary. tensor = tensor_methods::conj(tensor); @@ -1351,7 +1373,8 @@ at::Tensor XLANativeFunctions::constant_pad_nd(const at::Tensor& self, const at::Scalar& value) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::constant_pad_nd( - bridge::GetXlaTensor(self), XlaHelpers::I64List(pad), value)); + GetValueOrThrow(bridge::GetXlaTensor(self)), XlaHelpers::I64List(pad), + value)); } // This functions covers the whole convolution lowering. @@ -1363,13 +1386,16 @@ at::Tensor XLANativeFunctions::convolution_overrideable( TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); if (IsDefined(bias)) { return bridge::AtenFromXlaTensor(tensor_methods::convolution_overrideable( - bridge::GetXlaTensor(input), bridge::GetXlaTensor(weight), - bridge::GetXlaTensor(*bias), XlaHelpers::I64List(stride), - XlaHelpers::I64List(padding), XlaHelpers::I64List(dilation), transposed, + GetValueOrThrow(bridge::GetXlaTensor(input)), + GetValueOrThrow(bridge::GetXlaTensor(weight)), + GetValueOrThrow(bridge::GetXlaTensor(*bias)), + XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), + XlaHelpers::I64List(dilation), transposed, XlaHelpers::I64List(output_padding), groups)); } else { return bridge::AtenFromXlaTensor(tensor_methods::convolution_overrideable( - bridge::GetXlaTensor(input), bridge::GetXlaTensor(weight), + GetValueOrThrow(bridge::GetXlaTensor(input)), + GetValueOrThrow(bridge::GetXlaTensor(weight)), XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), XlaHelpers::I64List(dilation), transposed, XlaHelpers::I64List(output_padding), groups)); @@ -1385,9 +1411,11 @@ XLANativeFunctions::convolution_backward_overrideable( int64_t groups, std::array output_mask) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); auto gradients = tensor_methods::convolution_backward_overrideable( - bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(input), - bridge::GetXlaTensor(weight), XlaHelpers::I64List(stride), - XlaHelpers::I64List(padding), XlaHelpers::I64List(dilation), transposed, + GetValueOrThrow(bridge::GetXlaTensor(grad_output)), + GetValueOrThrow(bridge::GetXlaTensor(input)), + GetValueOrThrow(bridge::GetXlaTensor(weight)), + XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), + XlaHelpers::I64List(dilation), transposed, XlaHelpers::I64List(output_padding), groups); return std::make_tuple( output_mask[0] ? bridge::AtenFromXlaTensor(std::get<0>(gradients)) @@ -1415,15 +1443,16 @@ at::Tensor XLANativeFunctions::cross(const at::Tensor& self, const at::Tensor& other, std::optional dim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::cross( - bridge::GetXlaTensor(self), bridge::GetXlaTensor(other), - XlaHelpers::I64Optional(dim))); + return bridge::AtenFromXlaTensor( + tensor_methods::cross(GetValueOrThrow(bridge::GetXlaTensor(self)), + GetValueOrThrow(bridge::GetXlaTensor(other)), + XlaHelpers::I64Optional(dim))); } std::tuple XLANativeFunctions::cummax( const at::Tensor& self, int64_t dim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); std::tuple res = tensor_methods::cummax(self_tensor, dim); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(res)), @@ -1433,7 +1462,7 @@ std::tuple XLANativeFunctions::cummax( at::Tensor XLANativeFunctions::cumprod(const at::Tensor& self, int64_t dim, std::optional dtype) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); std::optional promoted_dtype = PromoteIntegralType(self_tensor->dtype(), dtype); if (IsOperationOnType(promoted_dtype, self_tensor->dtype(), @@ -1449,7 +1478,7 @@ at::Tensor XLANativeFunctions::cumprod(const at::Tensor& self, int64_t dim, at::Tensor XLANativeFunctions::cumsum(const at::Tensor& self, int64_t dim, std::optional dtype) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor( tensor_methods::cumsum(self_tensor, dim, dtype)); } @@ -1457,31 +1486,33 @@ at::Tensor XLANativeFunctions::cumsum(const at::Tensor& self, int64_t dim, // TODO(alanwaketan): Let's rewrite a without reusing other native functions. at::Tensor XLANativeFunctions::detach_copy(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor(GetValueOrThrow(bridge::GetXlaTensor(self))); } at::Tensor XLANativeFunctions::diag(const at::Tensor& self, int64_t diagonal) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::diag(bridge::GetXlaTensor(self), diagonal)); + return bridge::AtenFromXlaTensor(tensor_methods::diag( + GetValueOrThrow(bridge::GetXlaTensor(self)), diagonal)); } at::Tensor XLANativeFunctions::diagonal_copy(const at::Tensor& self, int64_t offset, int64_t dim1, int64_t dim2) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::diagonal(bridge::GetXlaTensor(self), offset, dim1, dim2)); + return bridge::AtenFromXlaTensor(tensor_methods::diagonal( + GetValueOrThrow(bridge::GetXlaTensor(self)), offset, dim1, dim2)); } at::Tensor XLANativeFunctions::diagonal_scatter(const at::Tensor& base, const at::Tensor& mutated_view, int64_t offset, int64_t dim1, int64_t dim2) { - auto base_ = bridge::GetXlaTensor(base); - auto mutated_view_ = bridge::GetXlaTensor(mutated_view); - int64_t base_rank = - bridge::GetXlaTensor(base)->shape().get().dimensions_size(); + auto base_ = GetValueOrThrow(bridge::GetXlaTensor(base)); + auto mutated_view_ = GetValueOrThrow(bridge::GetXlaTensor(mutated_view)); + int64_t base_rank = GetValueOrThrow(bridge::GetXlaTensor(base)) + ->shape() + .get() + .dimensions_size(); int64_t canonical_dim1 = torch::lazy::GetCanonicalDimensionIndex(dim1, base_rank); int64_t canonical_dim2 = @@ -1512,7 +1543,7 @@ at::Tensor XLANativeFunctions::div(const at::Tensor& self, const at::Scalar& other) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( - tensor_methods::div(bridge::GetXlaTensor(self), other)); + tensor_methods::div(GetValueOrThrow(bridge::GetXlaTensor(self)), other)); } at::Tensor XLANativeFunctions::dot(const at::Tensor& self, @@ -1534,8 +1565,9 @@ at::Tensor XLANativeFunctions::dot(const at::Tensor& self, return at::native::call_fallback_fn<&xla_fallback, ATEN_OP(dot)>::call( self, tensor); } - return bridge::AtenFromXlaTensor(tensor_methods::matmul( - bridge::GetXlaTensor(self), bridge::GetXlaTensor(tensor))); + return bridge::AtenFromXlaTensor( + tensor_methods::matmul(GetValueOrThrow(bridge::GetXlaTensor(self)), + GetValueOrThrow(bridge::GetXlaTensor(tensor)))); } at::Tensor XLANativeFunctions::einsum(std::string_view equation, @@ -1548,14 +1580,14 @@ at::Tensor XLANativeFunctions::einsum(std::string_view equation, [](unsigned char x) { return std::isspace(x); }), cleansed_equation.end()); - std::vector xla_tensors = bridge::TryGetXlaTensors(tensors); - bool all_xla_tensors_are_valid = true; - for (const XLATensorPtr xla_tensor : xla_tensors) { - if (!xla_tensor) { - all_xla_tensors_are_valid = false; - break; - } - } + std::vector xla_tensors; + std::transform(tensors.begin(), tensors.end(), + std::back_inserter(xla_tensors), [](const at::Tensor& tensor) { + return bridge::GetXlaTensor(tensor).value_or(XLATensorPtr{}); + }); + bool all_xla_tensors_are_valid = std::all_of( + xla_tensors.begin(), xla_tensors.end(), + [](const auto& xla_tensor) { return static_cast(xla_tensor); }); TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); // Einsum operations with more than 2 operands, like bilinear operations, are @@ -1581,8 +1613,8 @@ at::Tensor XLANativeFunctions::elu_backward(const at::Tensor& grad_output, << "In-place elu backward calculation is triggered with a negative slope " "which is not supported."; return bridge::AtenFromXlaTensor(tensor_methods::elu_backward( - bridge::GetXlaTensor(grad_output), alpha, scale, input_scale, - bridge::GetXlaTensor(self_or_result))); + GetValueOrThrow(bridge::GetXlaTensor(grad_output)), alpha, scale, + input_scale, GetValueOrThrow(bridge::GetXlaTensor(self_or_result)))); } at::Tensor XLANativeFunctions::embedding_dense_backward( @@ -1590,8 +1622,9 @@ at::Tensor XLANativeFunctions::embedding_dense_backward( int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::embedding_dense_backward( - bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(indices), - num_weights, padding_idx, scale_grad_by_freq)); + GetValueOrThrow(bridge::GetXlaTensor(grad_output)), + GetValueOrThrow(bridge::GetXlaTensor(indices)), num_weights, padding_idx, + scale_grad_by_freq)); } std::tuple @@ -1610,16 +1643,17 @@ XLANativeFunctions::_embedding_bag_forward_only( include_last_offset, padding_idx); } - auto indices_tensor = bridge::GetXlaTensor(indices); + auto indices_tensor = GetValueOrThrow(bridge::GetXlaTensor(indices)); auto sample_weights = per_sample_weights.has_value() && per_sample_weights.value().defined() - ? bridge::GetXlaTensor(per_sample_weights.value()) + ? GetValueOrThrow(bridge::GetXlaTensor(per_sample_weights.value())) : tensor_methods::full_like(indices_tensor, 1.0, *torch_xla::bridge::GetXlaDevice(weight), at::ScalarType::Float); auto result = tensor_methods::embedding_bag( - bridge::GetXlaTensor(weight), indices_tensor, - bridge::GetXlaTensor(offsets), mode, sample_weights, include_last_offset); + GetValueOrThrow(bridge::GetXlaTensor(weight)), indices_tensor, + GetValueOrThrow(bridge::GetXlaTensor(offsets)), mode, sample_weights, + include_last_offset); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(result)), bridge::AtenFromXlaTensor(std::get<1>(result)), bridge::AtenFromXlaTensor(std::get<2>(result)), @@ -1708,13 +1742,14 @@ at::Tensor XLANativeFunctions::expand_copy_symint(const at::Tensor& self, TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); std::optional size = c10::asIntArrayRefSlowOpt(sym_size); if (size.has_value()) { - return bridge::AtenFromXlaTensor(tensor_methods::expand( - bridge::GetXlaTensor(self), torch::lazy::ToVector(*size))); + return bridge::AtenFromXlaTensor( + tensor_methods::expand(GetValueOrThrow(bridge::GetXlaTensor(self)), + torch::lazy::ToVector(*size))); } else { // at least one of the dimension is symbolic, use the sym_int version of the // node - return bridge::AtenFromXlaTensor( - tensor_methods::expand_symint(bridge::GetXlaTensor(self), sym_size)); + return bridge::AtenFromXlaTensor(tensor_methods::expand_symint( + GetValueOrThrow(bridge::GetXlaTensor(self)), sym_size)); } } @@ -1728,21 +1763,21 @@ at::Tensor& XLANativeFunctions::exponential_( generator); } XLA_CHECK_GE(lambd, 0.0); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); tensor_methods::exponential_(self_tensor, lambd); return self; } at::Tensor& XLANativeFunctions::eye_out(int64_t n, at::Tensor& out) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr out_tensor = bridge::GetXlaTensor(out); + XLATensorPtr out_tensor = GetValueOrThrow(bridge::GetXlaTensor(out)); tensor_methods::eye_out(out_tensor, n, n); return out; } at::Tensor& XLANativeFunctions::eye_out(int64_t n, int64_t m, at::Tensor& out) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr out_tensor = bridge::GetXlaTensor(out); + XLATensorPtr out_tensor = GetValueOrThrow(bridge::GetXlaTensor(out)); tensor_methods::eye_out(out_tensor, n, m); return out; } @@ -1750,7 +1785,7 @@ at::Tensor& XLANativeFunctions::eye_out(int64_t n, int64_t m, at::Tensor& out) { at::Tensor& XLANativeFunctions::fill_(at::Tensor& self, const at::Scalar& value) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); tensor_methods::fill_(self_tensor, value); return self; } @@ -1768,7 +1803,7 @@ at::Tensor XLANativeFunctions::flip(const at::Tensor& self, at::IntArrayRef dims) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::flip( - bridge::GetXlaTensor(self), XlaHelpers::I64List(dims))); + GetValueOrThrow(bridge::GetXlaTensor(self)), XlaHelpers::I64List(dims))); } at::Tensor XLANativeFunctions::floor_divide(const at::Tensor& self, @@ -1828,15 +1863,16 @@ at::Tensor XLANativeFunctions::gather(const at::Tensor& self, int64_t dim, const at::Tensor& index, bool /* sparse_grad */) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::gather( - bridge::GetXlaTensor(self), dim, bridge::GetXlaTensor(index))); + return bridge::AtenFromXlaTensor( + tensor_methods::gather(GetValueOrThrow(bridge::GetXlaTensor(self)), dim, + GetValueOrThrow(bridge::GetXlaTensor(index)))); } at::Tensor XLANativeFunctions::gelu(const at::Tensor& self, std::string_view approximate) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::gelu(bridge::GetXlaTensor(self), approximate)); + return bridge::AtenFromXlaTensor(tensor_methods::gelu( + GetValueOrThrow(bridge::GetXlaTensor(self)), approximate)); } at::Tensor XLANativeFunctions::gelu_backward(const at::Tensor& grad, @@ -1845,16 +1881,17 @@ at::Tensor XLANativeFunctions::gelu_backward(const at::Tensor& grad, TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); at::ScalarType result_type = at::result_type(grad, self); return bridge::AtenFromXlaTensor(tensor_methods::gelu_backward( - bridge::GetXlaTensor(grad.to(result_type)), - bridge::GetXlaTensor(self.to(result_type)), approximate)); + GetValueOrThrow(bridge::GetXlaTensor(grad.to(result_type))), + GetValueOrThrow(bridge::GetXlaTensor(self.to(result_type))), + approximate)); } at::Tensor XLANativeFunctions::hardtanh(const at::Tensor& self, const at::Scalar& min_val, const at::Scalar& max_val) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::clamp(bridge::GetXlaTensor(self), min_val, max_val)); + return bridge::AtenFromXlaTensor(tensor_methods::clamp( + GetValueOrThrow(bridge::GetXlaTensor(self)), min_val, max_val)); } at::Tensor XLANativeFunctions::hardtanh_backward(const at::Tensor& grad_output, @@ -1863,8 +1900,8 @@ at::Tensor XLANativeFunctions::hardtanh_backward(const at::Tensor& grad_output, const at::Scalar& max_val) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::hardtanh_backward( - bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), min_val, - max_val)); + GetValueOrThrow(bridge::GetXlaTensor(grad_output)), + GetValueOrThrow(bridge::GetXlaTensor(self)), min_val, max_val)); } at::Tensor XLANativeFunctions::index( @@ -1902,27 +1939,28 @@ at::Tensor XLANativeFunctions::index_add(const at::Tensor& self, int64_t dim, const at::Scalar& alpha) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::index_add( - bridge::GetXlaTensor(self), dim, bridge::GetXlaTensor(index), - bridge::GetXlaTensor(source), alpha)); + GetValueOrThrow(bridge::GetXlaTensor(self)), dim, + GetValueOrThrow(bridge::GetXlaTensor(index)), + GetValueOrThrow(bridge::GetXlaTensor(source)), alpha)); } at::Tensor XLANativeFunctions::index_copy(const at::Tensor& self, int64_t dim, const at::Tensor& index, const at::Tensor& source) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); - return bridge::AtenFromXlaTensor( - tensor_methods::index_copy(self_tensor, dim, bridge::GetXlaTensor(index), - bridge::GetXlaTensor(source))); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor(tensor_methods::index_copy( + self_tensor, dim, GetValueOrThrow(bridge::GetXlaTensor(index)), + GetValueOrThrow(bridge::GetXlaTensor(source)))); } at::Tensor& XLANativeFunctions::index_fill_(at::Tensor& self, int64_t dim, const at::Tensor& index, const at::Scalar& value) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); - tensor_methods::index_fill_(self_tensor, dim, bridge::GetXlaTensor(index), - value); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + tensor_methods::index_fill_( + self_tensor, dim, GetValueOrThrow(bridge::GetXlaTensor(index)), value); return self; } @@ -1930,9 +1968,10 @@ at::Tensor& XLANativeFunctions::index_fill_(at::Tensor& self, int64_t dim, const at::Tensor& index, const at::Tensor& value) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); - tensor_methods::index_fill_(self_tensor, dim, bridge::GetXlaTensor(index), - bridge::GetXlaTensor(value)); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + tensor_methods::index_fill_(self_tensor, dim, + GetValueOrThrow(bridge::GetXlaTensor(index)), + GetValueOrThrow(bridge::GetXlaTensor(value))); return self; } @@ -1975,7 +2014,8 @@ at::Tensor XLANativeFunctions::index_select(const at::Tensor& self, int64_t dim, const at::Tensor& index) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::index_select( - bridge::GetXlaTensor(self), dim, bridge::GetXlaTensor(index))); + GetValueOrThrow(bridge::GetXlaTensor(self)), dim, + GetValueOrThrow(bridge::GetXlaTensor(index)))); } at::Tensor XLANativeFunctions::kl_div(const at::Tensor& self, @@ -1988,8 +2028,8 @@ at::Tensor XLANativeFunctions::kl_div(const at::Tensor& self, std::tuple XLANativeFunctions::kthvalue( const at::Tensor& self, int64_t k, int64_t dim, bool keepdim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto results = - tensor_methods::kthvalue(bridge::GetXlaTensor(self), k, dim, keepdim); + auto results = tensor_methods::kthvalue( + GetValueOrThrow(bridge::GetXlaTensor(self)), k, dim, keepdim); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), bridge::AtenFromXlaTensor(std::get<1>(results))); } @@ -2005,9 +2045,9 @@ at::Tensor XLANativeFunctions::leaky_relu_backward( torch::lazy::LazyGraphExecutor::Get()->GetIrValueForScalarFromCodegen( negative_slope, *common_device); torch::lazy::NodePtr node = torch_xla::MakeNode( - bridge::GetXlaTensor(grad_output)->GetIrValue(), - bridge::GetXlaTensor(self)->GetIrValue(), node_negative_slope, - self_is_result); + GetValueOrThrow(bridge::GetXlaTensor(grad_output))->GetIrValue(), + GetValueOrThrow(bridge::GetXlaTensor(self))->GetIrValue(), + node_negative_slope, self_is_result); return torch_xla::bridge::AtenFromXlaTensor( torch_xla::XLATensor::Create(std::move(node), *common_device)); } @@ -2022,9 +2062,10 @@ at::Tensor XLANativeFunctions::lerp(const at::Tensor& self, XLA_CHECK_EQ(self.dtype(), weight.dtype()) << "expected dtype " << self.dtype() << " for `weight` but got dtype " << weight.dtype(); - return bridge::AtenFromXlaTensor(tensor_methods::lerp( - bridge::GetXlaTensor(self), bridge::GetXlaTensor(end), - bridge::GetXlaTensor(weight))); + return bridge::AtenFromXlaTensor( + tensor_methods::lerp(GetValueOrThrow(bridge::GetXlaTensor(self)), + GetValueOrThrow(bridge::GetXlaTensor(end)), + GetValueOrThrow(bridge::GetXlaTensor(weight)))); } at::Tensor XLANativeFunctions::lerp(const at::Tensor& self, @@ -2034,8 +2075,9 @@ at::Tensor XLANativeFunctions::lerp(const at::Tensor& self, XLA_CHECK_EQ(self.dtype(), end.dtype()) << "expected dtype " << self.dtype() << " for `end` but got dtype " << end.dtype(); - return bridge::AtenFromXlaTensor(tensor_methods::lerp( - bridge::GetXlaTensor(self), bridge::GetXlaTensor(end), weight)); + return bridge::AtenFromXlaTensor( + tensor_methods::lerp(GetValueOrThrow(bridge::GetXlaTensor(self)), + GetValueOrThrow(bridge::GetXlaTensor(end)), weight)); } at::Tensor XLANativeFunctions::lift(const at::Tensor& tensor) { @@ -2064,8 +2106,8 @@ std::tuple XLANativeFunctions::linalg_inv_ex( } auto common_device = torch_xla::bridge::GetXlaDevice(self); TORCH_INTERNAL_ASSERT(common_device); - torch::lazy::NodePtr node = - torch_xla::MakeNode(bridge::GetXlaTensor(self)->GetIrValue()); + torch::lazy::NodePtr node = torch_xla::MakeNode( + GetValueOrThrow(bridge::GetXlaTensor(self))->GetIrValue()); auto result = torch_xla::XLATensor::Create(std::move(node), *common_device); auto info = tensor_methods::full_like(result, 0, result->GetDevice(), at::ScalarType::Int); @@ -2095,62 +2137,67 @@ at::Tensor XLANativeFunctions::linspace(const at::Scalar& start, at::Tensor XLANativeFunctions::log(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( - tensor_methods::log(bridge::GetXlaTensor(self))); + tensor_methods::log(GetValueOrThrow(bridge::GetXlaTensor(self)))); } at::Tensor XLANativeFunctions::logit(const at::Tensor& self, std::optional eps) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( - tensor_methods::logit(bridge::GetXlaTensor(self), eps)); + tensor_methods::logit(GetValueOrThrow(bridge::GetXlaTensor(self)), eps)); } at::Tensor XLANativeFunctions::log10(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::log_base( - bridge::GetXlaTensor(self), torch::lazy::OpKind(at::aten::log10), 10.0)); + return bridge::AtenFromXlaTensor( + tensor_methods::log_base(GetValueOrThrow(bridge::GetXlaTensor(self)), + torch::lazy::OpKind(at::aten::log10), 10.0)); } at::Tensor XLANativeFunctions::log1p(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( - tensor_methods::log1p(bridge::GetXlaTensor(self))); + tensor_methods::log1p(GetValueOrThrow(bridge::GetXlaTensor(self)))); } at::Tensor XLANativeFunctions::log2(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::log_base( - bridge::GetXlaTensor(self), torch::lazy::OpKind(at::aten::log2), 2.0)); + return bridge::AtenFromXlaTensor( + tensor_methods::log_base(GetValueOrThrow(bridge::GetXlaTensor(self)), + torch::lazy::OpKind(at::aten::log2), 2.0)); } at::Tensor XLANativeFunctions::logsumexp(const at::Tensor& self, at::IntArrayRef dim, bool keepdim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::logsumexp( - bridge::GetXlaTensor(self), torch::lazy::ToVector(dim), - /*keep_reduced_dimensions=*/keepdim)); + return bridge::AtenFromXlaTensor( + tensor_methods::logsumexp(GetValueOrThrow(bridge::GetXlaTensor(self)), + torch::lazy::ToVector(dim), + /*keep_reduced_dimensions=*/keepdim)); } at::Tensor XLANativeFunctions::xlogy(const at::Tensor& self, const at::Tensor& other) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::xlogy( - bridge::GetXlaTensor(self), bridge::GetXlaTensor(other))); + return bridge::AtenFromXlaTensor( + tensor_methods::xlogy(GetValueOrThrow(bridge::GetXlaTensor(self)), + GetValueOrThrow(bridge::GetXlaTensor(other)))); } at::Tensor XLANativeFunctions::masked_scatter(const at::Tensor& self, const at::Tensor& mask, const at::Tensor& source) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::masked_scatter( - self_tensor, bridge::GetXlaTensor(mask), bridge::GetXlaTensor(source))); + self_tensor, GetValueOrThrow(bridge::GetXlaTensor(mask)), + GetValueOrThrow(bridge::GetXlaTensor(source)))); } at::Tensor XLANativeFunctions::masked_select(const at::Tensor& self, const at::Tensor& mask) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); // Initially make XLA handled masked_select() handling experimental, and // opt-in. if (!DebugUtil::ExperimentEnabled("masked_select")) { @@ -2158,20 +2205,21 @@ at::Tensor XLANativeFunctions::masked_select(const at::Tensor& self, ATEN_OP(masked_select)>::call(self, mask); } - return bridge::AtenFromXlaTensor( - tensor_methods::masked_select(self_tensor, bridge::GetXlaTensor(mask))); + return bridge::AtenFromXlaTensor(tensor_methods::masked_select( + self_tensor, GetValueOrThrow(bridge::GetXlaTensor(mask)))); } at::Tensor XLANativeFunctions::max(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( - tensor_methods::max(bridge::GetXlaTensor(self))); + tensor_methods::max(GetValueOrThrow(bridge::GetXlaTensor(self)))); } std::tuple XLANativeFunctions::max( const at::Tensor& self, int64_t dim, bool keepdim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto outputs = tensor_methods::max(bridge::GetXlaTensor(self), dim, keepdim); + auto outputs = tensor_methods::max( + GetValueOrThrow(bridge::GetXlaTensor(self)), dim, keepdim); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)), bridge::AtenFromXlaTensor(std::get<1>(outputs))); } @@ -2180,10 +2228,12 @@ std::tuple XLANativeFunctions::max_out( const at::Tensor& self, int64_t dim, bool keepdim, at::Tensor& max, at::Tensor& max_values) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr max_tensor = bridge::GetXlaTensor(max); - XLATensorPtr max_values_tensor = bridge::GetXlaTensor(max_values); + XLATensorPtr max_tensor = GetValueOrThrow(bridge::GetXlaTensor(max)); + XLATensorPtr max_values_tensor = + GetValueOrThrow(bridge::GetXlaTensor(max_values)); tensor_methods::max_out(max_tensor, max_values_tensor, - bridge::GetXlaTensor(self), dim, keepdim); + GetValueOrThrow(bridge::GetXlaTensor(self)), dim, + keepdim); return std::forward_as_tuple(max, max_values); } @@ -2209,7 +2259,7 @@ std::tuple XLANativeFunctions::max_pool2d_with_indices( ceil_mode); } auto outputs = tensor_methods::max_pool_nd( - bridge::GetXlaTensor(self), /*spatial_dim_count=*/2, + GetValueOrThrow(bridge::GetXlaTensor(self)), /*spatial_dim_count=*/2, XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)), @@ -2232,7 +2282,8 @@ at::Tensor XLANativeFunctions::max_pool2d_with_indices_backward( ceil_mode, indices); } return bridge::AtenFromXlaTensor(tensor_methods::max_pool_nd_backward( - bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), + GetValueOrThrow(bridge::GetXlaTensor(grad_output)), + GetValueOrThrow(bridge::GetXlaTensor(self)), /*spatial_dim_count=*/2, XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode)); } @@ -2261,7 +2312,8 @@ at::Tensor XLANativeFunctions::max_pool3d_with_indices_backward( ceil_mode, indices); } return bridge::AtenFromXlaTensor(tensor_methods::max_pool_nd_backward( - bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), + GetValueOrThrow(bridge::GetXlaTensor(grad_output)), + GetValueOrThrow(bridge::GetXlaTensor(self)), /*spatial_dim_count=*/3, XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode)); } @@ -2280,7 +2332,7 @@ std::tuple XLANativeFunctions::max_pool3d_with_indices( ceil_mode); } auto outputs = tensor_methods::max_pool_nd( - bridge::GetXlaTensor(self), /*spatial_dim_count=*/3, + GetValueOrThrow(bridge::GetXlaTensor(self)), /*spatial_dim_count=*/3, XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)), @@ -2291,9 +2343,10 @@ at::Tensor XLANativeFunctions::max_unpool2d(const at::Tensor& self, const at::Tensor& indices, at::IntArrayRef output_size) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::max_unpool( - bridge::GetXlaTensor(self), bridge::GetXlaTensor(indices), - torch::lazy::ToVector(output_size))); + return bridge::AtenFromXlaTensor( + tensor_methods::max_unpool(GetValueOrThrow(bridge::GetXlaTensor(self)), + GetValueOrThrow(bridge::GetXlaTensor(indices)), + torch::lazy::ToVector(output_size))); } at::Tensor XLANativeFunctions::max_unpool3d(const at::Tensor& self, @@ -2302,15 +2355,16 @@ at::Tensor XLANativeFunctions::max_unpool3d(const at::Tensor& self, at::IntArrayRef stride, at::IntArrayRef padding) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::max_unpool( - bridge::GetXlaTensor(self), bridge::GetXlaTensor(indices), - torch::lazy::ToVector(output_size))); + return bridge::AtenFromXlaTensor( + tensor_methods::max_unpool(GetValueOrThrow(bridge::GetXlaTensor(self)), + GetValueOrThrow(bridge::GetXlaTensor(indices)), + torch::lazy::ToVector(output_size))); } at::Tensor XLANativeFunctions::mean(const at::Tensor& self, std::optional dtype) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::mean( self_tensor, torch::lazy::Iota(self_tensor->shape().get().dimensions_size()), @@ -2321,7 +2375,7 @@ at::Tensor XLANativeFunctions::mean(const at::Tensor& self, at::OptionalIntArrayRef dim, bool keepdim, std::optional dtype) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::mean( self_tensor, dim ? torch::lazy::ToVector(*dim) @@ -2333,13 +2387,14 @@ at::Tensor XLANativeFunctions::mean(const at::Tensor& self, at::Tensor XLANativeFunctions::min(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( - tensor_methods::min(bridge::GetXlaTensor(self))); + tensor_methods::min(GetValueOrThrow(bridge::GetXlaTensor(self)))); } std::tuple XLANativeFunctions::min( const at::Tensor& self, int64_t dim, bool keepdim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto outputs = tensor_methods::min(bridge::GetXlaTensor(self), dim, keepdim); + auto outputs = tensor_methods::min( + GetValueOrThrow(bridge::GetXlaTensor(self)), dim, keepdim); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)), bridge::AtenFromXlaTensor(std::get<1>(outputs))); } @@ -2347,26 +2402,28 @@ std::tuple XLANativeFunctions::min( at::Tensor XLANativeFunctions::mish(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( - tensor_methods::mish(bridge::GetXlaTensor(self))); + tensor_methods::mish(GetValueOrThrow(bridge::GetXlaTensor(self)))); } std::tuple XLANativeFunctions::min_out( const at::Tensor& self, int64_t dim, bool keepdim, at::Tensor& min, at::Tensor& min_indices) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr min_tensor = bridge::GetXlaTensor(min); - XLATensorPtr min_indices_tensor = bridge::GetXlaTensor(min_indices); + XLATensorPtr min_tensor = GetValueOrThrow(bridge::GetXlaTensor(min)); + XLATensorPtr min_indices_tensor = + GetValueOrThrow(bridge::GetXlaTensor(min_indices)); tensor_methods::min_out(min_tensor, min_indices_tensor, - bridge::GetXlaTensor(self), dim, keepdim); + GetValueOrThrow(bridge::GetXlaTensor(self)), dim, + keepdim); return std::forward_as_tuple(min, min_indices); } at::Tensor XLANativeFunctions::mm(const at::Tensor& self, const at::Tensor& mat2) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::mm(/*input=*/bridge::GetXlaTensor(self), - /*weight=*/bridge::GetXlaTensor(mat2))); + return bridge::AtenFromXlaTensor(tensor_methods::mm( + /*input=*/GetValueOrThrow(bridge::GetXlaTensor(self)), + /*weight=*/GetValueOrThrow(bridge::GetXlaTensor(mat2)))); } at::Tensor XLANativeFunctions::mse_loss(const at::Tensor& self, @@ -2374,7 +2431,8 @@ at::Tensor XLANativeFunctions::mse_loss(const at::Tensor& self, int64_t reduction) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::mse_loss( - bridge::GetXlaTensor(self), bridge::GetXlaTensor(target), reduction)); + GetValueOrThrow(bridge::GetXlaTensor(self)), + GetValueOrThrow(bridge::GetXlaTensor(target)), reduction)); } at::Tensor XLANativeFunctions::mse_loss_backward(const at::Tensor& grad_output, @@ -2383,8 +2441,9 @@ at::Tensor XLANativeFunctions::mse_loss_backward(const at::Tensor& grad_output, int64_t reduction) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::mse_loss_backward( - bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), - bridge::GetXlaTensor(target), reduction)); + GetValueOrThrow(bridge::GetXlaTensor(grad_output)), + GetValueOrThrow(bridge::GetXlaTensor(self)), + GetValueOrThrow(bridge::GetXlaTensor(target)), reduction)); } at::Tensor XLANativeFunctions::mul(const at::Tensor& self, @@ -2428,7 +2487,7 @@ at::Tensor XLANativeFunctions::multinomial( replacement, generator); } - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor( tensor_methods::multinomial(self_tensor, num_samples, replacement)); } @@ -2436,16 +2495,18 @@ at::Tensor XLANativeFunctions::multinomial( at::Tensor XLANativeFunctions::mv(const at::Tensor& self, const at::Tensor& vec) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::mv( - bridge::GetXlaTensor(self), bridge::GetXlaTensor(vec))); + return bridge::AtenFromXlaTensor( + tensor_methods::mv(GetValueOrThrow(bridge::GetXlaTensor(self)), + GetValueOrThrow(bridge::GetXlaTensor(vec)))); } at::Tensor& XLANativeFunctions::mv_out(const at::Tensor& self, const at::Tensor& vec, at::Tensor& out) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr out_tensor = bridge::GetXlaTensor(out); - tensor_methods::mv_out(out_tensor, bridge::GetXlaTensor(self), - bridge::GetXlaTensor(vec)); + XLATensorPtr out_tensor = GetValueOrThrow(bridge::GetXlaTensor(out)); + tensor_methods::mv_out(out_tensor, + GetValueOrThrow(bridge::GetXlaTensor(self)), + GetValueOrThrow(bridge::GetXlaTensor(vec))); return out; } @@ -2458,7 +2519,7 @@ at::Tensor XLANativeFunctions::nan_to_num(const at::Tensor& self, if (!at::native::is_floating_point(self)) { return torch::lazy::CopyTensor(self); } - XLATensorPtr input_tensor = bridge::GetXlaTensor(self); + XLATensorPtr input_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); const torch::lazy::BackendDevice& device = input_tensor->GetDevice(); auto element_type = MakeXlaPrimitiveType(self.scalar_type(), &device); XlaHelpers::MinMax min_max = XlaHelpers::MinMaxValues(element_type); @@ -2485,14 +2546,15 @@ XLANativeFunctions::native_batch_norm( const std::optional& running_var, bool training, double momentum, double eps) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr input_tensor = bridge::GetXlaTensor(input); + XLATensorPtr input_tensor = GetValueOrThrow(bridge::GetXlaTensor(input)); const torch::lazy::BackendDevice& device = input_tensor->GetDevice(); XLATensorPtr running_mean_tensor = bridge::GetOrCreateXlaTensor(running_mean, device); XLATensorPtr running_var_tensor = bridge::GetOrCreateXlaTensor(running_var, device); auto outputs = tensor_methods::native_batch_norm( - bridge::GetXlaTensor(input), bridge::GetOrCreateXlaTensor(weight, device), + GetValueOrThrow(bridge::GetXlaTensor(input)), + bridge::GetOrCreateXlaTensor(weight, device), bridge::GetOrCreateXlaTensor(bias, device), running_mean_tensor, running_var_tensor, training, momentum, eps); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)), @@ -2506,12 +2568,15 @@ XLANativeFunctions::_native_batch_norm_legit( const std::optional& bias, at::Tensor& running_mean, at::Tensor& running_var, bool training, double momentum, double eps) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr input_tensor = bridge::GetXlaTensor(input); + XLATensorPtr input_tensor = GetValueOrThrow(bridge::GetXlaTensor(input)); const torch::lazy::BackendDevice& device = input_tensor->GetDevice(); - XLATensorPtr running_mean_tensor = bridge::GetXlaTensor(running_mean); - XLATensorPtr running_var_tensor = bridge::GetXlaTensor(running_var); + XLATensorPtr running_mean_tensor = + GetValueOrThrow(bridge::GetXlaTensor(running_mean)); + XLATensorPtr running_var_tensor = + GetValueOrThrow(bridge::GetXlaTensor(running_var)); auto outputs = tensor_methods::native_batch_norm( - bridge::GetXlaTensor(input), bridge::GetOrCreateXlaTensor(weight, device), + GetValueOrThrow(bridge::GetXlaTensor(input)), + bridge::GetOrCreateXlaTensor(weight, device), bridge::GetOrCreateXlaTensor(bias, device), running_mean_tensor, running_var_tensor, training, momentum, eps); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)), @@ -2525,12 +2590,13 @@ XLANativeFunctions::_native_batch_norm_legit( const std::optional& bias, bool training, double momentum, double eps) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr input_tensor = bridge::GetXlaTensor(input); + XLATensorPtr input_tensor = GetValueOrThrow(bridge::GetXlaTensor(input)); const torch::lazy::BackendDevice& device = input_tensor->GetDevice(); XLATensorPtr null_running_mean_tensor = XLATensorPtr(); XLATensorPtr null_running_var_tensor = XLATensorPtr(); auto outputs = tensor_methods::native_batch_norm( - bridge::GetXlaTensor(input), bridge::GetOrCreateXlaTensor(weight, device), + GetValueOrThrow(bridge::GetXlaTensor(input)), + bridge::GetOrCreateXlaTensor(weight, device), bridge::GetOrCreateXlaTensor(bias, device), null_running_mean_tensor, null_running_var_tensor, training, momentum, eps); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)), @@ -2548,10 +2614,12 @@ XLANativeFunctions::native_batch_norm_backward( const std::optional& save_invstd, bool train, double eps, std::array output_mask) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr grad_out_tensor = bridge::GetXlaTensor(grad_out); + XLATensorPtr grad_out_tensor = + GetValueOrThrow(bridge::GetXlaTensor(grad_out)); const torch::lazy::BackendDevice& device = grad_out_tensor->GetDevice(); auto gradients = tensor_methods::native_batch_norm_backward( - bridge::GetXlaTensor(grad_out), bridge::GetXlaTensor(input), + GetValueOrThrow(bridge::GetXlaTensor(grad_out)), + GetValueOrThrow(bridge::GetXlaTensor(input)), bridge::GetOrCreateXlaTensor(weight, device), bridge::GetOrCreateXlaTensor(save_mean, device), bridge::GetOrCreateXlaTensor(save_invstd, device), train, eps); @@ -2568,7 +2636,7 @@ XLANativeFunctions::native_batch_norm_backward( std::tuple XLANativeFunctions::native_dropout( const at::Tensor& self, double p, std::optional train) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); auto results = tensor_methods::native_dropout(self_tensor, p, train); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), bridge::AtenFromXlaTensor(std::get<1>(results))); @@ -2581,7 +2649,7 @@ at::Tensor XLANativeFunctions::neg(const at::Tensor& self) { "you are trying to invert a mask, use the `~` or `logical_not()` " "operator instead."; return bridge::AtenFromXlaTensor( - tensor_methods::neg(bridge::GetXlaTensor(self))); + tensor_methods::neg(GetValueOrThrow(bridge::GetXlaTensor(self)))); } at::Tensor XLANativeFunctions::nll_loss2d_backward( @@ -2589,7 +2657,7 @@ at::Tensor XLANativeFunctions::nll_loss2d_backward( const at::Tensor& target, const std::optional& weight, int64_t reduction, int64_t ignore_index, const at::Tensor& total_weight) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); XLATensorPtr weight_tensor = bridge::GetOrCreateXlaTensor(weight, self_tensor->GetDevice()); XLATensorPtr total_weight_tensor; @@ -2598,9 +2666,9 @@ at::Tensor XLANativeFunctions::nll_loss2d_backward( bridge::GetOrCreateXlaTensor(total_weight, self_tensor->GetDevice()); } return bridge::AtenFromXlaTensor(tensor_methods::nll_loss2d_backward( - bridge::GetXlaTensor(grad_output), self_tensor, - bridge::GetXlaTensor(target), weight_tensor, reduction, ignore_index, - total_weight_tensor)); + GetValueOrThrow(bridge::GetXlaTensor(grad_output)), self_tensor, + GetValueOrThrow(bridge::GetXlaTensor(target)), weight_tensor, reduction, + ignore_index, total_weight_tensor)); } std::tuple XLANativeFunctions::nll_loss2d_forward( @@ -2608,12 +2676,12 @@ std::tuple XLANativeFunctions::nll_loss2d_forward( const std::optional& weight, int64_t reduction, int64_t ignore_index) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); XLATensorPtr total_weight = tensor_methods::full( {}, 1, self_tensor->GetDevice(), self_tensor->dtype()); return std::make_tuple( bridge::AtenFromXlaTensor(tensor_methods::nll_loss2d( - self_tensor, bridge::GetXlaTensor(target), + self_tensor, GetValueOrThrow(bridge::GetXlaTensor(target)), bridge::GetOrCreateXlaTensor(weight, self_tensor->GetDevice()), reduction, ignore_index)), bridge::AtenFromXlaTensor(total_weight)); @@ -2624,7 +2692,7 @@ at::Tensor XLANativeFunctions::nll_loss_backward( const at::Tensor& target, const std::optional& weight, int64_t reduction, int64_t ignore_index, const at::Tensor& total_weight) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); XLATensorPtr weight_tensor = bridge::GetOrCreateXlaTensor(weight, self_tensor->GetDevice()); XLATensorPtr total_weight_tensor; @@ -2633,9 +2701,9 @@ at::Tensor XLANativeFunctions::nll_loss_backward( bridge::GetOrCreateXlaTensor(total_weight, self_tensor->GetDevice()); } return bridge::AtenFromXlaTensor(tensor_methods::nll_loss_backward( - bridge::GetXlaTensor(grad_output), self_tensor, - bridge::GetXlaTensor(target), weight_tensor, reduction, ignore_index, - total_weight_tensor)); + GetValueOrThrow(bridge::GetXlaTensor(grad_output)), self_tensor, + GetValueOrThrow(bridge::GetXlaTensor(target)), weight_tensor, reduction, + ignore_index, total_weight_tensor)); } std::tuple XLANativeFunctions::nll_loss_forward( @@ -2643,12 +2711,12 @@ std::tuple XLANativeFunctions::nll_loss_forward( const std::optional& weight, int64_t reduction, int64_t ignore_index) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); XLATensorPtr total_weight = tensor_methods::full( {}, 1, self_tensor->GetDevice(), self_tensor->dtype()); return std::make_tuple( bridge::AtenFromXlaTensor(tensor_methods::nll_loss( - self_tensor, bridge::GetXlaTensor(target), + self_tensor, GetValueOrThrow(bridge::GetXlaTensor(target)), bridge::GetOrCreateXlaTensor(weight, self_tensor->GetDevice()), reduction, ignore_index)), bridge::AtenFromXlaTensor(total_weight)); @@ -2656,7 +2724,7 @@ std::tuple XLANativeFunctions::nll_loss_forward( at::Tensor XLANativeFunctions::nonzero(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); // Initially make XLA handled nonzero() handling experimental, and opt-in. if (!DebugUtil::ExperimentEnabled("nonzero")) { return at::native::call_fallback_fn<&xla_fallback, ATEN_OP(nonzero)>::call( @@ -2675,8 +2743,9 @@ at::Tensor XLANativeFunctions::norm(const at::Tensor& self, return at::native::call_fallback_fn< &xla_fallback, ATEN_OP2(norm, ScalarOpt_dtype)>::call(self, p, dtype); } - return bridge::AtenFromXlaTensor(tensor_methods::norm( - bridge::GetXlaTensor(self), p, dtype, {}, /*keepdim=*/false)); + return bridge::AtenFromXlaTensor( + tensor_methods::norm(GetValueOrThrow(bridge::GetXlaTensor(self)), p, + dtype, {}, /*keepdim=*/false)); } at::Tensor XLANativeFunctions::norm(const at::Tensor& self, @@ -2688,8 +2757,9 @@ at::Tensor XLANativeFunctions::norm(const at::Tensor& self, return at::native::call_fallback_fn<&xla_fallback, ATEN_OP2(norm, Scalar)>::call(self, p); } - return bridge::AtenFromXlaTensor(tensor_methods::norm( - bridge::GetXlaTensor(self), p, std::nullopt, {}, /*keepdim=*/false)); + return bridge::AtenFromXlaTensor( + tensor_methods::norm(GetValueOrThrow(bridge::GetXlaTensor(self)), p, + std::nullopt, {}, /*keepdim=*/false)); } at::Tensor XLANativeFunctions::norm(const at::Tensor& self, @@ -2705,8 +2775,8 @@ at::Tensor XLANativeFunctions::norm(const at::Tensor& self, keepdim, dtype); } - return bridge::AtenFromXlaTensor( - tensor_methods::norm(bridge::GetXlaTensor(self), p, dtype, dim, keepdim)); + return bridge::AtenFromXlaTensor(tensor_methods::norm( + GetValueOrThrow(bridge::GetXlaTensor(self)), p, dtype, dim, keepdim)); } at::Tensor XLANativeFunctions::norm(const at::Tensor& self, @@ -2720,8 +2790,9 @@ at::Tensor XLANativeFunctions::norm(const at::Tensor& self, &xla_fallback, ATEN_OP2(norm, ScalarOpt_dim)>::call(self, p, dim, keepdim); } - return bridge::AtenFromXlaTensor(tensor_methods::norm( - bridge::GetXlaTensor(self), p, std::nullopt, dim, keepdim)); + return bridge::AtenFromXlaTensor( + tensor_methods::norm(GetValueOrThrow(bridge::GetXlaTensor(self)), p, + std::nullopt, dim, keepdim)); } at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, double std, @@ -2733,7 +2804,7 @@ at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, double std, generator); } return bridge::AtenFromXlaTensor( - tensor_methods::normal(bridge::GetXlaTensor(mean), std)); + tensor_methods::normal(GetValueOrThrow(bridge::GetXlaTensor(mean)), std)); } at::Tensor XLANativeFunctions::normal(double mean, const at::Tensor& std, @@ -2745,7 +2816,7 @@ at::Tensor XLANativeFunctions::normal(double mean, const at::Tensor& std, generator); } return bridge::AtenFromXlaTensor( - tensor_methods::normal(mean, bridge::GetXlaTensor(std))); + tensor_methods::normal(mean, GetValueOrThrow(bridge::GetXlaTensor(std)))); } at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, @@ -2757,8 +2828,9 @@ at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, &xla_fallback, ATEN_OP2(normal, Tensor_Tensor)>::call(mean, std, generator); } - return bridge::AtenFromXlaTensor(tensor_methods::normal( - bridge::GetXlaTensor(mean), bridge::GetXlaTensor(std))); + return bridge::AtenFromXlaTensor( + tensor_methods::normal(GetValueOrThrow(bridge::GetXlaTensor(mean)), + GetValueOrThrow(bridge::GetXlaTensor(std)))); } at::Tensor& XLANativeFunctions::normal_( @@ -2769,7 +2841,7 @@ at::Tensor& XLANativeFunctions::normal_( return at::native::call_fallback_fn<&xla_fallback, ATEN_OP(normal_)>::call( self, mean, std, generator); } - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); tensor_methods::normal_(self_tensor, mean, std); return self; } @@ -2778,7 +2850,7 @@ at::Tensor XLANativeFunctions::permute_copy(const at::Tensor& self, at::IntArrayRef dims) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::permute( - bridge::GetXlaTensor(self), XlaHelpers::I64List(dims))); + GetValueOrThrow(bridge::GetXlaTensor(self)), XlaHelpers::I64List(dims))); } at::Tensor XLANativeFunctions::pow(const at::Tensor& self, @@ -2824,8 +2896,8 @@ at::Tensor XLANativeFunctions::_prelu_kernel(const at::Tensor& self, << weight_num << " and channel size = " << channel_size; } - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); - XLATensorPtr weight_tensor = bridge::GetXlaTensor(weight); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + XLATensorPtr weight_tensor = GetValueOrThrow(bridge::GetXlaTensor(weight)); return bridge::AtenFromXlaTensor( tensor_methods::prelu(self_tensor, weight_tensor)); @@ -2836,9 +2908,10 @@ std::tuple XLANativeFunctions::_prelu_kernel_backward( const at::Tensor& weight) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr grad_output_tensor = bridge::GetXlaTensor(grad_output); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); - XLATensorPtr weight_tensor = bridge::GetXlaTensor(weight); + XLATensorPtr grad_output_tensor = + GetValueOrThrow(bridge::GetXlaTensor(grad_output)); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + XLATensorPtr weight_tensor = GetValueOrThrow(bridge::GetXlaTensor(weight)); auto outputs = tensor_methods::prelu_backward(grad_output_tensor, self_tensor, weight_tensor); @@ -2849,7 +2922,7 @@ std::tuple XLANativeFunctions::_prelu_kernel_backward( at::Tensor XLANativeFunctions::prod(const at::Tensor& self, std::optional dtype) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::prod( self_tensor, torch::lazy::Iota(self_tensor->shape().get().dimensions_size()), @@ -2861,9 +2934,9 @@ at::Tensor XLANativeFunctions::prod(const at::Tensor& self, int64_t dim, bool keepdim, std::optional dtype) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::prod(bridge::GetXlaTensor(self), {dim}, keepdim, - PromoteIntegralType(self.scalar_type(), dtype))); + return bridge::AtenFromXlaTensor(tensor_methods::prod( + GetValueOrThrow(bridge::GetXlaTensor(self)), {dim}, keepdim, + PromoteIntegralType(self.scalar_type(), dtype))); } void XLANativeFunctions::_propagate_xla_data(const at::Tensor& input, @@ -2874,8 +2947,8 @@ void XLANativeFunctions::_propagate_xla_data(const at::Tensor& input, // for in-place ops we have in hands. // 1) Aid XLA's InputOutputAlias. - auto input_tensor = bridge::GetXlaTensor(input); - auto output_tensor = bridge::GetXlaTensor(output); + auto input_tensor = GetValueOrThrow(bridge::GetXlaTensor(input)); + auto output_tensor = GetValueOrThrow(bridge::GetXlaTensor(output)); if (input_tensor->CurrentDataHandle() != nullptr || (input_tensor->CurrentIrValue().node != nullptr && torch_xla::DeviceData::Cast( @@ -2922,16 +2995,18 @@ void XLANativeFunctions::_propagate_xla_data(const at::Tensor& input, at::Tensor& XLANativeFunctions::put_(at::Tensor& self, const at::Tensor& index, const at::Tensor& source, bool accumulate) { - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); - tensor_methods::put_(self_tensor, bridge::GetXlaTensor(index), - bridge::GetXlaTensor(source), accumulate); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + tensor_methods::put_( + self_tensor, GetValueOrThrow(bridge::GetXlaTensor(index)), + GetValueOrThrow(bridge::GetXlaTensor(source)), accumulate); return self; } std::tuple XLANativeFunctions::qr( const at::Tensor& self, bool some) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto results = tensor_methods::qr(bridge::GetXlaTensor(self), some); + auto results = + tensor_methods::qr(GetValueOrThrow(bridge::GetXlaTensor(self)), some); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), bridge::AtenFromXlaTensor(std::get<1>(results))); } @@ -2946,7 +3021,7 @@ at::Tensor& XLANativeFunctions::random_( &xla_fallback, ATEN_OP2(random_, from)>::call(self, from, to, generator); } - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); at::ScalarType dtype = self_tensor->dtype(); // Prevent "to_val" from overflowing with at::ScalarType::Long. int64_t inc = (dtype == at::ScalarType::Long) ? 0 : 1; @@ -2966,7 +3041,7 @@ at::Tensor& XLANativeFunctions::random_( ATEN_OP2(random_, to)>::call(self, to, generator); } - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); XLA_CHECK_GT(to, 0); CheckRangeValues(self_tensor->dtype(), 0, to - 1); tensor_methods::random_(self_tensor, 0, to); @@ -2981,7 +3056,7 @@ at::Tensor& XLANativeFunctions::random_( return at::native::call_fallback_fn<&xla_fallback, ATEN_OP(random_)>::call( self, generator); } - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); at::ScalarType dtype = self_tensor->dtype(); // Prevent "to_val" from overflowing with at::ScalarType::Long. int64_t inc = (dtype == at::ScalarType::Long) ? 0 : 1; @@ -3018,7 +3093,8 @@ at::Tensor XLANativeFunctions::reflection_pad1d(const at::Tensor& self, at::IntArrayRef padding) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::reflection_pad1d( - bridge::GetXlaTensor(self), torch::lazy::ToVector(padding))); + GetValueOrThrow(bridge::GetXlaTensor(self)), + torch::lazy::ToVector(padding))); } at::Tensor XLANativeFunctions::reflection_pad1d_backward( @@ -3026,7 +3102,8 @@ at::Tensor XLANativeFunctions::reflection_pad1d_backward( at::IntArrayRef padding) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::reflection_pad1d_backward( - bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), + GetValueOrThrow(bridge::GetXlaTensor(grad_output)), + GetValueOrThrow(bridge::GetXlaTensor(self)), torch::lazy::ToVector(padding))); } @@ -3034,7 +3111,8 @@ at::Tensor XLANativeFunctions::reflection_pad2d(const at::Tensor& self, at::IntArrayRef padding) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::reflection_pad2d( - bridge::GetXlaTensor(self), torch::lazy::ToVector(padding))); + GetValueOrThrow(bridge::GetXlaTensor(self)), + torch::lazy::ToVector(padding))); } at::Tensor XLANativeFunctions::reflection_pad2d_backward( @@ -3042,7 +3120,8 @@ at::Tensor XLANativeFunctions::reflection_pad2d_backward( at::IntArrayRef padding) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::reflection_pad2d_backward( - bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), + GetValueOrThrow(bridge::GetXlaTensor(grad_output)), + GetValueOrThrow(bridge::GetXlaTensor(self)), torch::lazy::ToVector(padding))); } @@ -3050,7 +3129,8 @@ at::Tensor XLANativeFunctions::reflection_pad3d(const at::Tensor& self, at::IntArrayRef padding) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::reflection_pad3d( - bridge::GetXlaTensor(self), torch::lazy::ToVector(padding))); + GetValueOrThrow(bridge::GetXlaTensor(self)), + torch::lazy::ToVector(padding))); } at::Tensor XLANativeFunctions::reflection_pad3d_backward( @@ -3058,29 +3138,32 @@ at::Tensor XLANativeFunctions::reflection_pad3d_backward( at::IntArrayRef padding) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::reflection_pad3d_backward( - bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), + GetValueOrThrow(bridge::GetXlaTensor(grad_output)), + GetValueOrThrow(bridge::GetXlaTensor(self)), torch::lazy::ToVector(padding))); } at::Tensor XLANativeFunctions::remainder(const at::Tensor& self, const at::Tensor& other) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::remainder( - bridge::GetXlaTensor(self), bridge::GetXlaTensor(other))); + return bridge::AtenFromXlaTensor( + tensor_methods::remainder(GetValueOrThrow(bridge::GetXlaTensor(self)), + GetValueOrThrow(bridge::GetXlaTensor(other)))); } at::Tensor XLANativeFunctions::remainder(const at::Tensor& self, const at::Scalar& other) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::remainder(bridge::GetXlaTensor(self), other)); + return bridge::AtenFromXlaTensor(tensor_methods::remainder( + GetValueOrThrow(bridge::GetXlaTensor(self)), other)); } at::Tensor XLANativeFunctions::replication_pad1d(const at::Tensor& self, at::IntArrayRef padding) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::replication_pad1d( - bridge::GetXlaTensor(self), XlaHelpers::I64List(padding))); + GetValueOrThrow(bridge::GetXlaTensor(self)), + XlaHelpers::I64List(padding))); } at::Tensor XLANativeFunctions::replication_pad1d_backward( @@ -3088,7 +3171,8 @@ at::Tensor XLANativeFunctions::replication_pad1d_backward( at::IntArrayRef padding) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::replication_pad1d_backward( - bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), + GetValueOrThrow(bridge::GetXlaTensor(grad_output)), + GetValueOrThrow(bridge::GetXlaTensor(self)), XlaHelpers::I64List(padding))); } @@ -3096,7 +3180,8 @@ at::Tensor XLANativeFunctions::replication_pad2d(const at::Tensor& self, at::IntArrayRef padding) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::replication_pad2d( - bridge::GetXlaTensor(self), XlaHelpers::I64List(padding))); + GetValueOrThrow(bridge::GetXlaTensor(self)), + XlaHelpers::I64List(padding))); } at::Tensor XLANativeFunctions::replication_pad2d_backward( @@ -3104,7 +3189,8 @@ at::Tensor XLANativeFunctions::replication_pad2d_backward( at::IntArrayRef padding) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::replication_pad2d_backward( - bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), + GetValueOrThrow(bridge::GetXlaTensor(grad_output)), + GetValueOrThrow(bridge::GetXlaTensor(self)), XlaHelpers::I64List(padding))); } @@ -3112,7 +3198,8 @@ at::Tensor XLANativeFunctions::replication_pad3d(const at::Tensor& self, at::IntArrayRef padding) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::replication_pad3d( - bridge::GetXlaTensor(self), XlaHelpers::I64List(padding))); + GetValueOrThrow(bridge::GetXlaTensor(self)), + XlaHelpers::I64List(padding))); } at::Tensor XLANativeFunctions::replication_pad3d_backward( @@ -3120,7 +3207,8 @@ at::Tensor XLANativeFunctions::replication_pad3d_backward( at::IntArrayRef padding) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::replication_pad3d_backward( - bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), + GetValueOrThrow(bridge::GetXlaTensor(grad_output)), + GetValueOrThrow(bridge::GetXlaTensor(self)), XlaHelpers::I64List(padding))); } @@ -3128,7 +3216,7 @@ const at::Tensor& XLANativeFunctions::resize_( const at::Tensor& self, at::IntArrayRef size, std::optional /* memory_format */) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); tensor_methods::resize_(self_tensor, XlaHelpers::I64List(size)); return self; } @@ -3138,7 +3226,7 @@ at::Tensor XLANativeFunctions::roll(const at::Tensor& self, at::IntArrayRef dims) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::roll( - bridge::GetXlaTensor(self), XlaHelpers::I64List(shifts), + GetValueOrThrow(bridge::GetXlaTensor(self)), XlaHelpers::I64List(shifts), XlaHelpers::I64List(dims))); } @@ -3155,9 +3243,10 @@ at::Tensor XLANativeFunctions::rrelu_with_noise( upper, training, generator); } - XLATensorPtr noise_tensor = bridge::GetXlaTensor(noise); + XLATensorPtr noise_tensor = GetValueOrThrow(bridge::GetXlaTensor(noise)); return bridge::AtenFromXlaTensor(tensor_methods::rrelu_with_noise( - bridge::GetXlaTensor(self), noise_tensor, lower, upper, training)); + GetValueOrThrow(bridge::GetXlaTensor(self)), noise_tensor, lower, upper, + training)); } at::Tensor XLANativeFunctions::rrelu_with_noise_backward( @@ -3167,10 +3256,11 @@ at::Tensor XLANativeFunctions::rrelu_with_noise_backward( TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); double negative_slope = (lower.to() + upper.to()) / 2; XLA_CHECK(!self_is_result || negative_slope > 0.0); - XLATensorPtr noise_tensor = bridge::GetXlaTensor(noise); + XLATensorPtr noise_tensor = GetValueOrThrow(bridge::GetXlaTensor(noise)); return bridge::AtenFromXlaTensor(tensor_methods::rrelu_with_noise_backward( - bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), - noise_tensor, lower, upper, training)); + GetValueOrThrow(bridge::GetXlaTensor(grad_output)), + GetValueOrThrow(bridge::GetXlaTensor(self)), noise_tensor, lower, upper, + training)); } at::Tensor XLANativeFunctions::rsub(const at::Tensor& self, @@ -3200,15 +3290,15 @@ at::Tensor XLANativeFunctions::rsub(const at::Tensor& self, at::Tensor scatter_reduce_helper(const at::Tensor& self, int64_t dim, const at::Tensor& index, const at::Tensor& src, std::optional reduce) { - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); if (!reduce.has_value()) { - return bridge::AtenFromXlaTensor( - tensor_methods::scatter(self_tensor, dim, bridge::GetXlaTensor(index), - bridge::GetXlaTensor(src))); + return bridge::AtenFromXlaTensor(tensor_methods::scatter( + self_tensor, dim, GetValueOrThrow(bridge::GetXlaTensor(index)), + GetValueOrThrow(bridge::GetXlaTensor(src)))); } else if (*reduce == "add") { return bridge::AtenFromXlaTensor(tensor_methods::scatter_add( - self_tensor, dim, bridge::GetXlaTensor(index), - bridge::GetXlaTensor(src))); + self_tensor, dim, GetValueOrThrow(bridge::GetXlaTensor(index)), + GetValueOrThrow(bridge::GetXlaTensor(src)))); } else { // TODO: implement scatter_mul return at::native::call_fallback_fn< @@ -3222,13 +3312,13 @@ at::Tensor scatter_reduce_helper(const at::Tensor& self, int64_t dim, const at::Scalar& value, std::optional reduce) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); if (!reduce.has_value()) { return bridge::AtenFromXlaTensor(tensor_methods::scatter( - self_tensor, dim, bridge::GetXlaTensor(index), value)); + self_tensor, dim, GetValueOrThrow(bridge::GetXlaTensor(index)), value)); } else if (*reduce == "add") { return bridge::AtenFromXlaTensor(tensor_methods::scatter_add( - self_tensor, dim, bridge::GetXlaTensor(index), value)); + self_tensor, dim, GetValueOrThrow(bridge::GetXlaTensor(index)), value)); } else { // TODO: implement scatter_mul return at::native::call_fallback_fn< @@ -3284,8 +3374,9 @@ at::Tensor XLANativeFunctions::scatter_reduce( reduce == "amax") && include_self) { return bridge::AtenFromXlaTensor(tensor_methods::scatter_reduce( - bridge::GetXlaTensor(self), dim, bridge::GetXlaTensor(index), - bridge::GetXlaTensor(src), reduce, include_self)); + GetValueOrThrow(bridge::GetXlaTensor(self)), dim, + GetValueOrThrow(bridge::GetXlaTensor(index)), + GetValueOrThrow(bridge::GetXlaTensor(src)), reduce, include_self)); } else { return at::native::call_fallback_fn< &xla_fallback, ATEN_OP2(scatter_reduce, two)>::call(self, dim, index, @@ -3297,17 +3388,18 @@ at::Tensor XLANativeFunctions::scatter_reduce( at::Tensor XLANativeFunctions::select_copy(const at::Tensor& self, int64_t dim, int64_t index) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::select(bridge::GetXlaTensor(self), dim, index)); + return bridge::AtenFromXlaTensor(tensor_methods::select( + GetValueOrThrow(bridge::GetXlaTensor(self)), dim, index)); } at::Tensor XLANativeFunctions::select_scatter(const at::Tensor& base, const at::Tensor& mutated_view, int64_t dim, int64_t index) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto base_tensor = bridge::GetXlaTensor(base); + auto base_tensor = GetValueOrThrow(bridge::GetXlaTensor(base)); auto base_tensor_shape = base_tensor->shape(); - auto mutated_view_tensor = bridge::GetXlaTensor(mutated_view); + auto mutated_view_tensor = + GetValueOrThrow(bridge::GetXlaTensor(mutated_view)); auto mutated_view_tensor_shape = mutated_view_tensor->shape(); auto common_device = torch_xla::bridge::GetXlaDevice(base); @@ -3333,7 +3425,7 @@ at::Tensor XLANativeFunctions::select_scatter(const at::Tensor& base, // TODO(JackCaoG): Remove after elu being codegened at::Tensor& XLANativeFunctions::selu_(at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); tensor_methods::selu_(self_tensor); return self; } @@ -3341,8 +3433,8 @@ at::Tensor& XLANativeFunctions::selu_(at::Tensor& self) { at::Tensor& XLANativeFunctions::set_(at::Tensor& self, const at::Tensor& source) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr source_tensor = bridge::GetXlaTensor(source); - bridge::ReplaceXlaTensor(self, source_tensor); + XLATensorPtr source_tensor = GetValueOrThrow(bridge::GetXlaTensor(source)); + MaybeThrow(bridge::ReplaceXlaTensor(self, source_tensor)); return self; } @@ -3350,7 +3442,8 @@ at::Tensor XLANativeFunctions::sigmoid_backward(const at::Tensor& grad_output, const at::Tensor& output) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::sigmoid_backward( - bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(output))); + GetValueOrThrow(bridge::GetXlaTensor(grad_output)), + GetValueOrThrow(bridge::GetXlaTensor(output)))); } at::Tensor XLANativeFunctions::slice_copy(const at::Tensor& self, int64_t dim, @@ -3361,8 +3454,8 @@ at::Tensor XLANativeFunctions::slice_copy(const at::Tensor& self, int64_t dim, int64_t start_val = start.has_value() ? start.value() : 0; int64_t end_val = end.has_value() ? end.value() : INT64_MAX; return bridge::AtenFromXlaTensor(bridge::SetBaseTensor( - tensor_methods::slice(bridge::GetXlaTensor(self), dim, start_val, end_val, - step), + tensor_methods::slice(GetValueOrThrow(bridge::GetXlaTensor(self)), dim, + start_val, end_val, step), self)); } @@ -3370,8 +3463,8 @@ at::Tensor XLANativeFunctions::slice_scatter( const at::Tensor& base, const at::Tensor& mutated_view, int64_t dim, std::optional start, std::optional end, int64_t step) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto base_ = bridge::GetXlaTensor(base); - auto mutated_view_ = bridge::GetXlaTensor(mutated_view); + auto base_ = GetValueOrThrow(bridge::GetXlaTensor(base)); + auto mutated_view_ = GetValueOrThrow(bridge::GetXlaTensor(mutated_view)); int64_t start_val = start.has_value() ? start.value() : 0; int64_t end_val = end.has_value() ? end.value() : INT64_MAX; @@ -3401,8 +3494,8 @@ at::Tensor XLANativeFunctions::smooth_l1_loss(const at::Tensor& self, int64_t reduction, double beta) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::smooth_l1_loss( - bridge::GetXlaTensor(self), bridge::GetXlaTensor(target), reduction, - beta)); + GetValueOrThrow(bridge::GetXlaTensor(self)), + GetValueOrThrow(bridge::GetXlaTensor(target)), reduction, beta)); } at::Tensor XLANativeFunctions::smooth_l1_loss_backward( @@ -3410,16 +3503,17 @@ at::Tensor XLANativeFunctions::smooth_l1_loss_backward( const at::Tensor& target, int64_t reduction, double beta) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::smooth_l1_loss_backward( - bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), - bridge::GetXlaTensor(target), reduction, beta)); + GetValueOrThrow(bridge::GetXlaTensor(grad_output)), + GetValueOrThrow(bridge::GetXlaTensor(self)), + GetValueOrThrow(bridge::GetXlaTensor(target)), reduction, beta)); } at::Tensor XLANativeFunctions::softplus(const at::Tensor& self, const at::Scalar& beta, const at::Scalar& threshold) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::softplus(bridge::GetXlaTensor(self), beta, threshold)); + return bridge::AtenFromXlaTensor(tensor_methods::softplus( + GetValueOrThrow(bridge::GetXlaTensor(self)), beta, threshold)); } at::Tensor XLANativeFunctions::softplus_backward(const at::Tensor& grad_output, @@ -3428,16 +3522,16 @@ at::Tensor XLANativeFunctions::softplus_backward(const at::Tensor& grad_output, const at::Scalar& threshold) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::softplus_backward( - bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), beta, - threshold)); + GetValueOrThrow(bridge::GetXlaTensor(grad_output)), + GetValueOrThrow(bridge::GetXlaTensor(self)), beta, threshold)); } std::tuple XLANativeFunctions::sort( const at::Tensor& self, int64_t dim, bool descending) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto results = - tensor_methods::topk(bridge::GetXlaTensor(self), self.size(dim), dim, - descending, /*sorted=*/true, /*stable=*/false); + auto results = tensor_methods::topk( + GetValueOrThrow(bridge::GetXlaTensor(self)), self.size(dim), dim, + descending, /*sorted=*/true, /*stable=*/false); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), bridge::AtenFromXlaTensor(std::get<1>(results))); } @@ -3447,7 +3541,8 @@ std::tuple XLANativeFunctions::sort( bool descending) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); auto results = tensor_methods::topk( - bridge::GetXlaTensor(self), self.size(dim), dim, descending, + GetValueOrThrow(bridge::GetXlaTensor(self)), self.size(dim), dim, + descending, /*sorted=*/false, /*stable=*/stable.has_value() ? stable.value() : false); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), @@ -3458,8 +3553,8 @@ std::vector XLANativeFunctions::split_copy(const at::Tensor& self, int64_t split_size, int64_t dim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto xla_tensors = - tensor_methods::split(bridge::GetXlaTensor(self), split_size, dim); + auto xla_tensors = tensor_methods::split( + GetValueOrThrow(bridge::GetXlaTensor(self)), split_size, dim); return bridge::AtenFromXlaTensors(xla_tensors); } @@ -3467,28 +3562,30 @@ std::vector XLANativeFunctions::split_with_sizes_copy( const at::Tensor& self, at::IntArrayRef split_sizes, int64_t dim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); auto xla_tensors = tensor_methods::split_with_sizes( - bridge::GetXlaTensor(self), XlaHelpers::I64List(split_sizes), dim); + GetValueOrThrow(bridge::GetXlaTensor(self)), + XlaHelpers::I64List(split_sizes), dim); return bridge::AtenFromXlaTensors(xla_tensors); } at::Tensor XLANativeFunctions::squeeze_copy(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( - tensor_methods::squeeze(bridge::GetXlaTensor(self))); + tensor_methods::squeeze(GetValueOrThrow(bridge::GetXlaTensor(self)))); } at::Tensor XLANativeFunctions::squeeze_copy(const at::Tensor& self, int64_t dim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::squeeze(bridge::GetXlaTensor(self), dim)); + return bridge::AtenFromXlaTensor(tensor_methods::squeeze( + GetValueOrThrow(bridge::GetXlaTensor(self)), dim)); } at::Tensor XLANativeFunctions::squeeze_copy(const at::Tensor& self, at::IntArrayRef dim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::squeeze( - bridge::GetXlaTensor(self), torch::lazy::ToVector(dim))); + return bridge::AtenFromXlaTensor( + tensor_methods::squeeze(GetValueOrThrow(bridge::GetXlaTensor(self)), + torch::lazy::ToVector(dim))); } at::Tensor XLANativeFunctions::stack(at::TensorList tensors, int64_t dim) { @@ -3497,13 +3594,13 @@ at::Tensor XLANativeFunctions::stack(at::TensorList tensors, int64_t dim) { std::vector c_tensors(tensors.size()); std::transform(tensors.begin(), tensors.end(), c_tensors.begin(), [=](const at::Tensor& t) { return t.to(result_type); }); - return bridge::AtenFromXlaTensor( - tensor_methods::stack(bridge::GetXlaTensors(c_tensors), dim)); + return bridge::AtenFromXlaTensor(tensor_methods::stack( + GetValueOrThrow(bridge::GetXlaTensors(c_tensors)), dim)); } at::Tensor XLANativeFunctions::std(const at::Tensor& self, bool unbiased) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::std( self_tensor, torch::lazy::Iota(self_tensor->shape().get().dimensions_size()), @@ -3514,7 +3611,7 @@ at::Tensor XLANativeFunctions::std(const at::Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::std( self_tensor, dim ? torch::lazy::ToVector(*dim) @@ -3528,7 +3625,7 @@ at::Tensor XLANativeFunctions::std(const at::Tensor& self, const std::optional& correction, bool keepdim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::std( self_tensor, dim ? torch::lazy::ToVector(*dim) @@ -3541,7 +3638,7 @@ std::tuple XLANativeFunctions::std_mean( const at::Tensor& self, at::OptionalIntArrayRef dim, const std::optional& correction, bool keepdim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); auto results = tensor_methods::std_mean( self_tensor, dim ? torch::lazy::ToVector(*dim) @@ -3587,7 +3684,7 @@ at::Tensor XLANativeFunctions::sub(const at::Tensor& self, at::Tensor XLANativeFunctions::sum(const at::Tensor& self, std::optional dtype) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::sum( self_tensor, torch::lazy::Iota(self_tensor->shape().get().dimensions_size()), @@ -3598,7 +3695,7 @@ at::Tensor XLANativeFunctions::sum(const at::Tensor& self, at::OptionalIntArrayRef dim, bool keepdim, std::optional dtype) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::sum( self_tensor, dim ? torch::lazy::ToVector(*dim) @@ -3610,8 +3707,8 @@ at::Tensor XLANativeFunctions::sum(const at::Tensor& self, std::tuple XLANativeFunctions::svd( const at::Tensor& self, bool some, bool compute_uv) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto results = - tensor_methods::svd(bridge::GetXlaTensor(self), some, compute_uv); + auto results = tensor_methods::svd( + GetValueOrThrow(bridge::GetXlaTensor(self)), some, compute_uv); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), bridge::AtenFromXlaTensor(std::get<1>(results)), bridge::AtenFromXlaTensor(std::get<2>(results))); @@ -3619,23 +3716,25 @@ std::tuple XLANativeFunctions::svd( at::Tensor XLANativeFunctions::t_copy(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::transpose(bridge::GetXlaTensor(self), 0, 1)); + return bridge::AtenFromXlaTensor(tensor_methods::transpose( + GetValueOrThrow(bridge::GetXlaTensor(self)), 0, 1)); } at::Tensor XLANativeFunctions::tanh_backward(const at::Tensor& grad_output, const at::Tensor& output) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::tanh_backward( - bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(output))); + GetValueOrThrow(bridge::GetXlaTensor(grad_output)), + GetValueOrThrow(bridge::GetXlaTensor(output)))); } at::Tensor XLANativeFunctions::threshold(const at::Tensor& self, const at::Scalar& threshold, const at::Scalar& value) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::threshold( - bridge::GetXlaTensor(self), threshold.to(), value.to())); + return bridge::AtenFromXlaTensor( + tensor_methods::threshold(GetValueOrThrow(bridge::GetXlaTensor(self)), + threshold.to(), value.to())); } at::Tensor XLANativeFunctions::threshold_backward(const at::Tensor& grad_output, @@ -3643,15 +3742,16 @@ at::Tensor XLANativeFunctions::threshold_backward(const at::Tensor& grad_output, const at::Scalar& threshold) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::threshold_backward( - bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), - threshold.to())); + GetValueOrThrow(bridge::GetXlaTensor(grad_output)), + GetValueOrThrow(bridge::GetXlaTensor(self)), threshold.to())); } std::tuple XLANativeFunctions::topk( const at::Tensor& self, int64_t k, int64_t dim, bool largest, bool sorted) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto results = tensor_methods::topk(bridge::GetXlaTensor(self), k, dim, - largest, sorted, /*stable=*/false); + auto results = + tensor_methods::topk(GetValueOrThrow(bridge::GetXlaTensor(self)), k, dim, + largest, sorted, /*stable=*/false); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), bridge::AtenFromXlaTensor(std::get<1>(results))); } @@ -3659,14 +3759,14 @@ std::tuple XLANativeFunctions::topk( at::Tensor XLANativeFunctions::trace(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( - tensor_methods::trace(bridge::GetXlaTensor(self))); + tensor_methods::trace(GetValueOrThrow(bridge::GetXlaTensor(self)))); } at::Tensor XLANativeFunctions::transpose_copy(const at::Tensor& self, int64_t dim0, int64_t dim1) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::transpose(bridge::GetXlaTensor(self), dim0, dim1)); + return bridge::AtenFromXlaTensor(tensor_methods::transpose( + GetValueOrThrow(bridge::GetXlaTensor(self)), dim0, dim1)); } std::tuple XLANativeFunctions::triangular_solve( @@ -3676,8 +3776,9 @@ std::tuple XLANativeFunctions::triangular_solve( // Currently, ATen doesn't have a left_side option. Once this // is added, this API will have to be changed. auto results = tensor_methods::triangular_solve( - bridge::GetXlaTensor(b), bridge::GetXlaTensor(A), /*left_side=*/true, - upper, transpose, unitriangular); + GetValueOrThrow(bridge::GetXlaTensor(b)), + GetValueOrThrow(bridge::GetXlaTensor(A)), /*left_side=*/true, upper, + transpose, unitriangular); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), bridge::AtenFromXlaTensor(std::get<1>(results))); } @@ -3686,7 +3787,7 @@ std::vector XLANativeFunctions::unbind_copy(const at::Tensor& self, int64_t dim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensors( - tensor_methods::unbind(bridge::GetXlaTensor(self), dim)); + tensor_methods::unbind(GetValueOrThrow(bridge::GetXlaTensor(self)), dim)); } at::Tensor& XLANativeFunctions::uniform_( @@ -3697,7 +3798,7 @@ at::Tensor& XLANativeFunctions::uniform_( return at::native::call_fallback_fn<&xla_fallback, ATEN_OP(uniform_)>::call( self, from, to, generator); } - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); tensor_methods::uniform_(self_tensor, from, to); return self; } @@ -3705,15 +3806,15 @@ at::Tensor& XLANativeFunctions::uniform_( at::Tensor XLANativeFunctions::unsqueeze_copy(const at::Tensor& self, int64_t dim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::unsqueeze(bridge::GetXlaTensor(self), dim)); + return bridge::AtenFromXlaTensor(tensor_methods::unsqueeze( + GetValueOrThrow(bridge::GetXlaTensor(self)), dim)); } at::Tensor XLANativeFunctions::upsample_bilinear2d( const at::Tensor& self, at::IntArrayRef output_size, bool align_corners, std::optional scales_h, std::optional scales_w) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); absl::Span input_dims = self_tensor->shape().get().dimensions(); std::vector scaled_output_size = @@ -3736,7 +3837,8 @@ at::Tensor XLANativeFunctions::upsample_bilinear2d_backward( at::IntArrayRef input_size, bool align_corners, std::optional scales_h, std::optional scales_w) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr grad_output_tensor = bridge::GetXlaTensor(grad_output); + XLATensorPtr grad_output_tensor = + GetValueOrThrow(bridge::GetXlaTensor(grad_output)); // Only the XLA TPU backend for now implements the CustomCall required by // our XLA lowering. XlaDeviceType hw_type = @@ -3768,7 +3870,7 @@ at::Tensor XLANativeFunctions::upsample_nearest2d( const at::Tensor& self, at::IntArrayRef output_size, std::optional scales_h, std::optional scales_w) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); absl::Span input_dims = self_tensor->shape().get().dimensions(); std::vector scaled_output_size = @@ -3791,7 +3893,8 @@ at::Tensor XLANativeFunctions::upsample_nearest2d_backward( at::IntArrayRef input_size, std::optional scales_h, std::optional scales_w) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr grad_output_tensor = bridge::GetXlaTensor(grad_output); + XLATensorPtr grad_output_tensor = + GetValueOrThrow(bridge::GetXlaTensor(grad_output)); // Only the XLA TPU backend for now implements the CustomCall required by // our XLA lowering. XlaDeviceType hw_type = @@ -3825,20 +3928,23 @@ at::Tensor XLANativeFunctions::var(const at::Tensor& self, const std::optional& correction, bool keepdim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); - return bridge::AtenFromXlaTensor(tensor_methods::var( - self_tensor, - dim ? XlaHelpers::I64List(*dim) - : torch::lazy::Iota( - bridge::GetXlaTensor(self)->shape().get().dimensions_size()), - correction ? correction->toDouble() : 1.0, keepdim)); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor( + tensor_methods::var(self_tensor, + dim ? XlaHelpers::I64List(*dim) + : torch::lazy::Iota( + GetValueOrThrow(bridge::GetXlaTensor(self)) + ->shape() + .get() + .dimensions_size()), + correction ? correction->toDouble() : 1.0, keepdim)); } std::tuple XLANativeFunctions::var_mean( const at::Tensor& self, at::OptionalIntArrayRef dim, const std::optional& correction, bool keepdim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); auto results = tensor_methods::var_mean( self_tensor, dim ? torch::lazy::ToVector(*dim) @@ -3859,7 +3965,7 @@ at::Tensor XLANativeFunctions::view_as_complex_copy(const at::Tensor& self) { "tensors, but got a tensor of scalar type: " << self.scalar_type(); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor( tensor_methods::view_as_complex_copy(self_tensor)); } @@ -3871,7 +3977,7 @@ at::Tensor XLANativeFunctions::view_as_real_copy(const at::Tensor& self) { "tensors, but got a tensor of scalar type: " << self.scalar_type(); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor( tensor_methods::view_as_real_copy(self_tensor)); } @@ -3881,7 +3987,7 @@ at::Tensor XLANativeFunctions::view_copy_symint(const at::Tensor& self, TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); std::optional int_shape = c10::asIntArrayRefSlowOpt(shape); bool input_shape_static = int_shape.has_value(); - XLATensorPtr xla_input = bridge::GetXlaTensor(self); + XLATensorPtr xla_input = GetValueOrThrow(bridge::GetXlaTensor(self)); bool input_has_dyn_shape = xla_input->shape().get().is_dynamic(); XLA_CHECK(!(input_has_dyn_shape && input_shape_static)) @@ -3899,14 +4005,15 @@ at::Tensor XLANativeFunctions::where(const at::Tensor& condition, c10::MaybeOwned b_condition, b_self, b_other; std::tie(b_condition, b_self, b_other) = xla_expand_outplace(condition, self, other, "where"); - return bridge::AtenFromXlaTensor(tensor_methods::where( - bridge::GetXlaTensor(*b_condition), bridge::GetXlaTensor(*b_self), - bridge::GetXlaTensor(*b_other))); + return bridge::AtenFromXlaTensor( + tensor_methods::where(GetValueOrThrow(bridge::GetXlaTensor(*b_condition)), + GetValueOrThrow(bridge::GetXlaTensor(*b_self)), + GetValueOrThrow(bridge::GetXlaTensor(*b_other)))); } at::Tensor& XLANativeFunctions::zero_(at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); tensor_methods::zero_(self_tensor); return self; } @@ -3919,7 +4026,7 @@ std::tuple XLANativeFunctions::_linalg_svd( // As per https://pytorch.org/docs/stable/generated/torch.svd.html, // The second boolean argument is exactly opposite between // torch::svd and torch::_linalg_svd, hence the negation of full_matrices. - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); auto results = tensor_methods::svd(self_tensor, !full_matrices, compute_uv); auto u = std::get<0>(results); auto s = std::get<1>(results); @@ -3940,7 +4047,7 @@ std::tuple XLANativeFunctions::_linalg_svd( at::Scalar XLANativeFunctions::_local_scalar_dense(const at::Tensor& self) { if (DebugUtil::ExperimentEnabled("early_sync")) { // sync tensors in order to save computation when step is marked later. - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); XLAGraphExecutor::Get()->SyncLiveTensorsGraph(&self_tensor->GetDevice(), /*devices=*/{}, /*wait=*/true); @@ -3983,17 +4090,21 @@ at::Tensor XLANativeFunctions::_cdist_forward( TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLA_CHECK(p >= 0) << "p value for the p-norm distance must be >= 0"; return bridge::AtenFromXlaTensor(tensor_methods::cdist_forward( - bridge::GetXlaTensor(x1), bridge::GetXlaTensor(x2), p)); + GetValueOrThrow(bridge::GetXlaTensor(x1)), + GetValueOrThrow(bridge::GetXlaTensor(x2)), p)); } at::Tensor XLANativeFunctions::_pdist_forward(const at::Tensor& self, double p) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLA_CHECK(p >= 0) << "p value for the p-norm distance must be >= 0"; - XLA_CHECK(bridge::GetXlaTensor(self)->shape().get().dimensions_size() == 2) + XLA_CHECK(GetValueOrThrow(bridge::GetXlaTensor(self)) + ->shape() + .get() + .dimensions_size() == 2) << "pdist only support 2d dimension"; - return bridge::AtenFromXlaTensor( - tensor_methods::pdist_forward(bridge::GetXlaTensor(self), p)); + return bridge::AtenFromXlaTensor(tensor_methods::pdist_forward( + GetValueOrThrow(bridge::GetXlaTensor(self)), p)); } // All of the below ops correspond to CompositeExplicitAutograd kernels from @@ -4077,7 +4188,7 @@ XLANativeFunctions::convolution_backward( at::Tensor XLANativeFunctions::count_nonzero(const at::Tensor& self, std::optional dim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr xla_tensor = bridge::GetXlaTensor(self); + XLATensorPtr xla_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); std::vector dims; if (dim) { dims = torch::lazy::GetCanonicalDimensionIndices( @@ -4090,7 +4201,7 @@ at::Tensor XLANativeFunctions::count_nonzero(const at::Tensor& self, at::Tensor XLANativeFunctions::count_nonzero(const at::Tensor& self, at::IntArrayRef dim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr xla_tensor = bridge::GetXlaTensor(self); + XLATensorPtr xla_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); std::vector canonical_dims = torch::lazy::GetCanonicalDimensionIndices( @@ -4129,7 +4240,8 @@ at::Tensor XLANativeFunctions::embedding_symint(const at::Tensor& weight, // TODO: We need to make use of the TPU embedding core here eventually. TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::embedding( - bridge::GetXlaTensor(weight), bridge::GetXlaTensor(indices))); + GetValueOrThrow(bridge::GetXlaTensor(weight)), + GetValueOrThrow(bridge::GetXlaTensor(indices)))); } at::Tensor XLANativeFunctions::_euclidean_dist(const at::Tensor& x1, @@ -4167,7 +4279,7 @@ at::Tensor XLANativeFunctions::narrow_copy_symint(const at::Tensor& self, at::Tensor XLANativeFunctions::pixel_shuffle(const at::Tensor& self, int64_t upscale_factor) { return bridge::AtenFromXlaTensor(tensor_methods::pixel_shuffle( - bridge::GetXlaTensor(self), upscale_factor)); + GetValueOrThrow(bridge::GetXlaTensor(self)), upscale_factor)); } at::Tensor XLANativeFunctions::pixel_unshuffle(const at::Tensor& self, @@ -4252,7 +4364,7 @@ at::Tensor XLANativeFunctions::linalg_vector_norm( TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLA_CHECK(at::isFloatingType(self.scalar_type())) << "Input must be a floating type"; - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::linalg_vector_norm( self_tensor, ord, dim ? torch::lazy::ToVector(*dim) @@ -4298,7 +4410,7 @@ at::Tensor XLANativeFunctions::as_strided( const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride, std::optional storage_offset) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); auto xsize = XlaHelpers::I64List(size); auto xstride = XlaHelpers::I64List(stride); if (!AsStrided::StrideIsSupported(self_tensor->shape(), xsize, xstride, @@ -4316,7 +4428,7 @@ const at::Tensor& XLANativeFunctions::as_strided_( const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride, std::optional storage_offset) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); auto xsize = XlaHelpers::I64List(size); auto xstride = XlaHelpers::I64List(stride); if (!AsStrided::StrideIsSupported(self_tensor->shape(), xsize, xstride, @@ -4333,8 +4445,8 @@ const at::Tensor& XLANativeFunctions::as_strided_( at::Tensor XLANativeFunctions::diagonal(const at::Tensor& self, int64_t offset, int64_t dim1, int64_t dim2) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::diagonal(bridge::GetXlaTensor(self), offset, dim1, dim2)); + return bridge::AtenFromXlaTensor(tensor_methods::diagonal( + GetValueOrThrow(bridge::GetXlaTensor(self)), offset, dim1, dim2)); } at::Tensor XLANativeFunctions::expand_symint(const at::Tensor& self, @@ -4343,13 +4455,14 @@ at::Tensor XLANativeFunctions::expand_symint(const at::Tensor& self, TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); std::optional size = c10::asIntArrayRefSlowOpt(sym_size); if (size.has_value()) { - return bridge::AtenFromXlaTensor(tensor_methods::expand( - bridge::GetXlaTensor(self), torch::lazy::ToVector(*size))); + return bridge::AtenFromXlaTensor( + tensor_methods::expand(GetValueOrThrow(bridge::GetXlaTensor(self)), + torch::lazy::ToVector(*size))); } else { // at least one of the dimension is symbolic, use the sym_int version of the // node - return bridge::AtenFromXlaTensor( - tensor_methods::expand_symint(bridge::GetXlaTensor(self), sym_size)); + return bridge::AtenFromXlaTensor(tensor_methods::expand_symint( + GetValueOrThrow(bridge::GetXlaTensor(self)), sym_size)); } } @@ -4361,7 +4474,7 @@ at::Tensor XLANativeFunctions::view_symint(const at::Tensor& self, auto size = C10_AS_INTARRAYREF_SLOW(sym_size); TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor(tensor_methods::view( - bridge::GetXlaTensor(self), XlaHelpers::I64List(size))); + GetValueOrThrow(bridge::GetXlaTensor(self)), XlaHelpers::I64List(size))); } } // namespace torch_xla diff --git a/torch_xla/csrc/cross_replica_reduces.cpp b/torch_xla/csrc/cross_replica_reduces.cpp index 56eeac4e6a41..aaa982e6ebd3 100644 --- a/torch_xla/csrc/cross_replica_reduces.cpp +++ b/torch_xla/csrc/cross_replica_reduces.cpp @@ -116,7 +116,7 @@ std::shared_ptr CreateToken( at::Tensor all_reduce(const at::Tensor& self, std::string reduceOp, std::string /*group_name*/) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto self_tensor = bridge::GetXlaTensor(self); + auto self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); // TODO(alanwaketan): Use group_name to generate groups. Currently we just // use {} as a workaround. Scale is always 1.0 here, and we always pin // layout. @@ -270,7 +270,7 @@ AllGatherResult BuildAllGather(xla::XlaOp input, xla::XlaOp token, int64_t dim, at::Tensor all_gather_into_tensor(const at::Tensor& self, int64_t group_size, std::string group_name) { TORCH_LAZY_FN_COUNTER("xla::"); - auto self_tensor = bridge::GetXlaTensor(self); + auto self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); std::vector all_groups(group_size); std::iota(all_groups.begin(), all_groups.end(), 0); auto result = tensor_methods::all_gather(self_tensor, 0, group_size, @@ -349,9 +349,9 @@ at::Tensor all_to_all_single(const at::Tensor& input, } XLATensorPtr result_ptr; torch::lazy::Value new_token; - std::tie(result_ptr, new_token) = - tensor_methods::all_to_all(bridge::GetXlaTensor(input), token, 0, 0, - split_count, {all_groups}, pin_layout); + std::tie(result_ptr, new_token) = tensor_methods::all_to_all( + GetValueOrThrow(bridge::GetXlaTensor(input)), token, 0, 0, split_count, + {all_groups}, pin_layout); at::Tensor result = bridge::AtenFromXlaTensor(std::move(result_ptr)); at::Tensor result_with_grad = torch::autograd::make_variable( @@ -481,7 +481,7 @@ xla::XlaOp BuildReduceScatter(AllReduceType reduce_type, xla::XlaOp input, at::Tensor reduce_scatter_tensor(const at::Tensor& input, std::string reduce_op, int64_t group_size, std::string group_name) { TORCH_LAZY_FN_COUNTER("xla::"); - auto self = bridge::GetXlaTensor(input); + auto self = GetValueOrThrow(bridge::GetXlaTensor(input)); std::vector all_groups(group_size); std::iota(all_groups.begin(), all_groups.end(), 0); int64_t shard_count = group_size; diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 8873fb434e0f..bd4152aee811 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -238,7 +238,7 @@ std::string GetTensorsDump( std::vector nodes; std::vector values; for (auto& tensor : tensors) { - XLATensorPtr xtensor = bridge::GetXlaTensor(tensor); + XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(tensor)); values.push_back(xtensor->GetIrValue()); nodes.push_back(values.back().node.get()); } @@ -274,13 +274,13 @@ std::vector GetXlaTensors(const std::vector& tensors, xtensors.reserve(tensors.size()); if (want_all) { for (auto& tensor : tensors) { - xtensors.push_back(bridge::GetXlaTensor(tensor)); + xtensors.push_back(GetValueOrThrow(bridge::GetXlaTensor(tensor))); } } else { for (auto& tensor : tensors) { - auto xtensor = bridge::TryGetXlaTensor(tensor); - if (xtensor) { - xtensors.push_back(xtensor); + auto xtensor_status = bridge::GetXlaTensor(tensor); + if (xtensor_status.ok()) { + xtensors.push_back(std::move(xtensor_status).value()); } } } @@ -288,7 +288,7 @@ std::vector GetXlaTensors(const std::vector& tensors, } bool IsNonDeviceDataIR(const at::Tensor& tensor) { - XLATensorPtr xtensor = bridge::GetXlaTensor(tensor); + XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(tensor)); return xtensor->CurrentIrValue() && !DeviceData::Cast(xtensor->CurrentIrValue().node.get()); } @@ -316,10 +316,12 @@ std::vector XlaCustomCall( if (is_tpu) { return bridge::AtenFromXlaTensors(tensor_methods::tpu_custom_call( - bridge::GetXlaTensors(inputs), payload, output_shapes, dtypes)); + GetValueOrThrow(bridge::GetXlaTensors(inputs)), payload, output_shapes, + dtypes)); } return bridge::AtenFromXlaTensors(tensor_methods::gpu_custom_call( - bridge::GetXlaTensors(inputs), payload, output_shapes, dtypes)); + GetValueOrThrow(bridge::GetXlaTensors(inputs)), payload, output_shapes, + dtypes)); } std::vector> ExtractXlaDotGeneralDimVectors( @@ -371,7 +373,8 @@ at::Tensor XlaDotGeneral(const at::Tensor& lhs, const at::Tensor& rhs, ->scalar_type; } return bridge::AtenFromXlaTensor(tensor_methods::xla_dot_general( - bridge::GetXlaTensor(lhs), bridge::GetXlaTensor(rhs), dim_vectors, + GetValueOrThrow(bridge::GetXlaTensor(lhs)), + GetValueOrThrow(bridge::GetXlaTensor(rhs)), dim_vectors, at_preferred_element_type)); } @@ -398,7 +401,7 @@ void AllReduceInPlace(const std::string& reduce_type, replica_groups, pin_layout); std::vector new_xtensors = GetXlaTensors(tensors, /*want_all=*/true); - bridge::ReplaceXlaTensor(tensors, new_xtensors); + MaybeThrow(bridge::ReplaceXlaTensor(tensors, new_xtensors)); } at::Tensor AllReduce(const std::string& reduce_type, const at::Tensor& input, @@ -406,9 +409,9 @@ at::Tensor AllReduce(const std::string& reduce_type, const at::Tensor& input, const std::vector>& replica_groups, bool pin_layout) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto result = tensor_methods::all_reduce(bridge::GetXlaTensor(input), - GetReduceType(reduce_type), scale, - replica_groups, pin_layout); + auto result = tensor_methods::all_reduce( + GetValueOrThrow(bridge::GetXlaTensor(input)), GetReduceType(reduce_type), + scale, replica_groups, pin_layout); return bridge::AtenFromXlaTensor(std::move(result)); } @@ -417,8 +420,8 @@ at::Tensor DynamicExpand(const at::Tensor& input, const at::Tensor& src_tensor, int src_dim, int target_dim) { XLATensorPtr result = tensor_methods::dynamic_expand( - bridge::GetXlaTensor(input), size, bridge::GetXlaTensor(src_tensor), - src_dim, target_dim); + GetValueOrThrow(bridge::GetXlaTensor(input)), size, + GetValueOrThrow(bridge::GetXlaTensor(src_tensor)), src_dim, target_dim); return bridge::AtenFromXlaTensor(std::move(result)); } @@ -427,15 +430,16 @@ at::Tensor DynamicView(const at::Tensor& input, const at::Tensor& src_tensor, int src_dim, int target_dim, float mul_scaler) { XLATensorPtr result = tensor_methods::dynamic_view( - bridge::GetXlaTensor(input), size, bridge::GetXlaTensor(src_tensor), - src_dim, target_dim, mul_scaler); + GetValueOrThrow(bridge::GetXlaTensor(input)), size, + GetValueOrThrow(bridge::GetXlaTensor(src_tensor)), src_dim, target_dim, + mul_scaler); return bridge::AtenFromXlaTensor(std::move(result)); } at::Tensor CastInt4(const at::Tensor& weight, const std::vector& int4_weight_values) { - auto result = tensor_methods::cast_int4(bridge::GetXlaTensor(weight), - int4_weight_values); + auto result = tensor_methods::cast_int4( + GetValueOrThrow(bridge::GetXlaTensor(weight)), int4_weight_values); return bridge::AtenFromXlaTensor(std::move(result)); } @@ -445,8 +449,8 @@ at::Tensor QuantizeTensor(const at::Tensor& input, int quant_min, int quant_max, const std::string& dtype, int axis) { auto result = tensor_methods::quantize_tensor( - bridge::GetXlaTensor(input), scale_list, zero_point_list, quant_min, - quant_max, dtype, axis); + GetValueOrThrow(bridge::GetXlaTensor(input)), scale_list, zero_point_list, + quant_min, quant_max, dtype, axis); return bridge::AtenFromXlaTensor(std::move(result)); } @@ -456,8 +460,8 @@ at::Tensor DequantizeTensor(const at::Tensor& input, int quant_min, int quant_max, const std::string& dtype, int axis) { auto result = tensor_methods::dequantize_tensor( - bridge::GetXlaTensor(input), scale_list, zero_point_list, quant_min, - quant_max, dtype, axis); + GetValueOrThrow(bridge::GetXlaTensor(input)), scale_list, zero_point_list, + quant_min, quant_max, dtype, axis); return bridge::AtenFromXlaTensor(std::move(result)); } @@ -472,9 +476,9 @@ std::pair> ReduceScatter( XLATensorPtr result; torch::lazy::Value new_token; std::tie(result, new_token) = tensor_methods::reduce_scatter( - bridge::GetXlaTensor(input), *token, GetReduceType(reduce_type), scale, - scatter_dim, shard_count, replica_groups, pin_layout, channel_id, - use_global_device_ids); + GetValueOrThrow(bridge::GetXlaTensor(input)), *token, + GetReduceType(reduce_type), scale, scatter_dim, shard_count, + replica_groups, pin_layout, channel_id, use_global_device_ids); return std::pair>( bridge::AtenFromXlaTensor(std::move(result)), std::make_shared(new_token)); @@ -486,11 +490,12 @@ std::shared_ptr ReduceScatterOut( int64_t scatter_dim, int64_t shard_count, const std::vector>& replica_groups, bool pin_layout) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr out = bridge::GetXlaTensor(output); + XLATensorPtr out = GetValueOrThrow(bridge::GetXlaTensor(output)); torch::lazy::Value new_token; new_token = tensor_methods::reduce_scatter_out( - out, bridge::GetXlaTensor(input), *token, GetReduceType(reduce_type), - scale, scatter_dim, shard_count, replica_groups, pin_layout); + out, GetValueOrThrow(bridge::GetXlaTensor(input)), *token, + GetReduceType(reduce_type), scale, scatter_dim, shard_count, + replica_groups, pin_layout); return std::make_shared(new_token); } @@ -537,8 +542,8 @@ at::Tensor AllGather(const at::Tensor& input, int64_t dim, int64_t shard_count, std::optional use_global_device_ids = std::nullopt) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); auto result = tensor_methods::all_gather( - bridge::GetXlaTensor(input), dim, shard_count, replica_groups, pin_layout, - channel_id, use_global_device_ids); + GetValueOrThrow(bridge::GetXlaTensor(input)), dim, shard_count, + replica_groups, pin_layout, channel_id, use_global_device_ids); return bridge::AtenFromXlaTensor(std::move(result)); } @@ -548,11 +553,11 @@ std::shared_ptr AllGatherOut( int64_t shard_count, const std::vector>& replica_groups, bool pin_layout) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr out = bridge::GetXlaTensor(output); + XLATensorPtr out = GetValueOrThrow(bridge::GetXlaTensor(output)); torch::lazy::Value new_token; - new_token = tensor_methods::all_gather_out(out, bridge::GetXlaTensor(input), - *token, dim, shard_count, - replica_groups, pin_layout); + new_token = tensor_methods::all_gather_out( + out, GetValueOrThrow(bridge::GetXlaTensor(input)), *token, dim, + shard_count, replica_groups, pin_layout); return std::make_shared(new_token); } @@ -598,8 +603,8 @@ std::pair> AllToAll( XLATensorPtr result; torch::lazy::Value new_token; std::tie(result, new_token) = tensor_methods::all_to_all( - bridge::GetXlaTensor(input), *token, split_dimension, concat_dimension, - split_count, replica_groups, pin_layout); + GetValueOrThrow(bridge::GetXlaTensor(input)), *token, split_dimension, + concat_dimension, split_count, replica_groups, pin_layout); return std::pair>( bridge::AtenFromXlaTensor(std::move(result)), std::make_shared(new_token)); @@ -611,7 +616,8 @@ std::pair> CollectivePermute( XLATensorPtr result; torch::lazy::Value new_token; std::tie(result, new_token) = tensor_methods::collective_permute( - bridge::GetXlaTensor(input), *token, source_target_pairs); + GetValueOrThrow(bridge::GetXlaTensor(input)), *token, + source_target_pairs); return std::pair>( bridge::AtenFromXlaTensor(std::move(result)), std::make_shared(new_token)); @@ -628,8 +634,8 @@ std::pair> Send( int64_t channel_id) { XLATensorPtr result; torch::lazy::Value new_token; - std::tie(result, new_token) = - tensor_methods::send(bridge::GetXlaTensor(input), *token, channel_id); + std::tie(result, new_token) = tensor_methods::send( + GetValueOrThrow(bridge::GetXlaTensor(input)), *token, channel_id); return {bridge::AtenFromXlaTensor(std::move(result)), std::make_shared(new_token)}; } @@ -637,7 +643,7 @@ std::pair> Send( std::pair> Recv( at::Tensor& output, const std::shared_ptr& token, int64_t channel_id) { - XLATensorPtr out = bridge::GetXlaTensor(output); + XLATensorPtr out = GetValueOrThrow(bridge::GetXlaTensor(output)); XLATensorPtr result; torch::lazy::Value new_token; std::tie(result, new_token) = tensor_methods::recv(out, *token, channel_id); @@ -713,10 +719,11 @@ std::string GetXLAShardingSpec(const XLATensorPtr xtensor) { } std::string GetXLATensorDebugInfo(const at::Tensor& tensor) { - auto xtensor = bridge::TryGetXlaTensor(tensor); - if (!xtensor) { + auto xtensor_status = bridge::GetXlaTensor(tensor); + if (!xtensor_status.ok()) { return "Not a XLATensor\n"; } + auto xtensor = std::move(xtensor_status).value(); std::stringstream ss; ss << "XLATensor {\n"; ss << "TensorID: " << xtensor->GetUniqueId() << "\n"; @@ -796,12 +803,12 @@ void ClearPendingIrs(const std::string& device_str) { } std::ptrdiff_t GetTensorViewAliasId(const at::Tensor& tensor) { - XLATensorPtr xtensor = bridge::GetXlaTensor(tensor); + XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(tensor)); return xtensor->GetViewAliasId(); } std::ptrdiff_t GetTensorId(const at::Tensor& tensor) { - XLATensorPtr xtensor = bridge::GetXlaTensor(tensor); + XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(tensor)); return xtensor->GetUniqueId(); } @@ -832,7 +839,7 @@ std::vector GetXlaTensorsFromAten( } at::Tensor GetXlaTensorDimensionSize(const at::Tensor& tensor, int64_t dim) { - XLATensorPtr xtensor = bridge::GetXlaTensor(tensor); + XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(tensor)); return bridge::AtenFromXlaTensor( tensor_methods::get_dimensions_size(xtensor, {dim})); } @@ -908,9 +915,9 @@ runtime::ComputationClient::ComputationPtr CreateComputationFromProto( xla::Shape GetTensorShape(const at::Tensor& tensor, const std::string& device_str) { - auto xtensor = bridge::TryGetXlaTensor(tensor); - if (xtensor) { - return xtensor->shape(); + auto xtensor_status = bridge::GetXlaTensor(tensor); + if (xtensor_status.ok()) { + return xtensor_status.value()->shape(); } torch::lazy::BackendDevice device = GetDeviceOrCurrent(device_str); return CreateComputationShapeFromTensor(tensor, &device); @@ -969,8 +976,8 @@ void MapXlaEnvVarsToLazy() { } at::Tensor MarkTensor(const at::Tensor& input, const std::string& info) { - XLATensorPtr result = - tensor_methods::mark_tensor(bridge::GetXlaTensor(input), info); + XLATensorPtr result = tensor_methods::mark_tensor( + GetValueOrThrow(bridge::GetXlaTensor(input)), info); return bridge::AtenFromXlaTensor(std::move(result)); } @@ -1186,7 +1193,8 @@ class PyLoweringContext { local_builder->GetProgramShape()->parameters_size(); int64_t additional_inputs_list_size = additional_inputs_list.size(); for (int64_t i = parameter_idx; i < additional_inputs_list_size; i++) { - XLATensorPtr xtensor = bridge::GetXlaTensor(additional_inputs_list[i]); + XLATensorPtr xtensor = + GetValueOrThrow(bridge::GetXlaTensor(additional_inputs_list[i])); xla::Shape shape = xtensor->shape().get(); xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, "UnusedArgumentsPlaceholder"); @@ -1282,7 +1290,7 @@ class PyLoweringContext { // remain parameters. int64_t GetTensorParameterId(at::Tensor tensor) { // Convert tensor into the backing lazy node - XLATensorPtr xtensor = bridge::GetXlaTensor(tensor); + XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(tensor)); torch::lazy::Value value = xtensor->GetIrValue(); const torch::lazy::Node* node = value.node.get(); if (node->op() != xla_device_data) { @@ -1588,12 +1596,10 @@ void InitXlaModuleBindings(py::module m) { }) .def("_get_xla_tensor_shape_type", [](const at::Tensor& tensor) -> std::string { - XLATensorPtr xla_tensor = bridge::TryGetXlaTensor(tensor); - if (xla_tensor) { - xla::Shape shape = xla_tensor->shape().get(); - return xla::primitive_util::LowercasePrimitiveTypeName( - shape.element_type()); - } + auto xla_tensor = GetValueOrThrow(bridge::GetXlaTensor(tensor)); + xla::Shape shape = xla_tensor->shape().get(); + return xla::primitive_util::LowercasePrimitiveTypeName( + shape.element_type()); }) .def( "_xla_tensors_from_aten", @@ -1688,7 +1694,8 @@ void InitXlaModuleBindings(py::module m) { }) .def("_xla_get_device_hw_type", [](const at::Tensor& tensor) { - XLATensorPtr xtensor = bridge::GetXlaTensor(tensor); + XLATensorPtr xtensor = + GetValueOrThrow(bridge::GetXlaTensor(tensor)); XlaDeviceType xla_device_type = static_cast(xtensor->GetDevice().type()); return DeviceType(xla_device_type).toString(); @@ -1823,8 +1830,8 @@ void InitXlaModuleBindings(py::module m) { std::vector> replica_groups = CreateReduceGroups(groups); auto result = tensor_methods::all_reduce( - bridge::GetXlaTensor(input), GetReduceType(reduce_type), scale, - std::move(replica_groups)); + GetValueOrThrow(bridge::GetXlaTensor(input)), + GetReduceType(reduce_type), scale, std::move(replica_groups)); return bridge::AtenFromXlaTensor(std::move(result)); }) .def( @@ -2032,8 +2039,9 @@ void InitXlaModuleBindings(py::module m) { std::vector> replica_groups = CreateReduceGroups(groups); auto result = tensor_methods::reduce_scatter( - bridge::GetXlaTensor(input), GetReduceType(reduce_type), scale, - scatter_dim, shard_count, replica_groups); + GetValueOrThrow(bridge::GetXlaTensor(input)), + GetReduceType(reduce_type), scale, scatter_dim, shard_count, + replica_groups); return bridge::AtenFromXlaTensor(std::move(result)); }) .def("_xla_reduce_scatter", @@ -2471,11 +2479,14 @@ void InitXlaModuleBindings(py::module m) { bool maximize) { { NoGilSection nogil; - XLATensorPtr found_inf_xla = bridge::GetXlaTensor(found_inf); - XLATensorPtr step_xla = bridge::GetXlaTensor(step); - XLATensorPtr param_xla = bridge::GetXlaTensor(param); - XLATensorPtr d_p_xla = bridge::GetXlaTensor(d_p); - XLATensorPtr buf_xla = bridge::GetXlaTensor(buf); + XLATensorPtr found_inf_xla = + GetValueOrThrow(bridge::GetXlaTensor(found_inf)); + XLATensorPtr step_xla = + GetValueOrThrow(bridge::GetXlaTensor(step)); + XLATensorPtr param_xla = + GetValueOrThrow(bridge::GetXlaTensor(param)); + XLATensorPtr d_p_xla = GetValueOrThrow(bridge::GetXlaTensor(d_p)); + XLATensorPtr buf_xla = GetValueOrThrow(bridge::GetXlaTensor(buf)); tensor_methods::sgd_optimizer_step_( found_inf_xla, step_xla, param_xla, buf_xla, d_p_xla, weight_decay, momentum, lr, dampening, nesterov, maximize); @@ -2489,14 +2500,20 @@ void InitXlaModuleBindings(py::module m) { bool use_adamw) { { NoGilSection nogil; - XLATensorPtr found_inf_xla = bridge::GetXlaTensor(found_inf); - XLATensorPtr step_xla = bridge::GetXlaTensor(step); - XLATensorPtr param_xla = bridge::GetXlaTensor(param); - XLATensorPtr grad_xla = bridge::GetXlaTensor(grad); - XLATensorPtr exp_avg_xla = bridge::GetXlaTensor(exp_avg); - XLATensorPtr exp_avg_sq_xla = bridge::GetXlaTensor(exp_avg_sq); + XLATensorPtr found_inf_xla = + GetValueOrThrow(bridge::GetXlaTensor(found_inf)); + XLATensorPtr step_xla = + GetValueOrThrow(bridge::GetXlaTensor(step)); + XLATensorPtr param_xla = + GetValueOrThrow(bridge::GetXlaTensor(param)); + XLATensorPtr grad_xla = + GetValueOrThrow(bridge::GetXlaTensor(grad)); + XLATensorPtr exp_avg_xla = + GetValueOrThrow(bridge::GetXlaTensor(exp_avg)); + XLATensorPtr exp_avg_sq_xla = + GetValueOrThrow(bridge::GetXlaTensor(exp_avg_sq)); XLATensorPtr max_exp_avg_sq_xla = - bridge::GetXlaTensor(max_exp_avg_sq); + GetValueOrThrow(bridge::GetXlaTensor(max_exp_avg_sq)); tensor_methods::adam_optimizer_step_( found_inf_xla, step_xla, param_xla, grad_xla, exp_avg_xla, exp_avg_sq_xla, max_exp_avg_sq_xla, beta1, beta2, lr, @@ -2509,7 +2526,7 @@ void InitXlaModuleBindings(py::module m) { }) .def("_xla_annotate_custom_sharding", [](const at::Tensor& input, xla::OpSharding sharding) { - XLATensorPtr xtensor = bridge::GetXlaTensor(input); + XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input)); ShardingUtil::XlaAnnotateCustomSharding(xtensor, sharding); }) .def("_mark_manual_sharding", @@ -2521,7 +2538,7 @@ void InitXlaModuleBindings(py::module m) { .def( "_spmd_full_to_shard_shape", [](const at::Tensor& input) -> at::Tensor { - XLATensorPtr xtensor = bridge::GetXlaTensor(input); + XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input)); auto sharding_spec = xtensor->sharding_spec(); XLA_CHECK(sharding_spec != nullptr) << "Input tensor is not sharded"; @@ -2542,7 +2559,7 @@ void InitXlaModuleBindings(py::module m) { [](const at::Tensor& input, const xla::OpSharding& sharding, const std::vector& output_shape, const py::object& output_dtype) -> at::Tensor { - XLATensorPtr xtensor = bridge::GetXlaTensor(input); + XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input)); auto sharding_spec = xtensor->sharding_spec(); XLA_CHECK(sharding_spec != nullptr && sharding_spec->sharding.type() == xla::OpSharding::MANUAL) @@ -2564,17 +2581,17 @@ void InitXlaModuleBindings(py::module m) { }) .def("_xla_clear_sharding", [](const at::Tensor& input) { - XLATensorPtr xtensor = bridge::GetXlaTensor(input); + XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input)); xtensor->ClearShardingSpec(); }) .def("_get_xla_sharding_spec", [](const at::Tensor& input) -> std::string { - XLATensorPtr xtensor = bridge::GetXlaTensor(input); + XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input)); return GetXLAShardingSpec(xtensor); }) .def("_get_xla_op_sharding", [](const at::Tensor& input) -> std::optional { - XLATensorPtr xtensor = bridge::GetXlaTensor(input); + XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input)); XLATensor::ShardingSpecPtr sharding_spec = xtensor ? xtensor->sharding_spec() : nullptr; if (sharding_spec != nullptr) { @@ -2591,14 +2608,14 @@ void InitXlaModuleBindings(py::module m) { std::vector sharding_specs; sharding_specs.reserve(tensors.size()); for (const at::Tensor& tensor : tensors) { - sharding_specs.push_back( - GetXLAShardingSpec(bridge::GetXlaTensor(tensor))); + sharding_specs.push_back(GetXLAShardingSpec( + GetValueOrThrow(bridge::GetXlaTensor(tensor)))); } return sharding_specs; }) .def("_get_xla_sharding_type", [](const at::Tensor& input) -> std::optional { - XLATensorPtr xtensor = bridge::GetXlaTensor(input); + XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input)); auto sharding_spec = xtensor->sharding_spec(); if (sharding_spec != nullptr) { return ShardingUtil::GetShardingType(sharding_spec->sharding); @@ -2693,7 +2710,8 @@ void InitXlaModuleBindings(py::module m) { std::vector element_types; // Find all shard handles for transfer for (auto& tensor : input) { - XLATensorPtr xtensor = bridge::GetXlaTensor(tensor); + XLATensorPtr xtensor = + GetValueOrThrow(bridge::GetXlaTensor(tensor)); XLA_CHECK(xtensor->GetXlaData() != nullptr) << "Shard data is not available"; XLA_CHECK(xtensor->sharding_spec() != nullptr) @@ -2746,7 +2764,8 @@ void InitXlaModuleBindings(py::module m) { -> std::vector>> { std::vector>> result; for (auto& tensor : input_tensors) { - XLATensorPtr xtensor = bridge::GetXlaTensor(tensor); + XLATensorPtr xtensor = + GetValueOrThrow(bridge::GetXlaTensor(tensor)); XLA_CHECK(xtensor->sharding_spec() != nullptr) << "Tensor is not sharded"; auto handle = @@ -2803,7 +2822,8 @@ void InitXlaModuleBindings(py::module m) { "_load_local_shards", [](const at::Tensor& tensor, std::vector& shards, std::vector& devices) { - XLATensorPtr xtensor = bridge::GetXlaTensor(tensor); + XLATensorPtr xtensor = + GetValueOrThrow(bridge::GetXlaTensor(tensor)); XLA_CHECK(xtensor->sharding_spec() != nullptr) << "Cannot load local shards into a non sharded tensor"; XLA_CHECK(devices.size() == @@ -2880,7 +2900,7 @@ void InitXlaModuleBindings(py::module m) { }) .def("_is_placecholder", [](at::Tensor& input) { - XLATensorPtr xtensor = bridge::GetXlaTensor(input); + XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input)); return xtensor->CurrentDataHandle() && !xtensor->CurrentDataHandle()->HasValue(); }) @@ -2967,9 +2987,9 @@ void InitXlaModuleBindings(py::module m) { } auto xtensors = tensor_methods::custom_call( - bridge::GetXlaTensors(inputs), target, output_shapes, dtypes, - has_side_effect, backend_config, api_version, - frontend_attributes); + GetValueOrThrow(bridge::GetXlaTensors(inputs)), target, + output_shapes, dtypes, has_side_effect, backend_config, + api_version, frontend_attributes); return bridge::AtenFromXlaTensors(std::move(xtensors)); }) .def("_xla_tpu_custom_call", @@ -3005,7 +3025,7 @@ void InitXlaModuleBindings(py::module m) { .def("_set_xla_custom_op_name_prefix", [](const at::Tensor& input, const std::string& op_name_prefix, size_t max_call_stack_depth) -> bool { - XLATensorPtr xtensor = bridge::GetXlaTensor(input); + XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input)); std::shared_ptr user_meta = std::make_shared(op_name_prefix, max_call_stack_depth); @@ -3032,7 +3052,8 @@ void InitXlaModuleBindings(py::module m) { std::vector handles; handles.reserve(tensors.size()); for (auto& tensor : tensors) { - handles.push_back(bridge::GetXlaTensor(tensor)->GetHandle()); + handles.push_back( + GetValueOrThrow(bridge::GetXlaTensor(tensor))->GetHandle()); } return handles; }) @@ -3049,7 +3070,8 @@ void InitXlaModuleBindings(py::module m) { .def("_xla_mark_dynamic", [](const at::Tensor& input, uint32_t dim) { TORCH_LAZY_COUNTER("XlaMarkDynamic", 1); - XLATensorPtr xtensor = bridge::GetXlaTensor(input); + XLATensorPtr xtensor = + GetValueOrThrow(bridge::GetXlaTensor(input)); xtensor->MarkDynamicDimension(dim); }) .def("_xla_dynamic_expand", @@ -3090,7 +3112,8 @@ void InitXlaModuleBindings(py::module m) { // Note that donated buffers can not be used after being donated. "_set_buffer_donation", [](at::Tensor& tensor, bool should_donate) -> bool { - XLATensorPtr xtensor = bridge::GetXlaTensor(tensor); + XLATensorPtr xtensor = + GetValueOrThrow(bridge::GetXlaTensor(tensor)); bool buffer_donation_updated = false; if (xtensor->CurrentDataHandle() != nullptr) { auto data = @@ -3113,7 +3136,8 @@ void InitXlaModuleBindings(py::module m) { }) .def("_get_buffer_donation", [](const at::Tensor& input) -> bool { - XLATensorPtr xtensor = bridge::GetXlaTensor(input); + XLATensorPtr xtensor = + GetValueOrThrow(bridge::GetXlaTensor(input)); if (!xtensor) { return false; } else if (xtensor->CurrentDataHandle() != nullptr) { @@ -3134,7 +3158,8 @@ void InitXlaModuleBindings(py::module m) { }) .def("_on_ready_callback", [](const at::Tensor& tensor, const std::function& callback) { - XLATensorPtr xtensor = bridge::GetXlaTensor(tensor); + XLATensorPtr xtensor = + GetValueOrThrow(bridge::GetXlaTensor(tensor)); XLA_CHECK(xtensor) << "The input is not an XLA tensor."; // Wait for placeholder `Data`s to be assigned XLAGraphExecutor::Get()->WaitDeviceOps({}); @@ -3158,8 +3183,7 @@ void InitXlaModuleBindings(py::module m) { }) .def("_unsafe_buffer_pointer", [](const at::Tensor& input) -> std::uintptr_t { - XLATensorPtr xtensor = bridge::GetXlaTensor(input); - XLA_CHECK(xtensor) << "The input is not an XLA tensor."; + auto xtensor = GetValueOrThrow(bridge::GetXlaTensor(input)); if (xtensor->CurrentDataHandle() != nullptr) { std::shared_ptr data = std::dynamic_pointer_cast( @@ -3226,9 +3250,10 @@ void InitXlaModuleBindings(py::module m) { -> std::pair, std::vector> { std::vector roots; for (const at::Tensor& tensor : output_tensors) { - auto xtensor = bridge::TryGetXlaTensor(tensor); - if (xtensor) { - roots.push_back(xtensor->GetIrValue().node.get()); + auto xtensor_status = bridge::GetXlaTensor(tensor); + if (xtensor_status.ok()) { + roots.push_back( + xtensor_status.value()->GetIrValue().node.get()); } } @@ -3294,12 +3319,13 @@ void InitXlaModuleBindings(py::module m) { std::vector xtensors; xtensors.reserve(tensors.size()); for (const at::Tensor& tensor : tensors) { - xtensors.push_back(bridge::TryGetXlaTensor(tensor)); + xtensors.push_back( + bridge::GetXlaTensor(tensor).value_or(XLATensorPtr{})); } return check_materialization_helper(xtensors); }) .def( - // Return true if value of the any tensor in this devicerequires a + // Return true if value of the any tensor in this device requires a // computation. "_check_device_tensor_need_materialization", [](const std::string& device_str) -> std::vector { @@ -3309,18 +3335,19 @@ void InitXlaModuleBindings(py::module m) { opt_device ? &opt_device.value() : nullptr); return check_materialization_helper(xtensors); }) - .def("_get_graph_hash", - [](const std::vector& tensors) { - std::vector xtensors; - xtensors.reserve(tensors.size()); - for (auto& tensor : tensors) { - xtensors.push_back(bridge::GetXlaTensor(tensor)); - } - torch::lazy::hash_t hash = - XLAGraphExecutor::Get()->GetGraphHash(xtensors); - std::string bin((const char*)&hash, sizeof(hash)); - return py::bytes(bin); - }) + .def( + "_get_graph_hash", + [](const std::vector& tensors) { + std::vector xtensors; + xtensors.reserve(tensors.size()); + for (auto& tensor : tensors) { + xtensors.push_back(GetValueOrThrow(bridge::GetXlaTensor(tensor))); + } + torch::lazy::hash_t hash = + XLAGraphExecutor::Get()->GetGraphHash(xtensors); + std::string bin((const char*)&hash, sizeof(hash)); + return py::bytes(bin); + }) .def("_clear_pending_irs", [](const std::string& device) { // Use with caution. Those tensor whole ir was cleared @@ -3332,7 +3359,8 @@ void InitXlaModuleBindings(py::module m) { }) .def("_unique_id_for_ir_and_data", [](const at::Tensor& tensor) -> std::string { - XLATensorPtr xtensor = bridge::GetXlaTensor(tensor); + XLATensorPtr xtensor = + GetValueOrThrow(bridge::GetXlaTensor(tensor)); if (xtensor->CurrentIrValue()) { torch::lazy::Value value = xtensor->CurrentIrValue(); return std::to_string((uintptr_t)value.node.get()) + ", " + diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 80d799076048..5916376c1061 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -2400,8 +2400,10 @@ std::tuple native_batch_norm( } } else { at::Tensor at_input = bridge::AtenFromXlaTensor(input); - mean = bridge::GetXlaTensor(at::empty({0}, at_input.options())); - variance_inverse = bridge::GetXlaTensor(at::empty({0}, at_input.options())); + mean = GetValueOrThrow( + bridge::GetXlaTensor(at::empty({0}, at_input.options()))); + variance_inverse = GetValueOrThrow( + bridge::GetXlaTensor(at::empty({0}, at_input.options()))); } XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index 26c669b1e4f8..bf5f7966f8f0 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -1056,7 +1056,7 @@ xla::PrimitiveType GetShapeDimensionType( std::shared_ptr get_data_handle( const at::Tensor& input) { - XLATensorPtr xtensor = bridge::GetXlaTensor(input); + XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input)); if (xtensor->CurrentDataHandle() != nullptr) { TF_VLOG(4) << "The xla tensor has a current data handle."; return std::dynamic_pointer_cast( diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 8ea25adcf034..15f38cae2333 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -805,15 +805,17 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( // setup the arguments for (auto& ivalue : graph_inputs) { torch::lazy::BackendDataPtr dataptr; - if (auto xla_tensor_ptr = bridge::TryGetXlaTensor(ivalue.toTensor())) { + auto xla_tensor_status = bridge::GetXlaTensor(ivalue.toTensor()); + if (xla_tensor_status.ok()) { + auto xla_tensor = std::move(xla_tensor_status).value(); bool is_non_data_ir = - xla_tensor_ptr->CurrentIrValue().node != nullptr && + xla_tensor->CurrentIrValue().node != nullptr && (torch_xla::DeviceData::Cast( - xla_tensor_ptr->CurrentIrValue().node.get()) == nullptr); + xla_tensor->CurrentIrValue().node.get()) == nullptr); XLA_CHECK(!is_non_data_ir) << "input data to dynamo graph can not be a pending ir, please set " "`torch_xla._dynamo.config.skip_input_data_check` to False"; - dataptr = xla_tensor_ptr->GetXlaData(); + dataptr = xla_tensor->GetXlaData(); } else { XLA_CHECK(device.type() != (int8_t)XlaDeviceType::SPMD) << "SPMD device data should already be on the XLA backend " @@ -934,8 +936,9 @@ std::vector XLAGraphExecutor::ExecuteStablehlo( // setup the arguments for (auto& ivalue : graph_inputs) { torch::lazy::BackendDataPtr dataptr; - if (auto xla_tensor_ptr = bridge::TryGetXlaTensor(ivalue.toTensor())) { - dataptr = xla_tensor_ptr->GetXlaData(); + auto xla_tensor_status = bridge::GetXlaTensor(ivalue.toTensor()); + if (xla_tensor_status.ok()) { + dataptr = xla_tensor_status.value()->GetXlaData(); } else { dataptr = torch_xla::TensorToXlaData(ivalue.toTensor(), device); } diff --git a/torch_xla/csrc/xla_manual_registration.cpp b/torch_xla/csrc/xla_manual_registration.cpp index 79db2b2307ce..4a5beea20988 100644 --- a/torch_xla/csrc/xla_manual_registration.cpp +++ b/torch_xla/csrc/xla_manual_registration.cpp @@ -37,8 +37,8 @@ at::Tensor nms_kernel(const at::Tensor& boxes, const at::Tensor& scores, XLA_CHECK_EQ(boxes.size(0), scores.size(0)) << "nms(): boxes and scores should have the same size for dimension 0."; - XLATensorPtr xla_boxes = bridge::GetXlaTensor(boxes); - XLATensorPtr xla_scores = bridge::GetXlaTensor(scores); + XLATensorPtr xla_boxes = GetValueOrThrow(bridge::GetXlaTensor(boxes)); + XLATensorPtr xla_scores = GetValueOrThrow(bridge::GetXlaTensor(scores)); return bridge::AtenFromXlaTensor( tensor_methods::nms(xla_boxes, xla_scores, iou_threshold), /*skip_functionalization=*/true); diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 3ddec53d7004..abc18206d420 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -767,7 +767,7 @@ void ShardingUtil::XlaMarkSharding(const at::Tensor& input, << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; XLA_CHECK(sharding.type() != xla::OpSharding::UNKNOWN) << "Can't explicilty annotate with UNKNOWN sharding type."; - XLATensorPtr xtensor = bridge::GetXlaTensor(input); + XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input)); // For Non DeviceData IR values, we directly attach the sharding spec to the // xtensor. From 7aa466e2310072c4b1175be640a4c5898299440f Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Tue, 29 Jul 2025 20:15:39 -0300 Subject: [PATCH 019/133] Dump C++ and Status propagation stacktraces. (#9492) --- test/cpp/test_status_common.h | 262 +++++++++++++++++++++++++--------- torch_xla/csrc/BUILD | 1 + torch_xla/csrc/status.cpp | 121 +++++++++++----- torch_xla/csrc/status.h | 89 ++++++++---- 4 files changed, 345 insertions(+), 128 deletions(-) diff --git a/test/cpp/test_status_common.h b/test/cpp/test_status_common.h index 7cb63d4f38a7..4d4b173f6431 100644 --- a/test/cpp/test_status_common.h +++ b/test/cpp/test_status_common.h @@ -18,9 +18,12 @@ #ifndef XLA_TEST_CPP_TEST_STATUS_COMMON_H_ #define XLA_TEST_CPP_TEST_STATUS_COMMON_H_ +#include +#include #include #include +#include #include #include "absl/status/status.h" @@ -30,7 +33,7 @@ namespace torch_xla { -// Enum to control whether C++ error context is shown in status messages +// Enum to control whether C++ error context is shown in status messages. enum class CppStacktracesMode { kShow, kHide, @@ -74,10 +77,45 @@ class StatusTest : public testing::TestWithParam { namespace testing { -constexpr inline char new_message[] = "New test error message"; -constexpr inline char message[] = "Test error message"; -constexpr inline char test_file[] = "test_file.cpp"; -constexpr inline int32_t line = 42; +constexpr inline char kNewMessage[] = "New test error message"; +constexpr inline char kMessage[] = "Test error message"; +constexpr inline char kFile[] = "test_file.cpp"; +constexpr inline char kFunction[] = "foo"; +constexpr inline char kEntryPrefix[] = "\n "; +constexpr inline int32_t kLine = 42; + +// The PyTorch C++ stacktrace is ALWAYS appended to the error message. +// More specifically, when `what()` function is called. +// +// However, it's only when the raised `c10::Error` gets translated to a +// Python exception that PyTorch checks the value of the +// `TORCH_SHOW_CPP_STACKTRACES` environment variable, which actually +// controls whether the stacktrace will get shown or not by calling +// `what_without_backtraces()`, instead. +// +// Therefore, we need to mimic this behavior. +#define THROW_RUNTIME_ERROR_FROM_C10_ERROR(block) \ + try { \ + block; \ + } catch (const c10::Error& error) { \ + throw std::runtime_error(IsShowCppStacktracesMode() \ + ? error.what() \ + : error.what_without_backtrace()); \ + } + +// Prefix of the C++ stacktrace PyTorch adds to the error message. +constexpr inline char kTorchCppStacktracePrefix[] = + "Exception raised from MaybeThrow at torch_xla/csrc/status.cpp:"; + +inline std::string GetStatusPropagationTrace(const absl::Status& status) { + if (status.ok()) { + return ""; + } + auto status_propagation_trace = status.GetPayload(kStatusPropagationTraceKey); + return status_propagation_trace.has_value() + ? std::string(status_propagation_trace->Flatten()) + : ""; +} TEST_P(StatusTest, MaybeThrowWithOkStatus) { absl::Status ok_status = absl::OkStatus(); @@ -85,8 +123,22 @@ TEST_P(StatusTest, MaybeThrowWithOkStatus) { } TEST_P(StatusTest, MaybeThrowWithErrorStatus) { - absl::Status error_status = absl::InvalidArgumentError(message); - EXPECT_THROW(MaybeThrow(error_status), std::runtime_error); + auto throw_exception = [=]() { + THROW_RUNTIME_ERROR_FROM_C10_ERROR({ + absl::Status error_status = absl::InvalidArgumentError(kMessage); + MaybeThrow(error_status); + }); + }; + + if (IsShowCppStacktracesMode()) { + std::string expected_prefix = + absl::StrCat(kMessage, "\n\n", kTorchCppStacktracePrefix); + EXPECT_THAT(throw_exception, ::testing::ThrowsMessage( + ::testing::StartsWith(expected_prefix))); + } else { + EXPECT_THAT(throw_exception, ::testing::ThrowsMessage( + ::testing::Eq(kMessage))); + } } TEST_P(StatusTest, GetValueOrThrowWithOkStatusOr) { @@ -97,44 +149,75 @@ TEST_P(StatusTest, GetValueOrThrowWithOkStatusOr) { } TEST_P(StatusTest, GetValueOrThrowWithErrorStatusOr) { - absl::StatusOr status_or = absl::InvalidArgumentError(message); - EXPECT_THROW(GetValueOrThrow(std::move(status_or)), std::runtime_error); + auto throw_exception = [=]() { + THROW_RUNTIME_ERROR_FROM_C10_ERROR({ + absl::StatusOr error_status = absl::InvalidArgumentError(kMessage); + int value = GetValueOrThrow(error_status); + }); + }; + if (IsShowCppStacktracesMode()) { + std::string expected_prefix = + absl::StrCat(kMessage, "\n\n", kTorchCppStacktracePrefix); + EXPECT_THAT(throw_exception, ::testing::ThrowsMessage( + ::testing::StartsWith(expected_prefix))); + } else { + EXPECT_THAT(throw_exception, ::testing::ThrowsMessage( + ::testing::Eq(kMessage))); + } } TEST_P(StatusTest, MaybeWithLocationPropagatesErrorStatus) { - absl::Status error_status = absl::InvalidArgumentError(message); - absl::Status result = MaybeWithLocation(error_status, test_file, line); + absl::Status error_status = absl::InvalidArgumentError(kMessage); + absl::Status result = + status_internal::MaybeWithLocation(error_status, kFile, kLine, kFunction); + + ASSERT_FALSE(result.ok()); + EXPECT_EQ(result.code(), error_status.code()); + EXPECT_EQ(result.message(), error_status.message()); + if (IsShowCppStacktracesMode()) { - ASSERT_NE(result, error_status); - EXPECT_FALSE(result.ok()); - EXPECT_EQ(result.code(), error_status.code()); - EXPECT_EQ(result.message(), "Test error message (at test_file.cpp:42)"); + EXPECT_NE(result, error_status); + EXPECT_EQ(GetStatusPropagationTrace(result), + absl::StrCat(kEntryPrefix, "From: ", kFunction, " at ", kFile, + ":", kLine, " (error: ", kMessage, ")")); } else { EXPECT_EQ(result, error_status); } } TEST_P(StatusTest, MaybeWithNewMessageEmptyNewMessage) { - absl::Status error_status = absl::InvalidArgumentError(message); - absl::Status result = MaybeWithNewMessage(error_status, test_file, line); - EXPECT_EQ(result, error_status); + absl::Status error_status = absl::InvalidArgumentError(kMessage); + absl::Status result = status_internal::MaybeWithNewMessage( + error_status, kFile, kLine, kFunction); + + ASSERT_FALSE(result.ok()); + EXPECT_EQ(result.code(), error_status.code()); + EXPECT_EQ(result.message(), error_status.message()); + + if (IsShowCppStacktracesMode()) { + EXPECT_NE(result, error_status); + EXPECT_EQ(GetStatusPropagationTrace(result), + absl::StrCat(kEntryPrefix, "From: ", kFunction, " at ", kFile, + ":", kLine)); + } else { + EXPECT_EQ(result, error_status); + } } TEST_P(StatusTest, MaybeWithNewMessageNonEmptyNewMessage) { - absl::Status error_status = absl::InvalidArgumentError(message); - absl::Status result = - MaybeWithNewMessage(error_status, test_file, line, new_message); + absl::Status error_status = absl::InvalidArgumentError(kMessage); + absl::Status result = status_internal::MaybeWithNewMessage( + error_status, kFile, kLine, kFunction, kNewMessage); - ASSERT_NE(result, error_status); ASSERT_FALSE(result.ok()); EXPECT_EQ(result.code(), error_status.code()); + EXPECT_EQ(result.message(), std::string_view(kNewMessage)); + EXPECT_NE(result, error_status); if (IsShowCppStacktracesMode()) { - EXPECT_EQ(result.message(), - absl::StrCat("New test error message (at test_file.cpp:42)\n" - "From Error: Test error message")); - } else { - EXPECT_EQ(result.message(), std::string_view(new_message)); + EXPECT_EQ(GetStatusPropagationTrace(result), + absl::StrCat(kEntryPrefix, "From: ", kFunction, " at ", kFile, + ":", kLine, " (error: ", kNewMessage, ")")); } } @@ -154,7 +237,7 @@ TEST_P(StatusTest, MacroReturnIfError) { TEST_P(StatusTest, MacroReturnIfErrorWithError) { auto test_function = [=]() -> absl::Status { - absl::Status error_status = absl::InvalidArgumentError(message); + absl::Status error_status = absl::InvalidArgumentError(kMessage); XLA_RETURN_IF_ERROR(error_status); return absl::OkStatus(); }; @@ -162,21 +245,22 @@ TEST_P(StatusTest, MacroReturnIfErrorWithError) { absl::Status result = test_function(); ASSERT_FALSE(result.ok()); EXPECT_EQ(result.code(), absl::StatusCode::kInvalidArgument); - EXPECT_EQ(result.message(), std::string_view(message)); + EXPECT_EQ(result.message(), std::string_view(kMessage)); } TEST_P(StatusTest, MacroReturnIfErrorWithNestedError) { - int32_t errline = 0; - auto inner_test_function = [&errline]() -> absl::Status { - errline = __LINE__ + 1; - return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(message)); + int32_t errline0 = __LINE__ + 2; + auto inner_test_function = []() -> absl::Status { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(kMessage)); }; + int32_t errline1 = __LINE__ + 2; auto test_function = [&]() -> absl::Status { XLA_RETURN_IF_ERROR(inner_test_function()); return absl::OkStatus(); }; + int32_t errline2 = __LINE__ + 2; auto outer_test_function = [&]() -> absl::Status { XLA_RETURN_IF_ERROR(test_function()); return absl::OkStatus(); @@ -185,34 +269,37 @@ TEST_P(StatusTest, MacroReturnIfErrorWithNestedError) { absl::Status result = outer_test_function(); ASSERT_FALSE(result.ok()); EXPECT_EQ(result.code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(result.message(), std::string_view(kMessage)); if (IsShowCppStacktracesMode()) { - EXPECT_EQ(result.message(), absl::StrCat("Test error message (at ", - __FILE__, ":", errline, ")")); - } else { - EXPECT_EQ(result.message(), std::string_view(message)); + auto frame0 = absl::StrCat(kEntryPrefix, "From: operator() at ", __FILE__, + ":", errline0, " (error: ", kMessage, ")"); + auto frame1 = absl::StrCat(kEntryPrefix, "From: operator() at ", __FILE__, + ":", errline1); + auto frame2 = absl::StrCat(kEntryPrefix, "From: operator() at ", __FILE__, + ":", errline2); + EXPECT_EQ(GetStatusPropagationTrace(result), + absl::StrCat(frame0, frame1, frame2)); } } TEST_P(StatusTest, MacroReturnIfErrorWithErrorWithNewMessage) { - int32_t errline = 0; - auto test_function = [&errline]() -> absl::Status { - absl::Status error_status = absl::InvalidArgumentError(message); - errline = __LINE__ + 1; - XLA_RETURN_IF_ERROR(error_status, new_message); + int32_t errline = __LINE__ + 3; + auto test_function = []() -> absl::Status { + absl::Status error_status = absl::InvalidArgumentError(kMessage); + XLA_RETURN_IF_ERROR(error_status, kNewMessage); return absl::OkStatus(); }; absl::Status result = test_function(); ASSERT_FALSE(result.ok()); EXPECT_EQ(result.code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(result.message(), std::string_view(kNewMessage)); if (IsShowCppStacktracesMode()) { - EXPECT_EQ(result.message(), - absl::StrCat("New test error message (at ", __FILE__, ":", - errline, ")\nFrom Error: Test error message")); - } else { - EXPECT_EQ(result.message(), std::string_view(new_message)); + EXPECT_EQ(GetStatusPropagationTrace(result), + absl::StrCat(kEntryPrefix, "From: operator() at ", __FILE__, ":", + errline, " (error: ", kNewMessage, ")")); } } @@ -233,7 +320,7 @@ TEST_P(StatusTest, MacroAssignOrReturn) { TEST_P(StatusTest, MacroAssignOrReturnWithError) { auto test_function = []() -> absl::StatusOr { - absl::StatusOr status_or = absl::InvalidArgumentError(message); + absl::StatusOr status_or = absl::InvalidArgumentError(kMessage); XLA_ASSIGN_OR_RETURN(int value, status_or); return value * 2; }; @@ -241,43 +328,90 @@ TEST_P(StatusTest, MacroAssignOrReturnWithError) { absl::StatusOr result = test_function(); ASSERT_FALSE(result.ok()); EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument); - EXPECT_EQ(result.status().message(), std::string_view(message)); + EXPECT_EQ(result.status().message(), std::string_view(kMessage)); } TEST_P(StatusTest, MacroAssignOrReturnWithErrorWithNewMessage) { - int32_t errline = 0; - - auto test_function = [&errline]() -> absl::StatusOr { - absl::StatusOr status_or = absl::InvalidArgumentError(message); - errline = __LINE__ + 1; - XLA_ASSIGN_OR_RETURN(int value, status_or, new_message); + int32_t errline = __LINE__ + 3; + auto test_function = []() -> absl::StatusOr { + absl::StatusOr status_or = absl::InvalidArgumentError(kMessage); + XLA_ASSIGN_OR_RETURN(int value, status_or, kNewMessage); return value * 2; }; absl::StatusOr result = test_function(); ASSERT_FALSE(result.ok()); EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(result.status().message(), std::string_view(kNewMessage)); if (IsShowCppStacktracesMode()) { - EXPECT_EQ(result.status().message(), - absl::StrCat("New test error message (at ", __FILE__, ":", - errline, ")\nFrom Error: Test error message")); - } else { - EXPECT_EQ(result.status().message(), std::string_view(new_message)); + EXPECT_EQ(GetStatusPropagationTrace(result.status()), + absl::StrCat(kEntryPrefix, "From: operator() at ", __FILE__, ":", + errline, " (error: ", kNewMessage, ")")); } } TEST_P(StatusTest, MacroErrorWithLocation) { - absl::Status error_status = absl::InvalidArgumentError(message); + absl::Status error_status = absl::InvalidArgumentError(kMessage); int32_t errline = __LINE__ + 1; absl::Status result = XLA_ERROR_WITH_LOCATION(error_status); ASSERT_FALSE(result.ok()); EXPECT_EQ(result.code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(result.message(), std::string_view(kMessage)); + if (IsShowCppStacktracesMode()) { + EXPECT_EQ(GetStatusPropagationTrace(result), + absl::StrCat(kEntryPrefix, "From: ", __FUNCTION__, " at ", + __FILE__, ":", errline, " (error: ", kMessage, ")")); + } +} + +TEST_P(StatusTest, MaybeThrowWithErrorPropagationWithNewMessage) { + int32_t errline0 = __LINE__ + 2; + auto innerfn = [&]() -> absl::Status { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(kMessage)); + }; + + int32_t errline1 = __LINE__ + 2; + auto midfn = [&]() -> absl::Status { + XLA_RETURN_IF_ERROR(innerfn(), kNewMessage); + return absl::OkStatus(); + }; + + int32_t errline2 = __LINE__ + 2; + auto outerfn = [&]() -> absl::Status { + XLA_RETURN_IF_ERROR(midfn()); + return absl::OkStatus(); + }; + + auto throw_exception = [&]() { + THROW_RUNTIME_ERROR_FROM_C10_ERROR(MaybeThrow(outerfn())); + }; + if (IsShowCppStacktracesMode()) { - EXPECT_EQ(result.message(), absl::StrCat("Test error message (at ", - __FILE__, ":", errline, ")")); + // Expected Error Message Prefix + // ============================= + // + // New test error kMessage + // + // Status Propagation Stacktrace: + // From: ./test/cpp/test_status_common.h:329 (error: Test error + // kMessage) From: ./test/cpp/test_status_common.h:335 (error: New test + // error kMessage) From: ./test/cpp/test_status_common.h:342 + // + // C++ Stacktrace: + // + std::string expected_prefix = absl::StrCat( + kNewMessage, "\n\nStatus Propagation Trace:", kEntryPrefix, + "From: operator() at ", __FILE__, ":", errline0, " (error: ", kMessage, + ")", kEntryPrefix, "From: operator() at ", __FILE__, ":", errline1, + " (error: ", kNewMessage, ")", kEntryPrefix, "From: operator() at ", + __FILE__, ":", errline2, "\n\n", kTorchCppStacktracePrefix); + + EXPECT_THAT(throw_exception, ::testing::ThrowsMessage( + ::testing::StartsWith(expected_prefix))); } else { - EXPECT_EQ(result.message(), std::string_view(message)); + EXPECT_THAT(throw_exception, ::testing::ThrowsMessage( + ::testing::Eq(kNewMessage))); } } diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 6c34eca14502..31ab65dbbcaf 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -377,6 +377,7 @@ cc_library( hdrs = ["status.h"], deps = [ "@torch//:headers", + "@tsl//tsl/platform:stacktrace", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status:statusor", ], diff --git a/torch_xla/csrc/status.cpp b/torch_xla/csrc/status.cpp index 1eb3511cb33e..270f34878675 100644 --- a/torch_xla/csrc/status.cpp +++ b/torch_xla/csrc/status.cpp @@ -1,19 +1,49 @@ #include "torch_xla/csrc/status.h" +#include #include +#include +#include +#include + #include "absl/log/absl_check.h" +#include "tsl/platform/stacktrace.h" namespace torch_xla { -// Common function for generating file location information with a space in the -// beginning. -static std::string LocationStrWithSpace(const char* file, const int32_t line) { - return absl::StrCat(" (at ", file, ":", line, ")"); +// Indent the stack frame representation so that it's easier to see. +constexpr char kFramePrefix[] = "\n "; + +// Creates the stack frame representation for the status propagation trace +// entry. +// +// The resulting string will be appended to the existing status propagation +// trace of the status currently being processed. +// +// Example: +// \n From: at : [(error: )] +// +static std::string GetStackFrame(const char* file, const int32_t line, + const char* function, + const std::string_view new_message) { + auto error_suffix = + new_message.empty() ? "" : absl::StrCat(" (error: ", new_message, ")"); + return absl::StrCat(kFramePrefix, "From: ", function, " at ", file, ":", line, + error_suffix); +} + +// Convenient function that retrieves the status propagation trace payload +// if it exists. Otherwise, returns an empty absl::Cord. +static absl::Cord GetStatusPropagationTraceOrEmpty(const absl::Status& status) { + auto opt = status.GetPayload(kStatusPropagationTraceKey); + return opt.has_value() ? *opt : absl::Cord(); } -absl::Status MaybeWithLocation(const absl::Status& status, const char* file, - const int32_t line) { +absl::Status status_internal::MaybeWithLocation(const absl::Status& status, + const char* file, + const int32_t line, + const char* function) { ABSL_CHECK(!status.ok()); // Return the same status if we don't need to add the C++ source location. @@ -21,14 +51,19 @@ absl::Status MaybeWithLocation(const absl::Status& status, const char* file, return status; } - return absl::Status( - status.code(), - absl::StrCat(status.message(), LocationStrWithSpace(file, line))); + // Make sure this is only called on fresh `status` instances. + ABSL_CHECK(GetStatusPropagationTraceOrEmpty(status).empty()); + + // Adding source location to `status` has the same semantics as overwriting + // the status message: + // 1. An stack frame will be added to the status propagation trace + // 2. The status' message will be the same + return MaybeWithNewMessage(status, file, line, function, status.message()); } -absl::Status MaybeWithNewMessage(const absl::Status& status, const char* file, - const int32_t line, - const std::string_view new_message) { +absl::Status status_internal::MaybeWithNewMessage( + const absl::Status& status, const char* file, const int32_t line, + const char* function, const std::string_view new_message) { ABSL_CHECK(!status.ok()); // Return the same status if: @@ -38,39 +73,55 @@ absl::Status MaybeWithNewMessage(const absl::Status& status, const char* file, return status; } - std::string_view old_message = status.message(); - // Replace the old status message with `new_message`, if it's not empty. // // The idea is that whenever `new_message` is given, it should have more // context to give a better error message to the user. - std::string_view message = new_message.empty() ? old_message : new_message; + auto new_status = absl::Status( + status.code(), new_message.empty() ? status.message() : new_message); - // If `TORCH_SHOW_CPP_STACKTRACES` is set, show the context of this error. - // In other words, show: - // 1. The error location - // 2. The old messages that were replaced by `new_message`. + // If `TORCH_SHOW_CPP_STACKTRACES` is set: // - // This should give more context for developers. Showing the older error - // messages alongside their debug information. + // 1. append the current stack frame to the status propagation trace + // payload // - // Note that we also condition showing source location information by (2) - // (i.e. `new_message` is not empty) because we don't really wish to show - // a stacktrace. Instead, we show only the history of error messages that - // has led to the current error. - const std::string context = - (torch::get_cpp_stacktraces_enabled() && !new_message.empty()) - ? absl::StrCat(LocationStrWithSpace(file, line), - "\nFrom Error: ", old_message) - : ""; - - return absl::Status(status.code(), absl::StrCat(message, context)); + // 2. append the new error message, if not empty + if (torch::get_cpp_stacktraces_enabled()) { + auto status_propagation_trace = GetStatusPropagationTraceOrEmpty(status); + status_propagation_trace.Append( + GetStackFrame(file, line, function, new_message)); + new_status.SetPayload(kStatusPropagationTraceKey, status_propagation_trace); + } + + return new_status; +} + +// Get a formatted string representation of the status propagation trace +// if it's not empty. +static std::string GetFormattedStatusPropagationTrace( + const absl::Status& status) { + auto status_propagation_trace = GetStatusPropagationTraceOrEmpty(status); + return status_propagation_trace.empty() + ? "" + : absl::StrCat("\nStatus Propagation Trace:", + status_propagation_trace.Flatten(), "\n"); +} + +// Get the status message followed by a line break, if we are printing the +// C++ stacktraces. +// +// This is needed so we have a blank line in between the status message and +// the dumped C++ traces (either the status propagation one, or the C++ +// stacktrace). +static std::string MaybeGetMessageWithLineBreak(const absl::Status& status) { + return torch::get_cpp_stacktraces_enabled() + ? absl::StrCat(status.message(), "\n") + : std::string(status.message()); } void MaybeThrow(const absl::Status& status) { - if (!status.ok()) { - throw std::runtime_error(std::string(status.message())); - } + TORCH_CHECK(status.ok(), MaybeGetMessageWithLineBreak(status), + GetFormattedStatusPropagationTrace(status)); } } // namespace torch_xla diff --git a/torch_xla/csrc/status.h b/torch_xla/csrc/status.h index d64a78ba58de..2f53b37381fb 100644 --- a/torch_xla/csrc/status.h +++ b/torch_xla/csrc/status.h @@ -14,6 +14,24 @@ namespace torch_xla { +// `type_url` for retrieving the status propagation trace payload of a given +// status. +// +// The payload is composed of multiple lines, where each line represents a stack +// frame in the status propagation trace. Each line is in the following format: +// +// \n From: :[ErrorSuffix] +// | ---- | +// | | |_ error message produced in that source +// | | location (it might be overwritten later). +// | | +// | |_ leading 4 spaces for improved readability. +// | +// |_ start with a line break. +// +constexpr char kStatusPropagationTraceKey[] = + "type.googleapis.com/torch_xla.status_trace"; + // If `TORCH_SHOW_CPP_STACKTRACES` is set, creates a new Status instance, // appending the current location (e.g. file and line information) to the // status message. @@ -28,10 +46,12 @@ namespace torch_xla { // // If `TORCH_SHOW_CPP_STACKTRACES` is set, the error shown will be: // -// Error message. (at :) +// RuntimeError: Error message. +// From: : (error: Error message.) // -#define XLA_ERROR_WITH_LOCATION(status) \ - ::torch_xla::MaybeWithLocation(status, __FILE__, __LINE__) +#define XLA_ERROR_WITH_LOCATION(status) \ + ::torch_xla::status_internal::MaybeWithLocation(status, __FILE__, __LINE__, \ + __FUNCTION__) #define XLA_CONCAT_(a, b) XLA_CONCAT_IMPL_(a, b) #define XLA_CONCAT_IMPL_(a, b) a##b @@ -41,15 +61,15 @@ namespace torch_xla { // Provides a flexible way to handle error checking with optional message // modification. It evaluates `expr`, checks if it's OK, and either: -// 1. Returns early with an error status (potentially modified by the provided -// additional messages) -// 2. Proceeds with the given `then` block if successful -#define XLA_RETURN_IF_ERROR_IMPL_(expr, var, then, ...) \ - auto var = (expr); \ - if (!var.ok()) { \ - return ::torch_xla::MaybeWithNewMessage( \ - ::torch_xla::GetStatus(var), __FILE__, __LINE__, ##__VA_ARGS__); \ - } \ +// 1. Returns early with an error status +// 2. Proceeds with the given `then` block if successful +#define XLA_RETURN_IF_ERROR_IMPL_(expr, var, then, ...) \ + auto var = (expr); \ + if (!var.ok()) { \ + return ::torch_xla::status_internal::MaybeWithNewMessage( \ + ::torch_xla::status_internal::GetStatus(var), __FILE__, __LINE__, \ + __FUNCTION__, ##__VA_ARGS__); \ + } \ then // Propagates `rexpr`, in case it's a non-ok status. @@ -65,9 +85,13 @@ namespace torch_xla { // we early return a non-ok status. Then, if `TORCH_SHOW_CPP_STACKTRACES` is // set, the error shown will be: // -// New error message. (at :) -// Previous error message. (at :) -// ... +// RuntimeError: New error message. +// +// Status Propagation Stacktrace: +// ... +// From: : (error: Previous error message.) +// ... +// From: : (error: New error message.) // #define XLA_RETURN_IF_ERROR(rexpr, ...) \ do { \ @@ -93,26 +117,29 @@ namespace torch_xla { // If the function call results in an ok status, execution continues with // `result` set to `ret.value()`, where `ret` is the returned value of the // function. Otherwise, we early return a non-ok status. Then, if -// `TORCH_SHOW_CPP_STACKTRACES` is set, the error shown will be: -// -// New error message. (at :) -// Previous error message. (at :) -// ... +// `TORCH_SHOW_CPP_STACKTRACES` is set, the error shown will be similar to +// the one above. // #define XLA_ASSIGN_OR_RETURN(lhs, rexpr, ...) \ XLA_RETURN_IF_ERROR_IMPL_(rexpr, XLA_STATUS_VAR_, \ lhs = std::move(XLA_STATUS_VAR_).value(), \ ##__VA_ARGS__) -// Maybe shows location information in the status message. +namespace status_internal { + +// Adds source location information to the status propagation trace if +// `TORCH_SHOW_CPP_STACKTRACES` is set. // -// This function assumes that `status` is a non-ok status. +// This function assumes that: +// +// 1. `status` is a non-ok status. +// 2. `status` doesn't have a status propagation trace payload +// +// If any of the above assumptions is false, this function crashes the +// whole program. // -// If `TORCH_SHOW_CPP_STACKTRACES` is set, appends the current source -// location information to the status message. Otherwise, it simply returns -// `status`. absl::Status MaybeWithLocation(const absl::Status& status, const char* file, - int32_t line); + int32_t line, const char* function); // Returns an `absl::Status` from an `absl::Status`. // In this case, this function is a no-op. It simply returns the argument. @@ -126,7 +153,8 @@ const absl::Status& GetStatus(const absl::StatusOr& status) { return status.status(); } -// Maybe replace the current `status` message with `new_message`. +// Maybe replace the current `status` message with `new_message`, and also +// add source location information if enabled. // // This function assumes that `status` is a non-ok status. // @@ -137,12 +165,15 @@ const absl::Status& GetStatus(const absl::StatusOr& status) { // Rationale: if given, `new_message` has more context, which makes it possible // to construct better error messages to the user. // -// This function also appends file location information to the error message, if +// This function also appends the source location information to the status +// propagation trace payload (creates a new one if needed), if // `TORCH_SHOW_CPP_STACKTRACES` is set. absl::Status MaybeWithNewMessage(const absl::Status& status, const char* file, - int32_t line, + int32_t line, const char* function, std::string_view new_message = ""); +} // namespace status_internal + // Maybe throws an exception if `status` has a non-ok code. // // Ideally, this function should be used only used in the project's From 199a9bd5e1fe9c4ad034369240a9bd7cec19962a Mon Sep 17 00:00:00 2001 From: Kyuyeun Kim <62023335+kyuyeunk@users.noreply.github.com> Date: Tue, 29 Jul 2025 21:02:15 -0700 Subject: [PATCH 020/133] Add w8a8 kernel blocks for Qwen 2.5 7B (#9517) --- .../pallas_kernels/quantized_matmul_kernel.py | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py b/torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py index d7356be54be6..b4bd0c081f07 100644 --- a/torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py +++ b/torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py @@ -328,8 +328,12 @@ def quantized_matmul_int8( (6, 1024, 13824, 5120, 'bfloat16', True): (1024, 768, 5120), (6, 1024, 1792, 5120, 'bfloat16', True): (1024, 256, 5120), (6, 1024, 28672, 4096, 'bfloat16', True): (1024, 2048, 4096), + (6, 1024, 3584, 18944, 'bfloat16', True): (1024, 3584, 512), + (6, 1024, 3584, 3584, 'bfloat16', True): (1024, 512, 3584), + (6, 1024, 37888, 3584, 'bfloat16', True): (1024, 1024, 3584), (6, 1024, 4096, 14336, 'bfloat16', True): (1024, 256, 14336), (6, 1024, 4096, 4096, 'bfloat16', True): (1024, 512, 4096), + (6, 1024, 4608, 3584, 'bfloat16', True): (1024, 768, 3584), (6, 1024, 5120, 1280, 'bfloat16', True): (1024, 1280, 1280), (6, 1024, 5120, 3456, 'bfloat16', True): (1024, 1024, 3456), (6, 1024, 5120, 640, 'bfloat16', True): (256, 5120, 640), @@ -344,8 +348,12 @@ def quantized_matmul_int8( (6, 128, 13824, 5120, 'bfloat16', True): (128, 512, 5120), (6, 128, 1792, 5120, 'bfloat16', True): (128, 1792, 1280), (6, 128, 28672, 4096, 'bfloat16', True): (128, 28672, 256), + (6, 128, 3584, 18944, 'bfloat16', True): (128, 256, 18944), + (6, 128, 3584, 3584, 'bfloat16', True): (128, 3584, 896), + (6, 128, 37888, 3584, 'bfloat16', True): (128, 1024, 3584), (6, 128, 4096, 14336, 'bfloat16', True): (128, 4096, 896), (6, 128, 4096, 4096, 'bfloat16', True): (128, 512, 4096), + (6, 128, 4608, 3584, 'bfloat16', True): (128, 768, 3584), (6, 128, 5120, 1280, 'bfloat16', True): (128, 1280, 1280), (6, 128, 5120, 3456, 'bfloat16', True): (128, 640, 3456), (6, 128, 5120, 640, 'bfloat16', True): (128, 2560, 640), @@ -360,8 +368,12 @@ def quantized_matmul_int8( (6, 16, 13824, 5120, 'bfloat16', True): (128, 512, 5120), (6, 16, 1792, 5120, 'bfloat16', True): (128, 896, 2560), (6, 16, 28672, 4096, 'bfloat16', True): (128, 28672, 256), + (6, 16, 3584, 18944, 'bfloat16', True): (128, 256, 18944), + (6, 16, 3584, 3584, 'bfloat16', True): (128, 896, 3584), + (6, 16, 37888, 3584, 'bfloat16', True): (128, 1024, 3584), (6, 16, 4096, 14336, 'bfloat16', True): (128, 4096, 896), (6, 16, 4096, 4096, 'bfloat16', True): (128, 512, 4096), + (6, 16, 4608, 3584, 'bfloat16', True): (128, 768, 3584), (6, 16, 5120, 1280, 'bfloat16', True): (128, 1280, 1280), (6, 16, 5120, 3456, 'bfloat16', True): (128, 640, 3456), (6, 16, 5120, 640, 'bfloat16', True): (128, 2560, 640), @@ -374,6 +386,10 @@ def quantized_matmul_int8( (6, 16, 896, 5120, 'bfloat16', True): (128, 896, 2560), (6, 16384, 13824, 5120, 'bfloat16', True): (2048, 1536, 5120), (6, 16384, 1792, 5120, 'bfloat16', True): (1024, 1792, 5120), + (6, 16384, 3584, 18944, 'bfloat16', True): (256, 3584, 18944), + (6, 16384, 3584, 3584, 'bfloat16', True): (512, 3584, 3584), + (6, 16384, 37888, 3584, 'bfloat16', True): (4096, 512, 3584), + (6, 16384, 4608, 3584, 'bfloat16', True): (512, 4608, 3584), (6, 16384, 5120, 1280, 'bfloat16', True): (512, 5120, 1280), (6, 16384, 5120, 3456, 'bfloat16', True): (512, 5120, 3456), (6, 16384, 5120, 640, 'bfloat16', True): (512, 5120, 640), @@ -384,8 +400,12 @@ def quantized_matmul_int8( (6, 2048, 13824, 5120, 'bfloat16', True): (2048, 768, 5120), (6, 2048, 1792, 5120, 'bfloat16', True): (2048, 256, 5120), (6, 2048, 28672, 4096, 'bfloat16', True): (2048, 1024, 4096), + (6, 2048, 3584, 18944, 'bfloat16', True): (2048, 3584, 512), + (6, 2048, 3584, 3584, 'bfloat16', True): (2048, 512, 3584), + (6, 2048, 37888, 3584, 'bfloat16', True): (2048, 1024, 3584), (6, 2048, 4096, 14336, 'bfloat16', True): (2048, 4096, 512), (6, 2048, 4096, 4096, 'bfloat16', True): (2048, 512, 4096), + (6, 2048, 4608, 3584, 'bfloat16', True): (2048, 512, 3584), (6, 2048, 5120, 1280, 'bfloat16', True): (256, 5120, 1280), (6, 2048, 5120, 3456, 'bfloat16', True): (2048, 512, 3456), (6, 2048, 5120, 640, 'bfloat16', True): (256, 5120, 640), @@ -400,8 +420,12 @@ def quantized_matmul_int8( (6, 256, 13824, 5120, 'bfloat16', True): (256, 512, 5120), (6, 256, 1792, 5120, 'bfloat16', True): (256, 1792, 1280), (6, 256, 28672, 4096, 'bfloat16', True): (256, 2048, 4096), + (6, 256, 3584, 18944, 'bfloat16', True): (256, 256, 18944), + (6, 256, 3584, 3584, 'bfloat16', True): (256, 896, 3584), + (6, 256, 37888, 3584, 'bfloat16', True): (256, 4736, 896), (6, 256, 4096, 14336, 'bfloat16', True): (256, 4096, 512), (6, 256, 4096, 4096, 'bfloat16', True): (256, 512, 4096), + (6, 256, 4608, 3584, 'bfloat16', True): (256, 768, 3584), (6, 256, 5120, 1280, 'bfloat16', True): (256, 2560, 1280), (6, 256, 5120, 3456, 'bfloat16', True): (256, 1024, 3456), (6, 256, 5120, 640, 'bfloat16', True): (256, 2560, 640), @@ -416,8 +440,12 @@ def quantized_matmul_int8( (6, 32, 13824, 5120, 'bfloat16', True): (128, 512, 5120), (6, 32, 1792, 5120, 'bfloat16', True): (128, 896, 2560), (6, 32, 28672, 4096, 'bfloat16', True): (128, 28672, 256), + (6, 32, 3584, 18944, 'bfloat16', True): (128, 128, 18944), + (6, 32, 3584, 3584, 'bfloat16', True): (128, 896, 3584), + (6, 32, 37888, 3584, 'bfloat16', True): (128, 1024, 3584), (6, 32, 4096, 14336, 'bfloat16', True): (128, 4096, 896), (6, 32, 4096, 4096, 'bfloat16', True): (128, 512, 4096), + (6, 32, 4608, 3584, 'bfloat16', True): (128, 768, 3584), (6, 32, 5120, 1280, 'bfloat16', True): (128, 1280, 1280), (6, 32, 5120, 3456, 'bfloat16', True): (128, 640, 3456), (6, 32, 5120, 640, 'bfloat16', True): (128, 2560, 640), @@ -430,6 +458,10 @@ def quantized_matmul_int8( (6, 32, 896, 5120, 'bfloat16', True): (128, 896, 2560), (6, 4096, 13824, 5120, 'bfloat16', True): (2048, 1536, 5120), (6, 4096, 1792, 5120, 'bfloat16', True): (512, 1792, 5120), + (6, 4096, 3584, 18944, 'bfloat16', True): (2048, 3584, 512), + (6, 4096, 3584, 3584, 'bfloat16', True): (4096, 256, 3584), + (6, 4096, 37888, 3584, 'bfloat16', True): (4096, 512, 3584), + (6, 4096, 4608, 3584, 'bfloat16', True): (4096, 512, 3584), (6, 4096, 5120, 1280, 'bfloat16', True): (256, 5120, 1280), (6, 4096, 5120, 3456, 'bfloat16', True): (4096, 512, 3456), (6, 4096, 5120, 640, 'bfloat16', True): (256, 5120, 640), @@ -440,8 +472,12 @@ def quantized_matmul_int8( (6, 512, 13824, 5120, 'bfloat16', True): (512, 13824, 512), (6, 512, 1792, 5120, 'bfloat16', True): (512, 1792, 1280), (6, 512, 28672, 4096, 'bfloat16', True): (512, 2048, 4096), + (6, 512, 3584, 18944, 'bfloat16', True): (512, 256, 18944), + (6, 512, 3584, 3584, 'bfloat16', True): (512, 1792, 3584), + (6, 512, 37888, 3584, 'bfloat16', True): (512, 18944, 512), (6, 512, 4096, 14336, 'bfloat16', True): (512, 256, 14336), (6, 512, 4096, 4096, 'bfloat16', True): (512, 1024, 4096), + (6, 512, 4608, 3584, 'bfloat16', True): (512, 768, 3584), (6, 512, 5120, 1280, 'bfloat16', True): (512, 2560, 1280), (6, 512, 5120, 3456, 'bfloat16', True): (512, 1280, 3456), (6, 512, 5120, 640, 'bfloat16', True): (512, 2560, 640), @@ -456,8 +492,12 @@ def quantized_matmul_int8( (6, 64, 13824, 5120, 'bfloat16', True): (128, 512, 5120), (6, 64, 1792, 5120, 'bfloat16', True): (128, 896, 2560), (6, 64, 28672, 4096, 'bfloat16', True): (128, 28672, 256), + (6, 64, 3584, 18944, 'bfloat16', True): (128, 256, 18944), + (6, 64, 3584, 3584, 'bfloat16', True): (128, 896, 3584), + (6, 64, 37888, 3584, 'bfloat16', True): (128, 1024, 3584), (6, 64, 4096, 14336, 'bfloat16', True): (128, 4096, 896), (6, 64, 4096, 4096, 'bfloat16', True): (128, 512, 4096), + (6, 64, 4608, 3584, 'bfloat16', True): (128, 768, 3584), (6, 64, 5120, 1280, 'bfloat16', True): (128, 1280, 1280), (6, 64, 5120, 3456, 'bfloat16', True): (128, 1024, 3456), (6, 64, 5120, 640, 'bfloat16', True): (128, 2560, 640), @@ -470,6 +510,10 @@ def quantized_matmul_int8( (6, 64, 896, 5120, 'bfloat16', True): (128, 896, 2560), (6, 8192, 13824, 5120, 'bfloat16', True): (2048, 1536, 5120), (6, 8192, 1792, 5120, 'bfloat16', True): (512, 1792, 5120), + (6, 8192, 3584, 18944, 'bfloat16', True): (2048, 3584, 512), + (6, 8192, 3584, 3584, 'bfloat16', True): (4096, 512, 3584), + (6, 8192, 37888, 3584, 'bfloat16', True): (4096, 1024, 3584), + (6, 8192, 4608, 3584, 'bfloat16', True): (4096, 512, 3584), (6, 8192, 5120, 1280, 'bfloat16', True): (256, 5120, 1280), (6, 8192, 5120, 3456, 'bfloat16', True): (512, 5120, 3456), (6, 8192, 5120, 640, 'bfloat16', True): (512, 5120, 640), From cb64f4c7d7979bdf1077f0da9b06ed64c8efa74e Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 30 Jul 2025 11:00:07 -0300 Subject: [PATCH 021/133] Deduplicate `GetXlaTensors()` function. (#9518) --- torch_xla/csrc/init_python_bindings.cpp | 63 +++++++++++++------------ 1 file changed, 32 insertions(+), 31 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index bd4152aee811..da2701bb21db 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -15,6 +15,8 @@ #include #include #include +#include +#include #include #include #include @@ -268,20 +270,18 @@ std::vector GetXlaDevices( return xla_devices; } -std::vector GetXlaTensors(const std::vector& tensors, - bool want_all) { +// Collects all valid `XLATensorPtr` out of `tensors`. +// +// Iterates through `tensors`, collecting every `XLATensorPtr` value, +// ignoring those that return with a non-ok status. +static std::vector CollectXlaTensors( + const std::vector& tensors) { std::vector xtensors; - xtensors.reserve(tensors.size()); - if (want_all) { - for (auto& tensor : tensors) { - xtensors.push_back(GetValueOrThrow(bridge::GetXlaTensor(tensor))); - } - } else { - for (auto& tensor : tensors) { - auto xtensor_status = bridge::GetXlaTensor(tensor); - if (xtensor_status.ok()) { - xtensors.push_back(std::move(xtensor_status).value()); - } + for (auto& tensor : tensors) { + auto xla_tensor_status = bridge::GetXlaTensor(tensor); + if (xla_tensor_status.ok()) { + // Insert only those that can be successfully retrieved. + xtensors.push_back(std::move(xla_tensor_status).value()); } } return xtensors; @@ -396,11 +396,11 @@ void AllReduceInPlace(const std::string& reduce_type, bool pin_layout) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); std::vector xtensors = - GetXlaTensors(tensors, /*want_all=*/true); + GetValueOrThrow(bridge::GetXlaTensors(tensors)); tensor_methods::all_reduce(xtensors, GetReduceType(reduce_type), scale, replica_groups, pin_layout); std::vector new_xtensors = - GetXlaTensors(tensors, /*want_all=*/true); + GetValueOrThrow(bridge::GetXlaTensors(tensors)); MaybeThrow(bridge::ReplaceXlaTensor(tensors, new_xtensors)); } @@ -506,7 +506,8 @@ ReduceScatterCoalesced(const std::string& reduce_type, double scale, int64_t scatter_dim, int64_t shard_count, const std::vector>& replica_groups, bool pin_layout) { - std::vector xtensors = GetXlaTensors(inputs, /*want_all=*/true); + std::vector xtensors = + GetValueOrThrow(bridge::GetXlaTensors(inputs)); std::vector result; torch::lazy::Value new_token; std::tie(result, new_token) = tensor_methods::reduce_scatter_coalesced( @@ -526,8 +527,9 @@ std::shared_ptr ReduceScatterCoalescedOut( int64_t scatter_dim, int64_t shard_count, const std::vector>& replica_groups, bool pin_layout) { std::vector xtensors_out = - GetXlaTensors(outputs, /*want_all=*/true); - std::vector xtensors = GetXlaTensors(inputs, /*want_all=*/true); + GetValueOrThrow(bridge::GetXlaTensors(outputs)); + std::vector xtensors = + GetValueOrThrow(bridge::GetXlaTensors(inputs)); torch::lazy::Value new_token; new_token = tensor_methods::reduce_scatter_coalesced_out( xtensors_out, xtensors, *token, GetReduceType(reduce_type), scale, @@ -568,7 +570,7 @@ AllGatherCoalesced(const std::vector& tensors, const std::vector>& replica_groups, bool pin_layout) { std::vector xtensors = - GetXlaTensors(tensors, /*want_all=*/true); + GetValueOrThrow(bridge::GetXlaTensors(tensors)); std::vector result; torch::lazy::Value new_token; std::tie(result, new_token) = tensor_methods::all_gather_coalesced( @@ -586,8 +588,9 @@ std::shared_ptr AllGatherCoalescedOut( int64_t shard_count, const std::vector>& replica_groups, bool pin_layout) { std::vector xtensors_out = - GetXlaTensors(outputs, /*want_all=*/true); - std::vector xtensors = GetXlaTensors(inputs, /*want_all=*/true); + GetValueOrThrow(bridge::GetXlaTensors(outputs)); + std::vector xtensors = + GetValueOrThrow(bridge::GetXlaTensors(inputs)); torch::lazy::Value new_token; new_token = tensor_methods::all_gather_coalesced_out( xtensors_out, xtensors, *token, dim, shard_count, replica_groups, @@ -624,8 +627,7 @@ std::pair> CollectivePermute( } void OptimizationBarrier_(std::vector& tensors) { - std::vector xtensors = - GetXlaTensors(tensors, /*want_all=*/false); + auto xtensors = CollectXlaTensors(tensors); tensor_methods::optimization_barrier_(xtensors); } @@ -654,8 +656,7 @@ std::pair> Recv( void SyncTensors(const std::vector& tensors, const std::vector& devices, bool wait, bool sync_xla_data, bool warm_up_cache_only = false) { - std::vector xtensors = - GetXlaTensors(tensors, /*want_all=*/false); + std::vector xtensors = CollectXlaTensors(tensors); XLAGraphExecutor::Get()->SyncTensorsGraph(&xtensors, devices, wait, sync_xla_data, warm_up_cache_only); } @@ -704,8 +705,7 @@ uint64_t GetRngSeed(const std::string& device_str) { std::string GetTensorsHloGraph(const std::vector& tensors, EmitMode mode) { - std::vector xtensors = - GetXlaTensors(tensors, /*want_all=*/false); + std::vector xtensors = CollectXlaTensors(tensors); return XLAGraphExecutor::Get()->DumpHloComputation(xtensors, mode); } @@ -884,7 +884,8 @@ py::object GetRevisions() { std::vector XlaUserComputation( const std::string& opname, const std::vector& inputs, runtime::ComputationClient::ComputationPtr computation) { - std::vector xinputs = GetXlaTensors(inputs, /*want_all=*/true); + std::vector xinputs = + GetValueOrThrow(bridge::GetXlaTensors(inputs)); std::vector xresults = tensor_methods::user_computation(opname, xinputs, std::move(computation)); std::vector results; @@ -1141,7 +1142,7 @@ class PyLoweringContext { void Build(std::vector tensors) { // Get the backing XLA tensors from the output torch tensor handles std::vector xtensors = - GetXlaTensors(tensors, /*want_all=*/true); + GetValueOrThrow(bridge::GetXlaTensors(tensors)); // Get the lazy IR value from the output XLA tensors std::vector ir_values; @@ -1168,7 +1169,7 @@ class PyLoweringContext { std::vector additional_inputs_list = {}) { // Get the backing XLA tensors from the output torch tensor handles std::vector xtensors = - GetXlaTensors(tensors, /*want_all=*/true); + GetValueOrThrow(bridge::GetXlaTensors(tensors)); // Get the lazy IR value from the output XLA tensors std::vector ir_values; @@ -2285,7 +2286,7 @@ void InitXlaModuleBindings(py::module m) { xtensors = XLAGraphExecutor::Get()->GetLiveTensors(&backend_device); } else { - xtensors = GetXlaTensors(tensors, /*want_all=*/false); + xtensors = CollectXlaTensors(tensors); } return py::bytes( XLAGraphExecutor::Get()->DumpHloComputation(xtensors, mode)); From 95bee8f085a698b99b902a67d4b9e7bdc7d6013e Mon Sep 17 00:00:00 2001 From: Hoomaaan <33916130+Hoomaaan@users.noreply.github.com> Date: Wed, 30 Jul 2025 09:35:40 -0700 Subject: [PATCH 022/133] [XLA] Add placements property to XLAShardedTensor for DTensor compatibility (#9509) --- test/neuron/run_tests.sh | 1 + test/run_tests.sh | 1 + test/spmd/test_xla_dtensor_placements.py | 95 +++++++++++++++++++ test/tpu/run_tests.sh | 1 + .../distributed/spmd/xla_sharded_tensor.py | 21 +++- 5 files changed, 118 insertions(+), 1 deletion(-) create mode 100644 test/spmd/test_xla_dtensor_placements.py diff --git a/test/neuron/run_tests.sh b/test/neuron/run_tests.sh index f7671cc3d827..ecc302aa30fc 100755 --- a/test/neuron/run_tests.sh +++ b/test/neuron/run_tests.sh @@ -270,6 +270,7 @@ function run_xla_op_tests3 { run_test "$_TEST_DIR/test_persistent_cache.py" run_test "$_TEST_DIR/test_devices.py" run_xla_ir_hlo_debug run_test "$_TEST_DIR/test_user_computation_debug_cache.py" + run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_placements.py" #python3 examples/data_parallel/train_resnet_xla_ddp.py # compiler error #python3 examples/fsdp/train_resnet_fsdp_auto_wrap.py diff --git a/test/run_tests.sh b/test/run_tests.sh index ec92a0a2691c..66c8bbff0406 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -270,6 +270,7 @@ function run_xla_op_tests3 { run_test "$_TEST_DIR/test_persistent_cache.py" run_test "$_TEST_DIR/test_devices.py" run_test "$_TEST_DIR/test_manual_xla_registration.py" + run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_placements.py" # NOTE: this line below is testing export and don't care about GPU PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_coverage "$_TEST_DIR/test_core_aten_ops.py" run_test "$_TEST_DIR/test_pallas.py" diff --git a/test/spmd/test_xla_dtensor_placements.py b/test/spmd/test_xla_dtensor_placements.py new file mode 100644 index 000000000000..e48cfade1e4a --- /dev/null +++ b/test/spmd/test_xla_dtensor_placements.py @@ -0,0 +1,95 @@ +import os +import sys + +import torch +from torch.distributed.tensor import DeviceMesh, Shard, distribute_tensor +from torch.distributed.tensor.placement_types import Replicate + +import torch_xla +import torch_xla.runtime as xr +from torch_xla.distributed.spmd import XLAShardedTensor +from torch_xla.distributed.spmd.xla_sharding import wrap_as_sharded_tensor + +import unittest +import test_xla_sharding_base + + +class XLADTensorSpecConversionTest(test_xla_sharding_base.XlaShardingTest): + + @classmethod + def setUpClass(cls): + super().setUpClass() + + def test_placements_basic(self): + """Test that placements property works when XLAShardedTensor is properly initialized.""" + + world_size = xr.global_runtime_device_count() + mesh = DeviceMesh("xla", torch.arange(world_size)) + big_tensor = torch.randn(100_000, 88) + + # Create a sharded tensor with placements + my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(0)]) + + # Test that placements property works on XLAShardedTensor + assert hasattr( + my_dtensor, + 'placements'), "XLAShardedTensor should have placements property" + assert my_dtensor.placements == ( + Shard(0),), f"Expected (Shard(0),), got {my_dtensor.placements}" + + def test_placements_failure(self): + """Test that placements property provides a helpful error message when sharding info is missing.""" + big_tensor = torch.randn(100_000, 88) + + # Create XLAShardedTensor without sharding information + xla_tensor = XLAShardedTensor(big_tensor) + + # Test that accessing placements raises the expected error + with self.assertRaises(ValueError) as context: + _ = xla_tensor.placements + + expected_message = ( + "Placements not available: XLAShardedTensor requires mesh_shape and " + "partition_spec to be set. Use mark_sharding() to properly initialize sharding information." + ) + self.assertEqual( + str(context.exception), expected_message, + "Error message should match exactly for user clarity") + + def test_placements_caching_behavior(self): + """Test that placements property uses caching correctly.""" + world_size = xr.global_runtime_device_count() + mesh = DeviceMesh("xla", torch.arange(world_size)) + big_tensor = torch.randn(100_000, 88) + + # Create properly sharded tensor + my_dtensor = distribute_tensor(big_tensor, mesh, [Replicate()]) + + # First access should create the cache + placements1 = my_dtensor.placements + self.assertIsNotNone(my_dtensor._cached_spec, + "Cache should be created after first access") + + # Second access should use cache + placements2 = my_dtensor.placements + self.assertEqual(placements1, placements2, + "Cached placements should be identical") + self.assertEqual(placements1, (Replicate(),), + f"Expected (Replicate(),), got {placements1}") + + # Invalidate cache and verify third access + my_dtensor.invalidate_spec_cache() + self.assertIsNone(my_dtensor._cached_spec, + "Cache should be None after invalidation") + + # Third access + placements3 = my_dtensor.placements + self.assertIsNotNone(my_dtensor._cached_spec, + "Cache should be recreated after invalidation") + self.assertEqual(placements3, (Replicate(),), + "New cache should have correct placements") + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 24f18d3bdcda..017fed5294fa 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -97,3 +97,4 @@ run_test "$_TEST_DIR/quantized_ops/test_dot_general.py" run_xla_ir_hlo_debug run_test "$_TEST_DIR/test_user_computation_debug_cache.py" run_test "$_TEST_DIR/test_data_type.py" run_test "$_TEST_DIR/test_compilation_cache_utils.py" +run_test "$_TEST_DIR/spmd/test_xla_dtensor_placements.py" diff --git a/torch_xla/distributed/spmd/xla_sharded_tensor.py b/torch_xla/distributed/spmd/xla_sharded_tensor.py index a20d530f3faa..c6a9a5d4f58c 100644 --- a/torch_xla/distributed/spmd/xla_sharded_tensor.py +++ b/torch_xla/distributed/spmd/xla_sharded_tensor.py @@ -9,7 +9,7 @@ import torch_xla.runtime as xr from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor.placement_types import Shard, Replicate +from torch.distributed.tensor.placement_types import Placement, Shard, Replicate from torch.utils._pytree import tree_map_only @@ -241,6 +241,25 @@ def _spec(self): mesh=mesh, placements=tuple(placements), tensor_meta=tensor_meta) return self._cached_spec + @property + def placements(self) -> tuple[Placement, ...]: + """ + Get the placements of this XLAShardedTensor. + + Returns: + tuple[Placement, ...]: The placements that describe how this tensor is distributed. + + Raises: + ValueError: If _spec cannot be created due to missing mesh_shape or partition_spec. + """ + try: + return self._spec.placements + except: + raise ValueError( + "Placements not available: XLAShardedTensor requires mesh_shape and " + "partition_spec to be set. Use mark_sharding() to properly initialize sharding information." + ) + def invalidate_spec_cache(self): """Invalidate the cached DTensorSpec.""" self._cached_spec = None From 241cd47f1499df53a82d6a94c2dcf39dd261eaa4 Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Wed, 30 Jul 2025 17:20:53 -0700 Subject: [PATCH 023/133] Update artifacts_builds.tf for 2.8.0-rc2 (#9522) --- infra/tpu-pytorch-releases/artifacts_builds.tf | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/infra/tpu-pytorch-releases/artifacts_builds.tf b/infra/tpu-pytorch-releases/artifacts_builds.tf index 3993b9ed42eb..8722a205475b 100644 --- a/infra/tpu-pytorch-releases/artifacts_builds.tf +++ b/infra/tpu-pytorch-releases/artifacts_builds.tf @@ -2,9 +2,9 @@ # Define common configuration parameters for 2.8 release and nightly locals { tpu_python_versions = ["3.9", "3.10", "3.11", "3.12", "3.13"] - release_git_tag = "v2.8.0-rc1" - release_package_version = "2.8.0-rc1" - release_pytorch_git_rev = "v2.8.0-rc1" + release_git_tag = "v2.8.0-rc2" + release_package_version = "2.8.0-rc2" + release_pytorch_git_rev = "v2.8.0-rc8" nightly_package_version = "2.9.0" cuda_versions = { "nightly": [], From c807ebcbce97b9369323062975727027c5fec51e Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Thu, 31 Jul 2025 15:26:08 -0700 Subject: [PATCH 024/133] Update artifacts_builds.tf for 2.8.0-rc3 wheel (#9527) --- infra/tpu-pytorch-releases/artifacts_builds.tf | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/infra/tpu-pytorch-releases/artifacts_builds.tf b/infra/tpu-pytorch-releases/artifacts_builds.tf index 8722a205475b..f17ddf5ec2f1 100644 --- a/infra/tpu-pytorch-releases/artifacts_builds.tf +++ b/infra/tpu-pytorch-releases/artifacts_builds.tf @@ -2,8 +2,8 @@ # Define common configuration parameters for 2.8 release and nightly locals { tpu_python_versions = ["3.9", "3.10", "3.11", "3.12", "3.13"] - release_git_tag = "v2.8.0-rc2" - release_package_version = "2.8.0-rc2" + release_git_tag = "v2.8.0-rc3" + release_package_version = "2.8.0-rc3" release_pytorch_git_rev = "v2.8.0-rc8" nightly_package_version = "2.9.0" cuda_versions = { From 83d4253ef21aeabf8c67c3640cb664633f219ca5 Mon Sep 17 00:00:00 2001 From: qihqi Date: Fri, 1 Aug 2025 01:13:22 -0700 Subject: [PATCH 025/133] make jax as an optional dependency (#9521) --- .circleci/common.sh | 1 + .github/workflows/_test.yml | 1 + .github/workflows/_tpu_ci.yml | 3 +- CONTRIBUTING.md | 3 ++ README.md | 3 ++ setup.py | 8 +--- test/tpu/xla_test_job.yaml | 1 + torch_xla/_dynamo/dynamo_backend2.py | 2 +- torch_xla/_internal/jax_workarounds.py | 15 +++++- torch_xla/core/xla_builder.py | 55 ++++++++++++++++------ torch_xla/debug/profiler.py | 12 +++-- torch_xla/distributed/spmd/xla_sharding.py | 8 ++-- torch_xla/experimental/assume_pure.py | 6 +++ 13 files changed, 88 insertions(+), 30 deletions(-) diff --git a/.circleci/common.sh b/.circleci/common.sh index 1dd73ff836e2..df145603db42 100755 --- a/.circleci/common.sh +++ b/.circleci/common.sh @@ -112,6 +112,7 @@ function build_torch_xla() { # Need to uncomment the line below. # Currently it fails upstream XLA CI. # pip install plugins/cuda -v + pip install 'torch_xla[pallas]' popd } diff --git a/.github/workflows/_test.yml b/.github/workflows/_test.yml index 4f1cb7899cc1..5dca5764dd52 100644 --- a/.github/workflows/_test.yml +++ b/.github/workflows/_test.yml @@ -140,6 +140,7 @@ jobs: set -x pip install expecttest unittest-xml-reporting + pip install torch_xla[pallas] if [[ ! -z "$RUN_BENCHMARK_TESTS" ]]; then pip install -r pytorch/xla/benchmarks/requirements.txt diff --git a/.github/workflows/_tpu_ci.yml b/.github/workflows/_tpu_ci.yml index cd7be780a273..662c5a24ee25 100644 --- a/.github/workflows/_tpu_ci.yml +++ b/.github/workflows/_tpu_ci.yml @@ -51,7 +51,8 @@ jobs: pip install --upgrade pip pip install fsspec pip install rich - # libtpu is needed for pallas tests. + # jax and libtpu is needed for pallas tests. + pip install torch_xla[pallas] pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-wheels/index.html -f https://storage.googleapis.com/libtpu-releases/index.html pip install --upgrade protobuf - name: Run Tests (${{ matrix.test_script }}) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 25357dcfc2d3..d01f657c7609 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -160,6 +160,9 @@ commands on your Linux machine directly, outside of the container. pip install torch_xla[tpu] \ -f https://storage.googleapis.com/libtpu-wheels/index.html \ -f https://storage.googleapis.com/libtpu-releases/index.html + + # Optional: if you're using custom kernels, install pallas dependencies + pip install torch_xla[pallas] ``` 1. If you are running on a TPU VM, ensure `torch` and `torch_xla` were built and diff --git a/README.md b/README.md index 446453705c3d..8e16c0d99ae4 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,9 @@ Note: Builds are available for Python 3.8 to 3.11; please use one of the support # conda create -n py311 python=3.11 pip install torch==2.7.0 'torch_xla[tpu]==2.7.0' + +# Optional: if you're using custom kernels, install pallas dependencies +pip install torch_xla[pallas] ``` **As of 07/16/2025 and starting from Pytorch/XLA 2.8 release, PyTorch/XLA will provide nightly and release wheels for Python 3.11 to 3.13** diff --git a/setup.py b/setup.py index 867afc1b9052..11824cd08a47 100644 --- a/setup.py +++ b/setup.py @@ -480,8 +480,6 @@ def _get_jax_install_requirements(): # importlib.metadata backport required for PJRT plugin discovery prior # to Python 3.10 'importlib_metadata>=4.6;python_version<"3.10"', - # Some torch operations are lowered to HLO via JAX. - *_get_jax_install_requirements(), ], package_data={ 'torch_xla': ['lib/*.so*',], @@ -503,10 +501,8 @@ def _get_jax_install_requirements(): f'libtpu=={_libtpu_version}', 'tpu-info', ], - # As of https://github.com/pytorch/xla/pull/8895, jax is always a dependency of torch_xla. - # However, this no-op extras_require entrypoint is left here for backwards compatibility. - # pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html - 'pallas': [f'jaxlib=={_jaxlib_version}', f'jax=={_jax_version}'], + # pip install torch_xla[pallas] + 'pallas': [*_get_jax_install_requirements(),] }, cmdclass={ 'build_ext': BuildBazelExtension, diff --git a/test/tpu/xla_test_job.yaml b/test/tpu/xla_test_job.yaml index e85b05dfacfc..e7f5258b9dde 100644 --- a/test/tpu/xla_test_job.yaml +++ b/test/tpu/xla_test_job.yaml @@ -43,6 +43,7 @@ spec: - | pip install expecttest==0.1.6 pip install rich + pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html cd /src/pytorch/xla volumeMounts: diff --git a/torch_xla/_dynamo/dynamo_backend2.py b/torch_xla/_dynamo/dynamo_backend2.py index 7f59a1726d4e..eb9bc1e40656 100644 --- a/torch_xla/_dynamo/dynamo_backend2.py +++ b/torch_xla/_dynamo/dynamo_backend2.py @@ -28,7 +28,7 @@ def _dynamo_backend(model: torch.fx.GraphModule, sample_args: Any): import torchax.interop from torchax.export import JaxInterpreter import jax - except ImportError: + except (ImportError, ModuleNotFoundError): print('To use this dynamo backend, please install torchax') raise diff --git a/torch_xla/_internal/jax_workarounds.py b/torch_xla/_internal/jax_workarounds.py index 715ac68f1839..04f37f8c0a00 100644 --- a/torch_xla/_internal/jax_workarounds.py +++ b/torch_xla/_internal/jax_workarounds.py @@ -2,6 +2,7 @@ from contextlib import contextmanager from typing import Callable, Any import functools +import logging # TODO(https://github.com/pytorch/xla/issues/8793): Get rid of this hack. @@ -53,5 +54,17 @@ def maybe_get_torchax(): import torchax.interop import torchax.ops.mappings return torchax - except ImportError: + except (ModuleNotFoundError, ImportError): return None + + +def maybe_get_jax(): + try: + jax_import_guard() + with jax_env_context(): + import jax + return jax + except (ModuleNotFoundError, ImportError): + logging.warn('You are trying to use a feature that requires jax/pallas.' + 'You can install Jax/Pallas via pip install torch_xla[pallas]') + return None \ No newline at end of file diff --git a/torch_xla/core/xla_builder.py b/torch_xla/core/xla_builder.py index 2246377c91dc..928200d15aee 100644 --- a/torch_xla/core/xla_builder.py +++ b/torch_xla/core/xla_builder.py @@ -3,8 +3,11 @@ from weakref import WeakKeyDictionary import torch import torch_xla -from torch.utils._pytree import tree_flatten -from torch_xla._internal.jax_workarounds import jax_env_context, jax_import_guard, requires_jax, maybe_get_torchax +from torch_xla._internal.jax_workarounds import (jax_env_context, + jax_import_guard, requires_jax, + maybe_get_torchax, + maybe_get_jax) +from torch.utils import _pytree as pytree import torch_xla.debug.profiler as xp import abc @@ -883,9 +886,8 @@ def __init__(self, orig_func): def preprocess(self, args, kwargs=None): with jax_env_context(): - import jax kwargs = kwargs or {} - flattened_inputs, spec = jax.tree.flatten((args, kwargs)) + flattened_inputs, spec = self.flatten((args, kwargs)) tensors = tuple( a for a in flattened_inputs if isinstance(a, torch.Tensor)) self.non_tensors = tuple( @@ -899,7 +901,6 @@ def preprocess(self, args, kwargs=None): def flat_call(self, flat_input): with jax_env_context(): - import jax assert self.in_spec is not None, 'flat call only makes sense after preprocess is called' # Put the tensor input and the non tensor input together @@ -909,19 +910,25 @@ def flat_call(self, flat_input): if new_flattened[i] is self._sentinel: new_flattened[i] = next(tensor_args_iter) - args, kwargs = jax.tree.unflatten(self.in_spec, new_flattened) + args, kwargs = self.unflatten(new_flattened, self.in_spec) res = self.orig_func(*args, **kwargs) - flattened_out, spec = jax.tree.flatten(res) + flattened_out, spec = self.flatten(res) self.out_spec = spec return flattened_out def postprocess(self, res_flattened): with jax_env_context(): - import jax assert self.out_spec is not None, 'post process only makes sense after flat_call is called' - res = jax.tree.unflatten(self.out_spec, res_flattened) + res = self.unflatten(res_flattened, self.out_spec) return res + # Methods to allow subclass to customize how to flatten/unflatten + def flatten(self, inputs): + return pytree.tree_flatten(inputs) + + def unflatten(self, flattened, spec): + return pytree.tree_unflatten(flattened, spec) + class CompiledCallableWithCache(abc.ABC): """This class is meant to be subclassed. @@ -974,6 +981,22 @@ def preprocess(self, args, kwargs=None): for a in self.non_tensors) return res + def flatten(self, inputs): + # use jax pytree because it can also handle vjp stuff that + # pytorch pytree cannot + jax = maybe_get_jax() + assert jax is not None, 'Jax dependency is required for calling Jax function' + res, spec = jax.tree.flatten(inputs) + return res, spec + + def unflatten(self, flattened, spec): + # use jax pytree because it can also handle vjp stuff that + # pytorch pytree cannot + jax = maybe_get_jax() + assert jax is not None, 'Jax dependency is required for calling Jax function' + res = jax.tree.unflatten(spec, flattened) + return res + class JaxCallable(CompiledCallableWithCache): @@ -981,8 +1004,11 @@ def __init__(self, jax_func): super().__init__(JaxFlattenedInputFunc(jax_func)) def specialize(self, sample_flat_args): - import jax + jax = maybe_get_jax() tx = maybe_get_torchax() + if jax is None or tx is None: + raise AssertionError('Jax is required for this feature') + sample_flat_args = tuple( jax.ShapeDtypeStruct(a.shape, tx.ops.mappings.t2j_dtype(a.dtype) ) if a is not None else None @@ -1090,11 +1116,12 @@ def call_jax(jax_func, works. If you get tracing overhead, check if `jax_func` is being redefined all the time. A common mistake is defining `jax_func` as a local function, e.g. during a training step. """ - import jax - from jax._src import config - + jax = maybe_get_jax() tx = maybe_get_torchax() - flattened, _ = jax.tree.flatten((args, kwargs)) + if jax is None or tx is None: + raise AssertionError('Jax is required for this feature') + from jax._src import config + flattened, _ = pytree.tree_flatten((args, kwargs)) kwargs = kwargs or {} if tx is not None and any(isinstance(a, tx.tensor.Tensor) for a in flattened): return tx.interop.call_jax(jax_func, *args, **kwargs) diff --git a/torch_xla/debug/profiler.py b/torch_xla/debug/profiler.py index bdf3632a799b..4e046248b42f 100644 --- a/torch_xla/debug/profiler.py +++ b/torch_xla/debug/profiler.py @@ -5,6 +5,7 @@ import torch_xla import torch_xla.core.xla_model as xm +from torch_xla._internal.jax_workarounds import maybe_get_jax _TRACER_MARKED_STEP: bool = False @@ -128,13 +129,16 @@ def __enter__(self): self.scope = torch_xla._XLAC.profiler.scope_pusher(self.name) super().__enter__() + self._jax_scope = None # Also enter the JAX named scope, to support torchax lowering. - import jax - self._jax_scope = jax.named_scope(self.name) - self._jax_scope.__enter__() + if jax := maybe_get_jax(): + self._jax_scope = jax.named_scope(self.name) + self._jax_scope.__enter__() def __exit__(self, type, value, traceback): - self._jax_scope.__exit__(type, value, traceback) + if self._jax_scope is not None: + self._jax_scope.__exit__(type, value, traceback) + self._jax_scope = None if getattr(self, 'scope', None): del self.scope super().__exit__(type, value, traceback) diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index 5f4d4378e7d2..c010fd4c3523 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -13,7 +13,7 @@ from torch_xla.distributed.spmd import XLAShardedTensor, XLAShard import torch_xla.runtime as xr import torch_xla.debug.profiler as xp -from torch_xla._internal.jax_workarounds import requires_jax, maybe_get_torchax +from torch_xla._internal.jax_workarounds import requires_jax, maybe_get_torchax, maybe_get_jax import numpy as np import functools @@ -185,7 +185,8 @@ def from_str(cls, mesh_str: str) -> Optional["Mesh"]: def get_jax_mesh(self): # Construct a JAX mesh object with the same device ids shape and ordering # from torch_xla device mesh. - import jax + jax = maybe_get_jax() + assert jax is not None import numpy as np from jax._src import mesh as mesh_lib @@ -645,7 +646,8 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})." tx = maybe_get_torchax() - if tx is not None and isinstance(t, tx.tensor.Tensor): + jax = maybe_get_jax() + if (jax is not None) and (tx is not None) and isinstance(t, tx.tensor.Tensor): from jax.sharding import PartitionSpec as P, NamedSharding jmesh = mesh.get_jax_mesh() t.shard_(NamedSharding(jmesh, P(*partition_spec))) diff --git a/torch_xla/experimental/assume_pure.py b/torch_xla/experimental/assume_pure.py index cef830ed3112..3eb145c868aa 100644 --- a/torch_xla/experimental/assume_pure.py +++ b/torch_xla/experimental/assume_pure.py @@ -7,6 +7,7 @@ import torch_xla from torch_xla._internal.jax_workarounds import requires_jax import torch_xla.core.xla_builder as xb +from torch_xla._internal.jax_workarounds import maybe_get_jax, maybe_get_torchax _XLA_COMPUTATION_CACHE = {} @@ -57,6 +58,11 @@ def add_randn(a): we can call add_randn_p(a, rng_seed=0) to get one result and add_randn_p(a, rng_seed=0) to get another result. """ + tx = maybe_get_torchax() + jax = maybe_get_jax() + if tx is None or jax is None: + raise AssertionError('Jax is required for this feature') + from torchax.interop import jax_view import torchax if add_rng_seed_argument: From d487007dad18524210a3da87c43b7535d8d996d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Melissa=20Weber=20Mendon=C3=A7a?= Date: Fri, 1 Aug 2025 13:46:08 -0300 Subject: [PATCH 026/133] Reorganize PyTorch/XLA Overview page (#9498) --- docs/source/index.rst | 4 + docs/source/learn/xla-advanced.md | 124 +++++++ docs/source/learn/xla-examples.md | 253 +++++++++++++ docs/source/learn/xla-overview.md | 566 +++-------------------------- docs/source/learn/xla-profiling.md | 113 ++++++ 5 files changed, 547 insertions(+), 513 deletions(-) create mode 100644 docs/source/learn/xla-advanced.md create mode 100644 docs/source/learn/xla-examples.md create mode 100644 docs/source/learn/xla-profiling.md diff --git a/docs/source/index.rst b/docs/source/index.rst index 75f52d4ed34e..f3271936ef5d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -38,6 +38,10 @@ Tutorials learn/pytorch-on-xla-devices learn/xla-overview + learn/xla-quickstart + learn/xla-examples + learn/xla-profiling + learn/xla-advanced .. toctree:: :glob: diff --git a/docs/source/learn/xla-advanced.md b/docs/source/learn/xla-advanced.md new file mode 100644 index 000000000000..a39a8128276e --- /dev/null +++ b/docs/source/learn/xla-advanced.md @@ -0,0 +1,124 @@ +# Advanced Topics in PyTorch XLA + +## Compilation, caching and execution + +HLO is fed to the XLA compiler +for compilation and optimization. Compilation is then cached by PyTorch +XLA to be reused later if/when needed. The compilation of the graph is +done on the host (CPU), which is the machine that runs the Python code. +If there are multiple XLA devices, the host compiles the code for each +of the devices separately except when using SPMD (single-program, +multiple-data) parallelization. For example, v4-8 has one host machine and [four +devices](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v4). +In this case the host compiles the code for each of the four devices +separately. In case of pod slices, when there are multiple hosts, each +host does the compilation for XLA devices it is attached to. If SPMD is +used, then the code is compiled only once (for given shapes and +computations) on each host for all the devices. + +## Synchronous execution and blocking + +The *synchronous* operations in Pytorch XLA, like printing, logging, +checkpointing or callbacks block tracing and result in slower execution. +In the case when an operation requires a specific value of an XLA +tensor, e.g. `print(xla_tensor_z)`, tracing is blocked until the value +of that tensor is available to the host. Note that only the part of the +graph responsible for computing that tensor value is executed. These +operations do not cut the IR graph, but they trigger host-device +communication through `TransferFromDevice`, which results in slower +performance. + +## Barriers + +A *barrier* is a special instruction that tells XLA to execute the IR +graph and materialize the tensors. This means that the PyTorch XLA +tensors will be evaluated, and the results will be available to the +host. The user-exposed barrier in Pytorch XLA is +[torch_xla.sync()](https://github.com/pytorch/xla/blob/bdceee54eca1269ee954f6cdd1868c584d0e88a4/torch_xla/core/xla_model.py#L808), +which breaks the IR graph and results in code execution on the XLA +devices. One of the key properties of `torch_xla.sync()` is that unlike +synchronous operations it does not block the further tracing while the +device is executing the graph. However, it does block access to the +values of the tensors that are being materialized. + +The example in the LazyTensor guide illustrates what happens in a simple +case of adding two tensors. Now, suppose we have a for loop that adds +XLA tensors and uses the value later: + +``` python +for x, y in tensors_on_device: + z += x + y +``` + +Without a barrier, the PyTorch tracing will result in a single graph that +wraps the addition of tensors `len(tensors_on_device)` times. This is +because the `for` loop is not captured by the tracing, so each iteration +of the loop will create a new subgraph corresponding to the computation +of `z += x+y` and add it to the graph. Here is an example when +`len(tensors_on_device)=3`. + +![img](../_static/img/IRgraph_no_markstep.png) + +However, introducing a barrier at the end of the loop will result in a +smaller graph that will be compiled once during the first pass inside +the `for` loop and will be reused for the next +`len(tensors_on_device)-1` iterations. The barrier will signal to the +tracing that the graph traced so far can be submitted for execution, and +if that graph has been seen before, a cached compiled program will be +reused. + +``` python +for x, y in tensors_on_device: + z += x + y + torch_xla.sync() +``` + +In this case there will be a small graph that is used +`len(tensors_on_device)=3` times. + +![img](../_static/img/IRgraph_markstep.png) + +It is important to highlight that in PyTorch XLA Python code inside for +loops is traced and a new graph is constructed for each iteration if +there is a barrier at the end. This can be a significant performance +bottleneck. + +## Graphs + +The XLA graphs can be reused when the same computation happens on the +same shapes of tensors. If the shapes of the inputs or intermediate +tensors change, the XLA compiler will recompile a new graph with +the new tensor shapes. If you have dynamic shapes or if +your code does not reuse tensor graphs, the XLA compiler will spend a +significant amount of time optimizing and fusing operations which will not be +used again. You can pad the inputs into a fixed shape to help avoid dynamic +shapes. + +The trade-off between graph size and compilation time is also important +to consider. If there is one large IR graph, the XLA compiler can spend +a lot of time on optimization and fusion of the ops. This can result in +a very long compilation time. However, the later execution may be much +faster, due to the optimizations that were performed during compilation. + +Sometimes it is worth breaking the IR graph with `torch_xla.sync()`. As +explained above, this will result in a smaller graph that can be reused +later. However making graphs smaller can reduce optimizations that +otherwise could be done by the XLA compiler. + +## Data Loading + +You can use +[MPDeviceLoader](https://github.com/pytorch/xla/blob/a1f822e2627a5639464273241821852677401026/torch_xla/distributed/parallel_loader.py#L186). +to preload data onto your XLA device to improve performance. `MPDeviceLoader` +uses `torch_xla.sync()` to automatically break the iterations over batches of +data and send them for execution. Note that if you are not using +`MPDeviceLoader`, you might need to set `barrier=True` in the `optimizer_step()` +to enable `torch_xla.sync()` if running a training job or explicitly adding +`torch_xla.sync()`. + +**Note:** + +0 and 1 are magic numbers in XLA and treated as constants in the +HLO. If your code uses a random number generator that can generate these values, +the XLA compiler will compile the code that uses each value separately. This can +be disabled with `XLA_NO_SPECIAL_SCALARS=1` environment variable. diff --git a/docs/source/learn/xla-examples.md b/docs/source/learn/xla-examples.md new file mode 100644 index 000000000000..e5375bb6ae41 --- /dev/null +++ b/docs/source/learn/xla-examples.md @@ -0,0 +1,253 @@ +# Converting code to PyTorch XLA + +General guidelines to modify your code: + +- Replace `cuda` with `torch_xla.device()` +- Remove code that would access the XLA tensor values +- Wrap data loader with MPDeviceLoader +- Profile to further optimize the code + +Remember: each case is unique so you might need to do something +different for each case. + +## Example 1. Stable Diffusion inference in PyTorch Lightning on a Single TPU Device + +To get a better understanding of the code changes needed to convert PyTorch code +that runs on GPUs to run on TPUs, let's look at the [inference +code](https://github.com/pytorch-tpu/stable-diffusion/blob/main/scripts/txt2img.py) +from a PyTorch implementation of the stable diffusion model. You can run the +script from the command line: + +``` bash + python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" +``` + +To see a diff of the modifications explained below, see +[ldm/models/diffusion/ddim.py](https://github.com/pytorch-tpu/stable-diffusion/commit/57f398eb784387e244dc5fb78421aa5261abd1ef). +Let's go over them step by step. As in the general guidelines above, +start with changes related to `cuda` devices. This inference code is +written to run on GPUs and `cuda` can be found in multiple places. Start +making changes by removing `model.cuda()` from [this +line](https://github.com/pytorch-tpu/stable-diffusion/blob/978da4c625a712a01ee066d019a0b0d2319cd8b3/scripts/txt2img.py#L64), +and `precision_scope` from +[here](https://github.com/pytorch-tpu/stable-diffusion/blob/978da4c625a712a01ee066d019a0b0d2319cd8b3/scripts/txt2img.py#L290). +Additionally, replace the `cuda` device in [this +line](https://github.com/pytorch-tpu/stable-diffusion/blob/978da4c625a712a01ee066d019a0b0d2319cd8b3/scripts/txt2img.py#L248) +with the `xla` device similar to the code below: + +``` python + import torch_xla.core.xla_model as xm + self.device = torch_xla.device() +``` + +Next, this particular configuration of the model is using +`FrozenCLIPEmbedder`, therefore we will modify this +[line](https://github.com/pytorch-tpu/stable-diffusion/blob/978da4c625a712a01ee066d019a0b0d2319cd8b3/ldm/modules/encoders/modules.py#L143) +as well. For simplicity we will directly define the `device` in this +tutorial, but you can pass the `device` value to the function as well. + +Another place in the code that has cuda specific code is [DDIM scheduler](https://github.com/pytorch-tpu/stable-diffusion/blob/978da4c625a712a01ee066d019a0b0d2319cd8b3/ldm/models/diffusion/ddim.py#L12). +Add `import torch_xla.core.xla_model as xm` on top of the file then +replace +[these](https://github.com/pytorch-tpu/stable-diffusion/blob/978da4c625a712a01ee066d019a0b0d2319cd8b3/ldm/models/diffusion/ddim.py#L21-L22) +lines + +``` python +if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) +``` + +with + +``` python +device = torch_xla.device() +attr = attr.to(torch.device(device)) +``` + +Next, you can reduce device (TPU) and host (CPU) communication by +removing print statements, disabling progress bars, and reducing or +removing callbacks and logging. These operations require the device to +stop executing, falling back to the CPU, executing the +logging/callbacks, and then returning to the device. This can be a +significant performance bottleneck, especially on large models. + +After making these changes, the code will run on TPUs. However, the +performance will not be optimized. This is because the XLA compiler tries to +build a single (huge) graph that wraps the number of inference steps (in +this case, 50) as there is no barrier inside the for loop. It is +difficult for the compiler to optimize the graph, and this leads to +significant performance degradation. As discussed above, breaking the +for loop with a call to `torch_xla.sync()` will result in a smaller +graph that is easier for the compiler to optimize. This allows +the compiler to reuse the graph from the previous step, which can +improve performance. + +Now the +[code](https://github.com/pytorch-tpu/stable-diffusion/blob/ss-inference/scripts/txt2img.py) +is ready to run on TPUs in a reasonable time. More optimization and +analysis can be done by [capturing a +profile](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm) +and investigating further. However, this is not covered here. + +Note: if you are running on v4-8 TPU, then you have 4 available XLA +(TPU) devices. Running the code as above will only use one XLA device. +In order to run on all 4 devices, use the `torch_xla.launch()` function. +We will discuss a `torch_xla.launch` in the next example. + +## Example 2. HF Stable Diffusion Inference + +Now, consider using [Stable Diffusion +Inference](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image) +in the HuggingFace diffusers library for both the SD-XL and 2.1 versions +of the model. You can find the changes described below in the [diffusers repo](https://github.com/pytorch-tpu/diffusers). +Clone the repo and run the inference script using the following command on your +TPU VM: + +``` bash +(vm)$ git clone https://github.com/pytorch-tpu/diffusers.git +(vm)$ cd diffusers/examples/text_to_image/ +(vm)$ python3 inference_tpu_single_device.py +``` + +## Running on a Single TPU device + +This section describes how to update the +[text_to_image inference example](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image#inference) +to run on TPUs. + +The original code uses Lora for inference, but this tutorial will not +use it. Instead, we will set the `model_id` argument to +`stabilityai/stable-diffusion-xl-base-0.9` when initializing the +pipeline. We will also use the default scheduler +(DPMSolverMultistepScheduler). However, similar changes can be made to +the other schedulers as well. + +``` bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install . # pip install -e . + +cd examples/text_to_image/ +pip install -r requirements.txt +pip install invisible_watermark transformers accelerate safetensors +``` + +(If `accelerate` is not found, log out, log back in.) + +Log in to HF and agree to the [sd-xl 0.9 +license](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9) +on the model card. Next, go to +[account→settings→access](https://huggingface.co/settings/tokens) and generate a +new token. Copy the token and run the following command +with that specific token value on your vm + +``` bash +(VM)$ huggingface-cli login --token _your_copied_token__ +``` + +The [HuggingFace readme](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9#sd-xl-09-base-model-card) +provides PyTorch code that is written to run on +GPUs. To run it on TPUs, the first step is to change the CUDA device to +an XLA device. This can be done by replacing the line `pipe.to("cuda")` +with the following lines: + +``` python +import torch_xla.core.xla_model as xm +device = torch_xla.device() +pipe.to(device) +``` + +The first time you run an inference with XLA, the compiler builds a graph of the +computations, and optimizes this graph for the specific hardware the code is +running on. Once the graph has been compiled, is can be reused for subsequent +calls, which will be much faster. For example, compilation time for stable +diffusion XL model inference from HuggingFace can take about an hour to compile, +whereas the actual inference may take only 5 seconds, depending on the batch +size. Likewise, a GPT-2 model can take about 10-15 mins to compile, after +which the training epoch time becomes much faster. + +If you are running inference multiple times, you will start to see the +advantages of XLA after the graph is compiled. For example, if you +run inference on a list of 10 prompts, the first inference (maybe +two[^1]) may take a long time to compile, but the remaining inference +steps will be much faster. This is because XLA will reuse the graph that +it compiled for the first inference. + +If you try to run the code without making any additional changes, you +will notice that the compilation time is very long (\>6 hours). This is +because the XLA compiler tries to build a single graph for all of the +scheduler steps at once similar to what we have discussed in the +previous example. To make the code run faster, we need to break the +graph up into smaller pieces with `torch_xla.sync()` and reuse them in the +next steps. This happens inside the `pipe.__call__` +[function](https://github.com/huggingface/diffusers/blob/2b1786735e27bc97f4d4699712292d5c463a7380/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L559) +in [these +lines](https://github.com/huggingface/diffusers/blob/2b1786735e27bc97f4d4699712292d5c463a7380/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L805-L839). +Disabling the progress bar, removing callbacks and adding +`torch_xla.sync()` at the end of the for loop speeds up the code +significantly. Changes are provided in this +[commit](https://github.com/huggingface/diffusers/compare/main...pytorch-tpu:diffusers:main). + +Additionally, the `self.scheduler.step()` function, which by default +uses the `DPMSolverMultistepScheduler` scheduler, has a few issues that +are described in the [PyTorch XLA +caveats](https://pytorch.org/xla/release/2.0/index.html#known-performance-caveats). +The `.nonzero()` and `.item()` calls in this function send requests to +the CPU for tensor evaluation, which trigger device-host communication. +This is not desirable, as it can slow down the code. In this particular +case, we can avoid these calls by passing the index to the function +directly. This prevents unnecessary device-host communitation. Changes are available +in +[this](https://github.com/pytorch-tpu/diffusers/commit/0243d2ef9c2c7bc06956bb1bcc92c23038f6519d) +commit. The code now is ready to be run on TPUs. + +## Running on Multiple TPU Devices + +To use multiple TPU devices, use the `torch_xla.launch` function +to run the function on multiple devices and sync when necessary. +The `torch_xla.launch` function will start processes on multiple TPU +devices and sync them when needed. This can be done by passing the +`index` argument to the function that runs on a single device. For +example, + +``` python +import torch_xla + +def my_function(index): + # function that runs on a single device + +torch_xla.launch(my_function, args=(0,)) +``` + +In this example, the `my_function` function will be run on 4 TPU +devices (for a v4-8 TPU slice). Each device is assigned an index from 0 to 3. +By default, `launch()` will run the function on all +TPU devices. If you want a single process, set `debug_single_process=True`: +`launch(..., debug_single_process=True)`. + +[This +file](https://github.com/ssusie/diffusers/blob/main/examples/text_to_image/inference_tpu_multidevice.py) +illustrates how xmp.spawn can be used to run stable diffusion 2.1 +version on multiple TPU devices. For this example, changes were made to the +[pipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) +file. + +## Running on Pods + +Once you have the code for running on a single host device, there is no +further change needed. You can create the TPU pod, for example, by +following these +[instructions](https://cloud.google.com/tpu/docs/pytorch-pods#create-tpu-vm). +Then run your script with + +``` bash +gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ + --zone=${ZONE} \ + --worker=all \ + --command="python3 your_script.py" +``` + +## Reference implementations + +The [AI-Hypercomputer/tpu-recipes](https://github.com/AI-Hypercomputer/tpu-recipes) +repo contains examples for training and serving many LLM and diffusion models. diff --git a/docs/source/learn/xla-overview.md b/docs/source/learn/xla-overview.md index f6b0761fd69a..970eff5e3db3 100644 --- a/docs/source/learn/xla-overview.md +++ b/docs/source/learn/xla-overview.md @@ -1,142 +1,63 @@ # Pytorch/XLA Overview -This section provides a brief overview of the basic details of PyTorch -XLA, which should help readers better understand the required -modifications and optimizations of code. - -Unlike regular PyTorch, which executes code line by line and does not -block execution until the value of a PyTorch tensor is fetched, PyTorch -XLA works differently. It iterates through the python code and records -the operations on (PyTorch) XLA tensors in an intermediate -representation (IR) graph until it encounters a barrier (discussed -below). This process of generating the IR graph is referred to as -tracing (LazyTensor tracing or code tracing). PyTorch XLA then converts -the IR graph to a lower-level machine-readable format called HLO -(High-Level Opcodes). HLO is a representation of a computation that is -specific to the XLA compiler and allows it to generate efficient code -for the hardware that it is running on. HLO is fed to the XLA compiler -for compilation and optimization. Compilation is then cached by PyTorch -XLA to be reused later if/when needed. The compilation of the graph is -done on the host (CPU), which is the machine that runs the Python code. -If there are multiple XLA devices, the host compiles the code for each -of the devices separately except when using SPMD (single-program, -multiple-data). For example, v4-8 has one host machine and [four -devices](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v4). -In this case the host compiles the code for each of the four devices -separately. In case of pod slices, when there are multiple hosts, each -host does the compilation for XLA devices it is attached to. If SPMD is -used, then the code is compiled only once (for given shapes and -computations) on each host for all the devices. +PyTorch/XLA is an open-source Python package that enables PyTorch to run on XLA +(Accelerated Linear Algebra) compatible devices, with a primary focus on +**Google Cloud TPUs** and also supporting **XLA-compatible GPUs**. It allows +developers and researchers to leverage the massive parallel processing +capabilities of these accelerators for training and inferencing large-scale AI +models with minimal code changes from their existing PyTorch workflows. + +At its core, PyTorch/XLA acts as a bridge between the familiar PyTorch Python +frontend and the XLA compiler. When you run PyTorch operations on XLA +devices using this library, you get the following key features: + +1. **Lazy Evaluation**: Operations are not executed immediately. Instead, + PyTorch/XLA records these operations in an intermediate representation (IR) + graph. The process of generating the IR graph is often referred to as + "tracing" (LazyTensor tracing or code tracing). Sometimes this is also called + lazy evaluation and it can lead to significant + [performance improvements](https://arxiv.org/pdf/2102.13267.pdf). +2. **Graph Compilation**: When results are actually needed (e.g., printing a + tensor, saving a checkpoint, or at an explicit synchronization point like + `torch_xla.sync()`), the accumulated IR graph is converted into a lower-level + machine-readable format called HLO (High-Level Opcodes). HLO is a + representation of a computation that is specific to the XLA compiler and + allows it to generate efficient code for the hardware that it is running on. +3. **XLA Optimization**: The XLA compiler takes this HLO, performs a series of + optimizations (like operator fusion, memory layout optimization, and + parallelization), and compiles it into highly efficient machine code tailored + for the specific XLA device (e.g., TPU). +4. **Execution**: The compiled code is then executed on the XLA device(s). + Compiled graphs are cached, so subsequent executions with the same + computation graph and input shapes can reuse the optimized binary, + significantly speeding up repeated operations typical in training loops. ![img](../_static/img/pytorchXLA_flow.svg) -For more details and examples, please refer to the [LazyTensor -guide](https://pytorch.org/blog/understanding-lazytensor-system-performance-with-pytorch-xla-on-cloud-tpu/). - -The operations in the IR graph are executed only when values of tensors -are needed. This is referred to as evaluation or materialization of -tensors. Sometimes this is also called lazy evaluation and it can lead -to significant [performance -improvements](https://arxiv.org/pdf/2102.13267.pdf). - -The *synchronous* operations in Pytorch XLA, like printing, logging, -checkpointing or callbacks block tracing and result in slower execution. -In the case when an operation requires a specific value of an XLA -tensor, e.g. `print(xla_tensor_z)`, tracing is blocked until the value -of that tensor is available to the host. Note that only the part of the -graph responsible for computing that tensor value is executed. These -operations do not cut the IR graph, but they trigger host-device -communication through `TransferFromDevice`, which results in slower -performance. - -A *barrier* is a special instruction that tells XLA to execute the IR -graph and materialize the tensors. This means that the PyTorch XLA -tensors will be evaluated, and the results will be available to the -host. The user-exposed barrier in Pytorch XLA is -[torch_xla.sync()](https://github.com/pytorch/xla/blob/bdceee54eca1269ee954f6cdd1868c584d0e88a4/torch_xla/core/xla_model.py#L808), -which breaks the IR graph and results in code execution on the XLA -devices. One of the key properties of `torch_xla.sync()` is that unlike -synchronous operations it does not block the further tracing while the -device is executing the graph. However, it does block access to the -values of the tensors that are being materialized. - -The example in the LazyTensor guide illustrates what happens in a simple -case of adding two tensors. Now, suppose we have a for loop that adds -XLA tensors and uses the value later: - -``` python -for x, y in tensors_on_device: - z += x + y -``` - -Without a barrier, the Python tracing will result in a single graph that -wraps the addition of tensors `len(tensors_on_device)` times. This is -because the `for` loop is not captured by the tracing, so each iteration -of the loop will create a new subgraph corresponding to the computation -of `z += x+y` and add it to the graph. Here is an example when -`len(tensors_on_device)=3`. - -![img](../_static/img/IRgraph_no_markstep.png) - -However, introducing a barrier at the end of the loop will result in a -smaller graph that will be compiled once during the first pass inside -the `for` loop and will be reused for the next -`len(tensors_on_device)-1` iterations. The barrier will signal to the -tracing that the graph traced so far can be submitted for execution, and -if that graph has been seen before, a cached compiled program will be -reused. - -``` python -for x, y in tensors_on_device: - z += x + y - torch_xla.sync() -``` - -In this case there will be a small graph that is used -`len(tensors_on_device)=3` times. - -![img](../_static/img/IRgraph_markstep.png) +This process allows PyTorch/XLA to provide significant performance benefits, +especially for large models and distributed training scenarios. For a deeper +dive into the lazy tensor system, see our +[LazyTensor guide](https://pytorch.org/blog/understanding-lazytensor-system-performance-with-pytorch-xla-on-cloud-tpu/). -It is important to highlight that in PyTorch XLA Python code inside for -loops is traced and a new graph is constructed for each iteration if -there is a barrier at the end. This can be a significant performance -bottleneck. +## **Why Use PyTorch/XLA?** -The XLA graphs can be reused when the same computation happens on the -same shapes of tensors. If the shapes of the inputs or intermediate -tensors change, then the XLA compiler will recompile a new graph with -the new tensor shapes. This means that if you have dynamic shapes or if -your code does not reuse tensor graphs, running your model on XLA will -not be suitable for that use case. Padding the input into a fixed shape -can be an option to help avoid dynamic shapes. Otherwise, a significant -amount of time will be spent by the compiler on optimizing and fusing -operations which will not be used again. +* **High Performance on TPUs**: PyTorch/XLA is optimized to deliver exceptional performance for training and inference on Google Cloud TPUs, which are custom-designed AI accelerators. +* **Scalability**: Seamlessly scale your models from a single device to large TPU Pods with minimal code changes, enabling you to tackle more ambitious projects. +* **Familiar PyTorch Experience**: Continue using the PyTorch APIs and ecosystem you know and love. PyTorch/XLA aims to make the transition to XLA devices as smooth as possible, often requiring only minor modifications to existing PyTorch code. +* **Cost-Efficiency**: TPUs offer a compelling price/performance ratio for many AI workloads. PyTorch/XLA helps you take advantage of this efficiency. +* **Versatility**: Accelerate a wide range of AI workloads, including chatbots, code generation, media content generation, vision services, and recommendation engines. +* **Support for Leading Frameworks**: While focused on PyTorch, XLA itself is a compiler backend used by other major frameworks like JAX and TensorFlow. -The trade-off between graph size and compilation time is also important -to consider. If there is one large IR graph, the XLA compiler can spend -a lot of time on optimization and fusion of the ops. This can result in -a very long compilation time. However, the later execution may be much -faster, due to the optimizations that were performed during compilation. +## **Target Hardware** -Sometimes it is worth breaking the IR graph with `torch_xla.sync()`. As -explained above, this will result in a smaller graph that can be reused -later. However making graphs smaller can reduce optimizations that -otherwise could be done by the XLA compiler. +While PyTorch/XLA can theoretically run on any XLA-compatible backend, its primary development and optimization focus is on: -Another important point to consider is -[MPDeviceLoader](https://github.com/pytorch/xla/blob/a1f822e2627a5639464273241821852677401026/torch_xla/distributed/parallel_loader.py#L186). -Once your code is running on an XLA device, consider wrapping the torch -dataloader with XLA `MPDeviceLoader` which preloads data to the device -to improve performance and includes `torch_xla.sync()` in it. The latter -automatically breaks the iterations over batches of data and sends them -for execution. Note, if you are not using MPDeviceLoader, you might need -to set `barrier=True` in the `optimizer_step()` to enable -`torch_xla.sync()` if running a training job or explicitly adding -`torch_xla.sync()`. +* **Google Cloud TPUs**: Including various generations like TPU v5 and v6. [Learn more about TPUs](https://cloud.google.com/tpu/docs/intro-to-tpu). +* **GPUs via XLA**: PyTorch/XLA also supports running on NVIDIA GPUs through the OpenXLA PJRT plugin, providing an alternative execution path. [Learn more about GPUs on Google Cloud](https://cloud.google.com/compute/docs/gpus). ## TPU Setup -Create TPU with base image to use nightly wheels or from the stable +Create a TPU with the base image to use nightly wheels or from the stable release by specifying the `RUNTIME_VERSION`. ``` bash @@ -153,7 +74,7 @@ gcloud compute tpus tpu-vm create ${TPU_NAME} \ --subnetwork=tpusubnet ``` -If you have a single host VM (e.g. v4-8), you can ssh to your vm and run +If you have a single host VM (e.g. v4-8), you can ssh to your vm and run the following commands from the vm directly. Otherwise, in case of TPU pods, you can use `--worker=all --command=""` similar to @@ -169,395 +90,14 @@ libraries ``` bash pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly-cp38-cp38-linux_x86_64.whl -​​pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl + pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl sudo apt-get install libopenblas-dev -y sudo apt-get update && sudo apt-get install libgl1 -y # diffusion specific ``` -## Reference implementations - -The [AI-Hypercomputer/tpu-recipies](https://github.com/AI-Hypercomputer/tpu-recipes) -repo. contains examples for training and serving many LLM and diffusion models. - -## Converting code to PyTorch XLA - -General guidelines to modify your code: - -- Replace `cuda` with `torch_xla.device()` -- Remove progress bar, printing that would access the XLA tensor - values -- Reduce logging and callbacks that would access the XLA tensor values -- Wrap data loader with MPDeviceLoader -- Profile to further optimize the code - -Remember: each case is unique so you might need to do something -different for each case. - -### Example 1. Stable Diffusion inference in PyTorch Lightning on a Single TPU Device - -As a first example consider the [inference -code](https://github.com/pytorch-tpu/stable-diffusion/blob/main/scripts/txt2img.py) -of the stable diffusion model in PyTorch Lightning which can be run from -command line as - -``` bash - python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" -``` - -For your reference, the diff of modifications described below can be -found -[here](https://github.com/pytorch-tpu/stable-diffusion/commit/57f398eb784387e244dc5fb78421aa5261abd1ef). -Let's go over them step by step. As in the general guideline above, -start with changes related to `cuda` device. This inference code is -written to run on GPUs and `cuda` can be found in multiple places. Start -making changes by removing `model.cuda()` from [this -line](https://github.com/pytorch-tpu/stable-diffusion/blob/978da4c625a712a01ee066d019a0b0d2319cd8b3/scripts/txt2img.py#L64), -and `precision_scope` from -[here](https://github.com/pytorch-tpu/stable-diffusion/blob/978da4c625a712a01ee066d019a0b0d2319cd8b3/scripts/txt2img.py#L290). -Additionally, replace the `cuda` device in [this -line](https://github.com/pytorch-tpu/stable-diffusion/blob/978da4c625a712a01ee066d019a0b0d2319cd8b3/scripts/txt2img.py#L248) -with the `xla` device similar to the code below: - -Next, this particular configuration of the model is using -`FrozenCLIPEmbedder`, therefore we will modify this -[line](https://github.com/pytorch-tpu/stable-diffusion/blob/978da4c625a712a01ee066d019a0b0d2319cd8b3/ldm/modules/encoders/modules.py#L143) -as well. For simplicity we will directly define the `device` in this -tutorial, but you can pass the `device` value to the function as well. - -``` python - import torch_xla.core.xla_model as xm - self.device = torch_xla.device() -``` - -Another place in the code that has cuda specific code is DDIM scheduler. -Add `import torch_xla.core.xla_model as xm` on top of the file then -replace -[these](https://github.com/pytorch-tpu/stable-diffusion/blob/978da4c625a712a01ee066d019a0b0d2319cd8b3/ldm/models/diffusion/ddim.py#L21-L22) -lines - -``` python -if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) -``` - -with - -``` python -device = torch_xla.device() -attr = attr.to(torch.device(device)) -``` - -Next, you can reduce device (TPU) and host (CPU) communication by -removing print statements, disabling progress bars, and reducing or -removing callbacks and logging. These operations require the device to -stop executing, falling back to the CPU, executing the -logging/callbacks, and then returning to the device. This can be a -significant performance bottleneck, especially on large models. - -After making these changes, the code will run on TPUs. However, the -performance will be very slow. This is because the XLA compiler tries to -build a single (huge) graph that wraps the number of inference steps (in -this case, 50) as there is no barrier inside the for loop. It is -difficult for the compiler to optimize the graph, and this leads to -significant performance degradation. As discussed above, breaking the -for loop with the barrier (torch_xla.sync()) will result in a smaller -graph that is easier for the compiler to optimize. This will also allow -the compiler to reuse the graph from the previous step, which can -improve performance. - -Now the -[code](https://github.com/pytorch-tpu/stable-diffusion/blob/ss-inference/scripts/txt2img.py) -is ready to run on TPUs in a reasonable time. More optimization and -analysis can be done by [capturing a -profile](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm) -and investigating further. However, this is not covered here. - -Note: if you are running on v4-8 TPU, then you have 4 available XLA -(TPU) devices. Running the code as above will only use one XLA device. -In order to run on all 4 devices you need to use `torch_xla.launch()` -function to spawn the code on all the devices. We will discuss a -`torch_xla.launch` in the next example. - -### Example 2. HF Stable Diffusion Inference - -Now, consider using [Stable Diffusion -Inference](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image) -in the HuggingFace diffusers library for both the SD-XL and 2.1 versions -of the model. For your reference, the changes described below can be -found in this [repo](https://github.com/pytorch-tpu/diffusers). You can -clone the repo and run the inference using the following command on your -TPU VM: - -``` bash -(vm)$ git clone https://github.com/pytorch-tpu/diffusers.git -(vm)$ cd diffusers/examples/text_to_image/ -(vm)$ python3 inference_tpu_single_device.py -``` - -### Running on a Single TPU device - -This section describes the changes that need to be made to the -[text_to_image inference -example](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image#inference) -code to run it on TPUs. - -The original code uses Lora for inference, but this tutorial will not -use it. Instead, we will set the `model_id` argument to -`stabilityai/stable-diffusion-xl-base-0.9` when initializing the -pipeline. We will also use the default scheduler -(DPMSolverMultistepScheduler). However, similar changes can be made to -the other schedulers as well. - -``` bash -git clone https://github.com/huggingface/diffusers -cd diffusers -pip install . # pip install -e . - -cd examples/text_to_image/ -pip install -r requirements.txt -pip install invisible_watermark transformers accelerate safetensors -``` - -(If `accelerate` is not found, log out, log back in.) - -Log in to HF and agree to the [sd-xl 0.9 -license](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9) -on the model card. Next, go to -[account→settings→access](https://huggingface.co/settings/tokens) token -and generate a new token. Copy the token and run the following command -with that specific token value on your vm - -``` bash -(vm)$ huggingface-cli login --token _your_copied_token__ -``` - -The HuggingFace readme provides PyTorch code that is written to run on -GPUs. To run it on TPUs, the first step is to change the CUDA device to -an XLA device. This can be done by replacing the line `pipe.to("cuda")` -with the following lines: - -``` python -import torch_xla.core.xla_model as xm -device = torch_xla.device() -pipe.to(device) -``` - -Additionally, it is important to note that the first time you run -inference with XLA, it will take a long time to compile. For example, -compilation time for stable diffusion XL model inference from -HuggingFace can take about an hour to compile, whereas the actual -inference may take only 5 seconds, depending on the batch size. -Likewise, a GPT-2 model can take about 10-15 mins to compile, after -which the training epoch time becomes much faster. This is because XLA -builds a graph of the computation that will be performed, and then -optimizes this graph for the specific hardware that it is running on. -However, once the graph has been compiled, it can be reused for -subsequent inferences, which will be much faster. Therefore, if you are -only running inference once, you may not benefit from using XLA. -However, if you are running inference multiple times, or if you are -running inference on a list of prompts, you will start to see the -advantages of XLA after the first few inferences. For example, if you -run inference on a list of 10 prompts, the first inference (maybe -two[^1]) may take a long time to compile, but the remaining inference -steps will be much faster. This is because XLA will reuse the graph that -it compiled for the first inference. - -If you try to run the code without making any additional changes, you -will notice that the compilation time is very long (\>6 hours). This is -because the XLA compiler tries to build a single graph for all of the -scheduler steps at once similar to what we have discussed in the -previous example. To make the code run faster, we need to break the -graph up into smaller pieces with `torch_xla.sync()` and reuse them in the -next steps. This happens inside the `pipe.__call__` -[function](https://github.com/huggingface/diffusers/blob/2b1786735e27bc97f4d4699712292d5c463a7380/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L559) -in [these -lines](https://github.com/huggingface/diffusers/blob/2b1786735e27bc97f4d4699712292d5c463a7380/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L805-L839). -Disabling the progress bar, removing callbacks and adding -`torch_xla.sync()` at the end of the for loop speeds up the code -significantly. Changes are provided in this -[commit](https://github.com/huggingface/diffusers/compare/main...pytorch-tpu:diffusers:main). - -Additionally, the `self.scheduler.step()` function, which by default -uses the `DPMSolverMultistepScheduler` scheduler, has a few issues that -are described in the [PyTorch XLA -caveats](https://pytorch.org/xla/release/2.0/index.html#known-performance-caveats). -The `.nonzero()` and `.item()` calls in this function send requests to -the CPU for tensor evaluation, which trigger device-host communication. -This is not desirable, as it can slow down the code. In this particular -case, we can avoid these calls by passing the index to the function -directly. This will prevent the function from sending requests to the -CPU, and will improve the performance of the code. Changes are available -in -[this](https://github.com/pytorch-tpu/diffusers/commit/0243d2ef9c2c7bc06956bb1bcc92c23038f6519d) -commit. The code now is ready to be run on TPUs. - -## Profiling and performance analysis - -To further investigate the performance of the model, we can profile it -using the profiling -[guide](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm). -As a rule of thumb, the profiling script should be run with the maximum -batch size that fits into the memory for [optimal memory -usage](https://cloud.google.com/tpu/docs/performance-guide). It also -helps to overlap tracing of the code with device execution which leads -to more optimal device usage. The duration of profiling should be long -enough to capture at least one step. Good performance of the model on -TPUs means that device-host communication is minimized and the device is -constantly running processes with no idle time. - -Starting a server in the `inference_tpu_*.py` file and running -`capture_profile.py` script as described in the guide will give us -information on processes that run on the devices. Currently, only one -XLA device is profiled. To better understand the TPU idle time (gaps in -the profile), profiling traces (`xp.Trace()`) should be added to the -code. The `xp.Trace()` measures the time it takes to trace the python -code on the host machine wrapped with the trace. For this example, -`xp.Trace()` traces were added inside the -[pipeline](https://github.com/ssusie/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py) -and the [U-net -model](https://github.com/ssusie/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py) -to measure the time to run specific sections of the code on the host -(CPU). - -If the gaps in the profile are due to Python code tracing that happens -on the host, then this might be a bottleneck and there is no further -straightforward optimization that can be done. Otherwise, the code -should be analyzed further to understand the caveats and improve the -performance further. Note that you cannot `xp.Trace()` wrap portions of -the code where `torch_xla.sync()` is called. - -To illustrate this we can look at already captured profiles that were -uploaded to tensorboard following the profiling guide. - -Starting from Stable Diffusion model version 2.1 - -If we capture a profile without inserting any traces, we will see the -following: - -![Alt text](../_static/img/image.png) - -The single TPU device on v4-8, which has two cores, appears to be busy. -There are no significant gaps in their usage, except for a small one in -the middle. If we scroll up to try to find which process is occupying -the host machine, we will not find any information. Therefore, we will -add `xp.traces` to the pipeline -[file](https://github.com/pytorch-tpu/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) -as well as the U-net -[function](https://github.com/pytorch-tpu/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py). -The latter may not be useful for this particular use case, but it does -demonstrate how traces can be added in different places and how their -information is displayed in TensorBoard. - -If we add traces and re-capture the profile with the largest batch size -that can fit on the device (32 in this case), we will see that the gap -in the device is caused by a Python process that is running on the host -machine. - -![Alt text](../_static/img/image-1.png) - -We can use the appropriate tool to zoom in on the timeline and see which -process is running during that period. This is when the Python code -tracing happens on the host, and we cannot improve the tracing further -at this point. - -Now, let's examine the XL version of the model and do the same thing. We -will add traces to the pipeline -[file](https://github.com/pytorch-tpu/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py) -in the same way that we did for the 2.1 version and capture a profile. - -![Alt text](../_static/img/image-4.png) - -This time, in addition to the large gap in the middle, which is caused -by the `pipe_watermark` tracing, there are many small gaps between the -inference steps within [this -loop](https://github.com/pytorch-tpu/diffusers/blob/0243d2ef9c2c7bc06956bb1bcc92c23038f6519d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L814-L830). - -First look closer into the large gap that is caused by `pipe_watermark`. -The gap is preceded with `TransferFromDevice` which indicates that -something is happening on the host machine that is waiting for -computation to finish before proceeding. Looking into watermark -[code](https://github.com/pytorch-tpu/diffusers/blob/0243d2ef9c2c7bc06956bb1bcc92c23038f6519d/src/diffusers/pipelines/stable_diffusion_xl/watermark.py#L29), -we can see that tensors are transferred to cpu and converted to numpy -arrays in order to be processed with `cv2` and `pywt` libraries later. -Since this part is not straightforward to optimize, we will leave this -as is. - -Now if we zoom in on the loop, we can see that the graph within the loop -is broken into smaller parts because the `TransferFromDevice` operation -happens. - -![Alt text](../_static/img/image-2.png) - -If we investigate the U-Net function and the scheduler, we can see that -the U-Net code does not contain any optimization targets for -PyTorch/XLA. However, there are `.item()` and `.nonzero()` calls inside -the -[scheduler.step](https://github.com/huggingface/diffusers/blob/15782fd506e8c4a7c2b288fc2e558bd77fdfa51a/src/diffusers/schedulers/scheduling_euler_discrete.py#L371). -We can -[rewrite](https://github.com/pytorch-tpu/diffusers/blob/0243d2ef9c2c7bc06956bb1bcc92c23038f6519d/src/diffusers/schedulers/scheduling_euler_discrete.py#L310) -the function to avoid those calls. If we fix this issue and rerun a -profile, we will not see much difference. However, since we have reduced -the device-host communication that was introducing smaller graphs, we -allowed the compiler to optimize the code better. The function -[scale_model_input](https://github.com/huggingface/diffusers/blob/15782fd506e8c4a7c2b288fc2e558bd77fdfa51a/src/diffusers/schedulers/scheduling_euler_discrete.py#L205) -has similar issues, and we can fix these by making the changes we made -above to the `step` function. Overall, since many of the gaps are caused -from python level code tracing and graph building, these gaps are not -possible to optimize with the current version of PyTorch XLA, but we may -see improvements in the future when dynamo is enabled in PyTorch XLA. - -## Running on Multiple TPU Devices - -To use multiple TPU devices, you can use the `torch_xla.launch` function -to spawn the function you ran on a single device to multiple devices. -The `torch_xla.launch` function will start processes on multiple TPU -devices and sync them when needed. This can be done by passing the -`index` argument to the function that runs on a single device. For -example, - -``` python -import torch_xla - -def my_function(index): - # function that runs on a single device - -torch_xla.launch(my_function, args=(0,)) -``` - -In this example, the `my_function` function will be spawned on 4 TPU -devices on v4-8, with each device being assigned an index from 0 to 3. -Note that by default, the launch() function will spawn preocesses on all -TPU devices. If you only want to run single process, set the argument -`launch(..., debug_single_process=True)`. - -[This -file](https://github.com/ssusie/diffusers/blob/main/examples/text_to_image/inference_tpu_multidevice.py) -illustrates how xmp.spawn can be used to run stable diffusion 2.1 -version on multiple TPU devices. For this version similar to the above -changes were made to the -[pipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) -file. - -## Running on Pods - -Once you have the code for running on a single host device, there is no -further change needed. You can create the TPU pod, for example, by -following these -[instructions](https://cloud.google.com/tpu/docs/pytorch-pods#create-tpu-vm). -Then run your script with - -``` bash -gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ - --zone=${ZONE} \ - --worker=all \ - --command="python3 your_script.py" -``` - -**Note:** +## Next Steps -0 and 1 are magic numbers in XLA and treated as constants in the -HLO. So if there is a random number generator in the code that can -generate these values, the code will compile for each value -separately. This can be disabled with `XLA_NO_SPECIAL_SCALARS=1` -environment variable. +- [Examples](./xla-examples.md): Explore example code for training and inference on TPUs. +- [Profiling and Performance](./xla-profiling.md): Learn how to profile and optimize your PyTorch/XLA applications. +- [Advanced Topics](./xla-advanced.md): Dive deeper into advanced concepts like graph optimization, data loading, and distributed training with PyTorch/XLA. diff --git a/docs/source/learn/xla-profiling.md b/docs/source/learn/xla-profiling.md new file mode 100644 index 000000000000..4ca1b3726b26 --- /dev/null +++ b/docs/source/learn/xla-profiling.md @@ -0,0 +1,113 @@ +# Profiling and performance analysis + +To investigate model performance, you can profile it using the profiling +[guide](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm). +As a rule of thumb, the profiling script should be run with the maximum +batch size that fits into the memory for [optimal memory +usage](https://cloud.google.com/tpu/docs/performance-guide). It also +helps to overlap tracing of the code with device execution which leads +to optimal device usage. The profile duration should be long +enough to capture at least one step. Good performance of the model on +TPUs means that device-host communication is minimized and the device is +constantly running processes with minimal idle time. + +Starting a server in the `inference_tpu_*.py` file and running +`capture_profile.py` script as described in the guide will give us +information on processes that run on the devices. Currently, only one +XLA device is profiled. To better understand the TPU idle time (gaps in +the profile), add profiling traces (`xp.Trace()`) to the +code. The `xp.Trace()` measures the time it takes to trace the python +code on the host machine wrapped with the trace. For this example, +`xp.Trace()` traces were added inside the +[pipeline](https://github.com/ssusie/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py) +and the [U-net +model](https://github.com/ssusie/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py) +to measure the time to run specific sections of the code on the host +(CPU). + +If the gaps in the profile are due to Python code tracing that happens +on the host, then this might be a bottleneck and there is no further +straightforward optimization that can be done. Otherwise, the code +should be analyzed further to understand the caveats and improve the +performance further. Note that you cannot `xp.Trace()` wrap portions of +the code where `torch_xla.sync()` is called. + +To illustrate this we can look at already captured profiles that were +uploaded to tensorboard following the profiling guide. + +Starting from Stable Diffusion model version 2.1 + +If we capture a profile without inserting any traces, we will see the +following: + +![Alt text](../_static/img/image.png) + +The single TPU device on v4-8, which has two cores, appears to be busy. +There are no significant gaps in their usage, except for a small one in +the middle. If we scroll up to try to find which process is occupying +the host machine, we will not find any information. Therefore, we will +add `xp.traces` to the pipeline +[file](https://github.com/pytorch-tpu/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) +as well as the U-net +[function](https://github.com/pytorch-tpu/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py). +The latter may not be useful for this particular use case, but it does +demonstrate how traces can be added in different places and how their +information is displayed in TensorBoard. + +If we add traces and re-capture the profile with the largest batch size +that can fit on the device (32 in this case), we will see that the gap +in the device is caused by a Python process that is running on the host +machine. + +![Alt text](../_static/img/image-1.png) + +We can use the appropriate tool to zoom in on the timeline and see which +process is running during that period. This is when the Python code +tracing happens on the host, and we cannot improve the tracing further +at this point. + +Now, let's examine the XL version of the model and do the same thing. We +will add traces to the pipeline +[file](https://github.com/pytorch-tpu/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py) +in the same way that we did for the 2.1 version and capture a profile. + +![Alt text](../_static/img/image-4.png) + +This time, in addition to the large gap in the middle, which is caused +by the `pipe_watermark` tracing, there are many small gaps between the +inference steps within [this +loop](https://github.com/pytorch-tpu/diffusers/blob/0243d2ef9c2c7bc06956bb1bcc92c23038f6519d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L814-L830). + +First look closer into the large gap that is caused by `pipe_watermark`. +The gap is preceded with `TransferFromDevice` which indicates that +something is happening on the host machine that is waiting for +computation to finish before proceeding. Looking into watermark +[code](https://github.com/pytorch-tpu/diffusers/blob/0243d2ef9c2c7bc06956bb1bcc92c23038f6519d/src/diffusers/pipelines/stable_diffusion_xl/watermark.py#L29), +we can see that tensors are transferred to cpu and converted to numpy +arrays in order to be processed with `cv2` and `pywt` libraries later. +Since this part is not straightforward to optimize, we will leave this +as is. + +Now if we zoom in on the loop, we can see that the graph within the loop +is broken into smaller parts because the `TransferFromDevice` operation +happens. + +![Alt text](../_static/img/image-2.png) + +If we investigate the U-Net function and the scheduler, we can see that +the U-Net code does not contain any optimization targets for +PyTorch/XLA. However, there are `.item()` and `.nonzero()` calls inside +the +[scheduler.step](https://github.com/huggingface/diffusers/blob/15782fd506e8c4a7c2b288fc2e558bd77fdfa51a/src/diffusers/schedulers/scheduling_euler_discrete.py#L371). +We can +[rewrite](https://github.com/pytorch-tpu/diffusers/blob/0243d2ef9c2c7bc06956bb1bcc92c23038f6519d/src/diffusers/schedulers/scheduling_euler_discrete.py#L310) +the function to avoid those calls. If we fix this issue and rerun a +profile, we will not see much difference. However, since we have reduced +the device-host communication that was introducing smaller graphs, we +allowed the compiler to optimize the code better. The function +[scale_model_input](https://github.com/huggingface/diffusers/blob/15782fd506e8c4a7c2b288fc2e558bd77fdfa51a/src/diffusers/schedulers/scheduling_euler_discrete.py#L205) +has similar issues, and we can fix these by making the changes we made +above to the `step` function. Overall, since many of the gaps are caused +from python level code tracing and graph building, these gaps are not +possible to optimize with the current version of PyTorch XLA, but we may +see improvements in the future when dynamo is enabled in PyTorch XLA. From 7a48185df1b7f219096d4a5f983e054e1cf4c41d Mon Sep 17 00:00:00 2001 From: XiongfeiWei Date: Fri, 1 Aug 2025 10:13:57 -0700 Subject: [PATCH 027/133] Support torch.nn.functional.one_hot (#9523) --- torchax/torchax/ops/jtorch.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torchax/torchax/ops/jtorch.py b/torchax/torchax/ops/jtorch.py index b53f27d462d2..ac2042a7511e 100644 --- a/torchax/torchax/ops/jtorch.py +++ b/torchax/torchax/ops/jtorch.py @@ -179,6 +179,13 @@ def wrap_flash_attention(query, key, value): return wrap_flash_attention(query, key, value) +@register_function(torch.nn.functional.one_hot) +def one_hot(tensor, num_classes=-1): + if num_classes == -1: + num_classes = jnp.max(tensor) + 1 + return jax.nn.one_hot(tensor, num_classes).astype(jnp.int64) + + @register_function(torch.nn.functional.pad) def pad(tensor, pad, mode="constant", value=None): # For padding modes that have different names between Torch and NumPy, this From 0ad39c220596669c93113e59bb8c90a6bf853bd3 Mon Sep 17 00:00:00 2001 From: Rui <179625410+rpsilva-aws@users.noreply.github.com> Date: Fri, 1 Aug 2025 11:01:54 -0700 Subject: [PATCH 028/133] Introduce PlatformVersion bindings (#9513) --- torch_xla/csrc/init_python_bindings.cpp | 4 ++++ torch_xla/csrc/runtime/computation_client.h | 2 ++ torch_xla/csrc/runtime/ifrt_computation_client.cpp | 4 ++++ torch_xla/csrc/runtime/ifrt_computation_client.h | 2 ++ torch_xla/csrc/runtime/pjrt_computation_client.cpp | 4 ++++ torch_xla/csrc/runtime/pjrt_computation_client.h | 2 ++ 6 files changed, 18 insertions(+) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index da2701bb21db..8d23ae7f9862 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1714,6 +1714,10 @@ void InitXlaModuleBindings(py::module m) { return runtime::GetComputationClientOrDie()->GetLocalDevices(); } }) + .def("_xla_get_platform_version", + []() { + return runtime::GetComputationClientOrDie()->GetPlatformVersion(); + }) .def("_get_stream_for_cuda_device", [](const int device_id) { return runtime::GetComputationClientOrDie()->GetCudaStreamForDevice( diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index c5b550fb6846..05478dc6cb42 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -380,6 +380,8 @@ class ComputationClient { virtual size_t GetNumDevices() const = 0; + virtual std::string_view GetPlatformVersion() const = 0; + virtual std::vector GetLocalDevices() const = 0; virtual std::vector GetAllDevices() const = 0; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cpp b/torch_xla/csrc/runtime/ifrt_computation_client.cpp index 5538cb4a5e22..d6337503508e 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cpp +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cpp @@ -656,6 +656,10 @@ std::vector IfrtComputationClient::GetAllDevices() const { return IfrtDevicesToString(client_->devices()); } +std::string_view IfrtComputationClient::GetPlatformVersion() const { + return client_->platform_version(); +} + int IfrtComputationClient::GetNumProcesses() const { int max_process_index = client_->process_index(); for (auto* device : client_->devices()) { diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index ab24d1ae357b..e1bcc751bbf3 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -118,6 +118,8 @@ class IfrtComputationClient : public ComputationClient { std::vector GetAllDevices() const override; + std::string_view GetPlatformVersion() const override; + int GetProcessIndex() const override { return client_->process_index(); }; int GetNumProcesses() const override; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cpp b/torch_xla/csrc/runtime/pjrt_computation_client.cpp index d57dbf9be6ce..98ce8520da32 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cpp +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cpp @@ -961,6 +961,10 @@ std::vector PjRtComputationClient::GetAllDevices() const { return PjRtDevicesToString(client_->devices()); } +std::string_view PjRtComputationClient::GetPlatformVersion() const { + return client_->platform_version(); +} + int PjRtComputationClient::GetNumProcesses() const { int max_process_index = client_->process_index(); for (auto* device : client_->devices()) { diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 9a93d2864f4e..3c13d3489cae 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -133,6 +133,8 @@ class PjRtComputationClient : public ComputationClient { std::vector GetAllDevices() const override; + std::string_view GetPlatformVersion() const override; + torch::lazy::hash_t HashCompilationEnv() override; int GetProcessIndex() const override { return client_->process_index(); }; From ebefc8f13367584c15367874ff0c5f4aa4dba1b9 Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Fri, 1 Aug 2025 11:46:40 -0700 Subject: [PATCH 029/133] Update artifacts_builds.tf for 2.8.0-rc4 (#9532) --- infra/tpu-pytorch-releases/artifacts_builds.tf | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/infra/tpu-pytorch-releases/artifacts_builds.tf b/infra/tpu-pytorch-releases/artifacts_builds.tf index f17ddf5ec2f1..99069a5a9ad1 100644 --- a/infra/tpu-pytorch-releases/artifacts_builds.tf +++ b/infra/tpu-pytorch-releases/artifacts_builds.tf @@ -2,8 +2,8 @@ # Define common configuration parameters for 2.8 release and nightly locals { tpu_python_versions = ["3.9", "3.10", "3.11", "3.12", "3.13"] - release_git_tag = "v2.8.0-rc3" - release_package_version = "2.8.0-rc3" + release_git_tag = "v2.8.0-rc4" + release_package_version = "2.8.0-rc4" release_pytorch_git_rev = "v2.8.0-rc8" nightly_package_version = "2.9.0" cuda_versions = { From adf305f321ab631148b6391b3191ced36fa50644 Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Fri, 1 Aug 2025 13:08:45 -0700 Subject: [PATCH 030/133] Fix pip install torch_xla[pallas] (#9531) --- .github/workflows/_test.yml | 2 +- .github/workflows/_tpu_ci.yml | 4 ++-- README.md | 2 +- test/tpu/xla_test_job.yaml | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/_test.yml b/.github/workflows/_test.yml index 5dca5764dd52..413a5aef8322 100644 --- a/.github/workflows/_test.yml +++ b/.github/workflows/_test.yml @@ -140,7 +140,7 @@ jobs: set -x pip install expecttest unittest-xml-reporting - pip install torch_xla[pallas] + pip install 'torch_xla[pallas]' if [[ ! -z "$RUN_BENCHMARK_TESTS" ]]; then pip install -r pytorch/xla/benchmarks/requirements.txt diff --git a/.github/workflows/_tpu_ci.yml b/.github/workflows/_tpu_ci.yml index 662c5a24ee25..82dd7c748c1c 100644 --- a/.github/workflows/_tpu_ci.yml +++ b/.github/workflows/_tpu_ci.yml @@ -52,8 +52,8 @@ jobs: pip install fsspec pip install rich # jax and libtpu is needed for pallas tests. - pip install torch_xla[pallas] - pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-wheels/index.html -f https://storage.googleapis.com/libtpu-releases/index.html + pip install 'torch_xla[pallas]' + pip install 'torch_xla[tpu]' -f https://storage.googleapis.com/libtpu-wheels/index.html -f https://storage.googleapis.com/libtpu-releases/index.html pip install --upgrade protobuf - name: Run Tests (${{ matrix.test_script }}) if: inputs.has_code_changes == 'true' diff --git a/README.md b/README.md index 8e16c0d99ae4..84bac8db9743 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ Note: Builds are available for Python 3.8 to 3.11; please use one of the support pip install torch==2.7.0 'torch_xla[tpu]==2.7.0' # Optional: if you're using custom kernels, install pallas dependencies -pip install torch_xla[pallas] +pip install 'torch_xla[pallas]' ``` **As of 07/16/2025 and starting from Pytorch/XLA 2.8 release, PyTorch/XLA will provide nightly and release wheels for Python 3.11 to 3.13** diff --git a/test/tpu/xla_test_job.yaml b/test/tpu/xla_test_job.yaml index e7f5258b9dde..30c1945eb269 100644 --- a/test/tpu/xla_test_job.yaml +++ b/test/tpu/xla_test_job.yaml @@ -43,7 +43,7 @@ spec: - | pip install expecttest==0.1.6 pip install rich - pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + pip install 'torch_xla[pallas]' cd /src/pytorch/xla volumeMounts: From d3d91a84665e7e134aa53f010e60e76224867d47 Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Fri, 1 Aug 2025 16:50:58 -0700 Subject: [PATCH 031/133] Remove cuda builds for release wheels (#9533) --- infra/tpu-pytorch-releases/artifacts_builds.tf | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/infra/tpu-pytorch-releases/artifacts_builds.tf b/infra/tpu-pytorch-releases/artifacts_builds.tf index 99069a5a9ad1..9ede0110ff3b 100644 --- a/infra/tpu-pytorch-releases/artifacts_builds.tf +++ b/infra/tpu-pytorch-releases/artifacts_builds.tf @@ -8,7 +8,7 @@ locals { nightly_package_version = "2.9.0" cuda_versions = { "nightly": [], - "r2.8": ["12.1", "12.6"] # Note: PyTorch 2.8 release supports 11.8, 12.6, 12.8 + "r2.8": [] # Note: PyTorch 2.8 release doesn't have CUDA support } # Built once a day from master From 9995e971d70a1a0727d76dca270bf7ea4544f817 Mon Sep 17 00:00:00 2001 From: Kyuyeun Kim <62023335+kyuyeunk@users.noreply.github.com> Date: Fri, 1 Aug 2025 16:52:04 -0700 Subject: [PATCH 032/133] Optimize KV cache dequantization performance (#9528) --- .../ragged_paged_attention_v2.py | 24 ++++++++----------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py index 61a6411b31f0..8e124f2c41ca 100644 --- a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py +++ b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py @@ -1954,12 +1954,14 @@ def masked_store(ref, val, start, end, group=1): # kv lens will be contracting dim, we should mask out the NaNs. kv_mask = ( lax.broadcasted_iota(jnp.int32, k.shape, 0) < kv_len - kv_len_start) - k = jnp.where(kv_mask, k.astype(jnp.float32), 0).astype(k.dtype) - v = jnp.where(kv_mask, v.astype(jnp.float32), 0).astype(v.dtype) - - qk = ( - jnp.einsum("nd,md->nm", q, k, preferred_element_type=jnp.float32) * - sm_scale) + k = jnp.where(kv_mask, k, 0) + v = jnp.where(kv_mask, v, 0) + + qk = jnp.einsum("nd,md->nm", q, k, preferred_element_type=jnp.float32) + if k_scale is not None: + qk *= sm_scale * k_scale + else: + qk *= sm_scale store_start = jnp.maximum(q_start - q_len_start, 0) store_end = jnp.minimum(q_end - q_len_start, num_q_per_blk) @@ -2007,6 +2009,8 @@ def init_scratch_ref(): m_curr = jnp.max(qk, axis=1, keepdims=True) s_curr = jnp.exp(qk - m_curr) qkv = jnp.dot(s_curr, v, preferred_element_type=jnp.float32) + if v_scale is not None: + qkv *= v_scale lm_store_shape = head_m_ref.shape m_curr = jnp.broadcast_to(m_curr, lm_store_shape) l_curr = jnp.broadcast_to( @@ -2088,14 +2092,6 @@ def prefetch_next_kv_blk(): for step_idx in range(kv_load_step): k = k_list[step_idx] v = v_list[step_idx] - if k_scale is not None: - # NOTE: Conversion between arbitrary data types is not supported. - # That's why it is converted to float32 first. - k = k.astype(jnp.float32) * k_scale - k = k.astype(q_ref.dtype) - if v_scale is not None: - v = v.astype(jnp.float32) * v_scale - v = v.astype(q_ref.dtype) kv_head_idx = kv_head_chunk_idx + step_idx q_head_idx = kv_head_idx * num_q_heads_per_kv_head # TODO(jevinjiang): extra handlig for packed type that can start at From 2ccd5dc927534270f59b415abd32ea8e9cdf3de7 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Sat, 2 Aug 2025 00:19:39 +0000 Subject: [PATCH 033/133] Add gemini edited docstring --- torchax/torchax/__init__.py | 79 ++++++++++++++++++++++++++++- torchax/torchax/interop.py | 99 +++++++++++++++++++++++++++++++++++-- 2 files changed, 172 insertions(+), 6 deletions(-) diff --git a/torchax/torchax/__init__.py b/torchax/torchax/__init__.py index fe4c1c8ff046..2cbcc3ad36e9 100644 --- a/torchax/torchax/__init__.py +++ b/torchax/torchax/__init__.py @@ -40,7 +40,30 @@ def default_env(): def extract_jax(mod: torch.nn.Module, env=None): - """Returns a pytree of jax.ndarray and a jax callable.""" + """Extracts the state of a `torch.nn.Module` into a JAX-compatible format. + + **Arguments:** + + * `mod` (`torch.nn.Module`): The PyTorch model to extract the state from. + * `env` (optional): The `torchax` environment to use. If not provided, the default environment is used. + + **Returns:** + + A tuple containing: + + * A `pytree` of `jax.ndarray` representing the model's state (parameters and buffers). + * A JAX-callable function that executes the model's forward pass. + + **Usage:** + + ```python + import torch + import torchax + + model = torch.nn.Linear(10, 20) + states, jax_func = torchax.extract_jax(model) + ``` + """ if env is None: env = default_env() states = dict(mod.named_buffers()) @@ -60,11 +83,31 @@ def jax_func(states, args, kwargs=None): def enable_globally(): + """Enables `torchax` globally, which intercepts PyTorch operations and routes them to the JAX backend. This is the primary entry point for using `torchax`. + + **Usage:** + + ```python + import torchax + + torchax.enable_globally() + ``` + """ env = default_env().enable_torch_modes() return env def disable_globally(): + """Disables the `torchax` backend. After calling this, PyTorch operations will revert to their default behavior. + + **Usage:** + + ```python + import torchax + + torchax.disable_globally() + ``` + """ global env default_env().disable_torch_modes() @@ -110,6 +153,40 @@ class CompileOptions: def compile(fn, options: Optional[CompileOptions] = None): + """Compiles a function or `torch.nn.Module` for optimized execution with JAX. + + **Arguments:** + + * `fn`: The function or `torch.nn.Module` to compile. + * `options` (`CompileOptions`, optional): A `CompileOptions` object to configure the compilation process. + + **`CompileOptions`:** + + * `methods_to_compile` (`List[str]`, default=`['forward']`): A list of methods to compile when `fn` is a `torch.nn.Module`. + * `jax_jit_kwargs` (`Dict[str, Any]`, default=`{}`): A dictionary of keyword arguments to pass to `jax.jit`. + * `mode` (`str`, default=`'jax'`): The compilation mode. Currently, only `'jax'` is supported. + + **Returns:** + + A compiled version of the input function or module. + + **Usage:** + + ```python + import torch + import torchax + + model = torch.nn.Linear(10, 20) + compiled_model = torchax.compile(model) + + # With options + options = torchax.CompileOptions( + methods_to_compile=['forward', 'encode'], + jax_jit_kwargs={'static_argnums': (0,)} + ) + compiled_model = torchax.compile(model, options) + ``` + """ options = options or CompileOptions() if options.mode == 'jax': from torchax import interop diff --git a/torchax/torchax/interop.py b/torchax/torchax/interop.py index a87efe9dfe74..d746e9d03ba2 100644 --- a/torchax/torchax/interop.py +++ b/torchax/torchax/interop.py @@ -56,6 +56,29 @@ def set_one(module, prefix): class JittableModule(torch.nn.Module): + """A wrapper class that makes a `torch.nn.Module` compatible with `jax.jit`. It separates the model's parameters and buffers, allowing them to be passed as arguments to a functional version of the model. + + **Arguments:** + + * `m` (`torch.nn.Module`): The PyTorch model to wrap. + * `extra_jit_args` (`dict`, optional): A dictionary of extra arguments to pass to `jax.jit`. + * `dedup_parameters` (`bool`, optional): If `True`, deduplicates parameters that are shared within the model. + + **Usage:** + + ```python + import torch + import torchax + from torchax.interop import JittableModule + + model = torch.nn.Linear(10, 20) + jittable_model = JittableModule(model) + + # The first call will compile the model + inputs = torch.randn(5, 10, device='jax') + outputs = jittable_model(inputs) + ``` + """ def __init__(self, m: torch.nn.Module, @@ -230,12 +253,17 @@ def call_torch(torch_func: TorchCallable, *args: JaxValue, def j2t_autograd(fn, call_jax=call_jax): - """Given a JAX function, returns a PyTorch autograd function implemented with `jax.vjp(fn)`. + """Given a JAX function, returns a PyTorch `autograd` function that is implemented with `jax.vjp`. This allows you to define custom gradients for your PyTorch operations using JAX. + + **Arguments:** + + * `fn`: The JAX function for which to create a PyTorch `autograd` function. + * `call_jax` (optional): The function to use for calling JAX functions from PyTorch. - It wraps `fn` with `jax.vjp` to compute both the output and residuals (intermediate - activations). The wrapped function is then run via `call_jax` and integrated into - the PyTorch autograd framework by saving the residuals into the context object. - """ + **Returns:** + + A PyTorch function with custom gradients defined by the JAX function. + """ @wraps(fn) def inner(*args, **kwargs): @@ -333,11 +361,50 @@ def wrap_jax_jit(torch_function, jax_jit_func=jax.jit, kwargs_for_jax=None): def jax_jit(torch_function, kwargs_for_jax_jit=None, fix_for_buffer_donation=False): + """A decorator that applies `jax.jit` to a PyTorch function. + + **Arguments:** + + * `torch_function`: The PyTorch function to be JIT-compiled. + * `kwargs_for_jax_jit` (`dict`, optional): A dictionary of keyword arguments to pass to `jax.jit`. + * `fix_for_buffer_donation` (`bool`, optional): A flag to enable a workaround for buffer donation issues. + + **Returns:** + + A JIT-compiled version of the PyTorch function. + + **Usage:** + + ```python + import torch + import torchax + from torchax.interop import jax_jit + + @jax_jit + def my_function(x, y): + return torch.sin(x) + torch.cos(y) + + x = torch.randn(5, 10, device='jax') + y = torch.randn(5, 10, device='jax') + result = my_function(x, y) + ``` + """ return wrap_jax_jit( torch_function, jax_jit_func=jax.jit, kwargs_for_jax=kwargs_for_jax_jit) def jax_shard_map(torch_function, kwargs_for_jax_shard_map=None): + """Applies `jax.experimental.shard_map` to a PyTorch function, allowing for data parallelism across multiple devices. + + **Arguments:** + + * `torch_function`: The PyTorch function to be sharded. + * `kwargs_for_jax_shard_map` (`dict`, optional): A dictionary of keyword arguments to pass to `shard_map`. + + **Returns:** + + A sharded version of the PyTorch function. + """ return wrap_jax_jit( torch_function, jax_jit_func=shard_map, @@ -345,6 +412,17 @@ def jax_shard_map(torch_function, kwargs_for_jax_shard_map=None): def jax_value_and_grad(torch_function, kwargs_for_value_and_grad=None): + """Applies `jax.value_and_grad` to a PyTorch function. + + **Arguments:** + + * `torch_function`: The PyTorch function. + * `kwargs_for_value_and_grad` (`dict`, optional): A dictionary of keyword arguments to pass to `jax.value_and_grad`. + + **Returns:** + + A function that computes both the value and the gradient of the input `torch_function`. + """ return wrap_jax_jit( torch_function, jax_jit_func=jax.value_and_grad, @@ -352,5 +430,16 @@ def jax_value_and_grad(torch_function, kwargs_for_value_and_grad=None): def gradient_checkpoint(torch_function, kwargs=None): + """Applies `jax.checkpoint` to a PyTorch function. This is useful for reducing memory usage during training by recomputing intermediate activations during the backward pass instead of storing them. + + **Arguments:** + + * `torch_function`: The PyTorch function to checkpoint. + * `kwargs` (`dict`, optional): A dictionary of keyword arguments to pass to `jax.checkpoint`. + + **Returns:** + + A checkpointed version of the PyTorch function. + """ return wrap_jax_jit( torch_function, jax_jit_func=jax.checkpoint, kwargs_for_jax=kwargs) From b6a5b82b9948b610fa4c304d0d869c82b8f17db1 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Mon, 4 Aug 2025 03:50:03 +0000 Subject: [PATCH 034/133] add more files --- torchax/torchax/CONTRIBUTING.md | 38 ------ torchax/torchax/amp.py | 13 +++ torchax/torchax/config.py | 30 +++++ torchax/torchax/decompositions.py | 23 ++-- torchax/torchax/device_module.py | 8 ++ torchax/torchax/export.py | 62 ++++++++-- torchax/torchax/flax.py | 25 ++++ torchax/torchax/mesh_util.py | 173 ++++++++++++++-------------- torchax/torchax/ops/mappings.py | 32 +++++ torchax/torchax/ops/op_base.py | 46 ++++++-- torchax/torchax/ops/ops_registry.py | 60 ++++++++++ torchax/torchax/tensor.py | 71 ++++++++---- torchax/torchax/tf_integration.py | 116 ++++++++++++++----- torchax/torchax/train.py | 55 ++++++--- torchax/torchax/view.py | 153 ++++++++++++------------ 15 files changed, 598 insertions(+), 307 deletions(-) delete mode 100644 torchax/torchax/CONTRIBUTING.md diff --git a/torchax/torchax/CONTRIBUTING.md b/torchax/torchax/CONTRIBUTING.md deleted file mode 100644 index c61462850652..000000000000 --- a/torchax/torchax/CONTRIBUTING.md +++ /dev/null @@ -1,38 +0,0 @@ -# Contributing to TorchXLA2 - -We appreciate all contributions. If you are planning to contribute a bug fix for an open issue, please comment on the thread and we're happy to provide any guidance. You are very welcome to pick issues from good first issue and help wanted labels. - -If you plan to contribute new features, utility functions or extensions to the core, please first open an issue and discuss the feature with us. Sending a PR without discussion might end up resulting in a rejected PR, because we might be taking the core in a different direction than you might be aware of. - - -# Developer setup - -## Mac setup: -@qihqi - -I am able to develop directly on mac (m1) laptop for most of parts. Using steps -in README.md works. The condensed version for easy copy & paste: - -```bash -conda create --name python=3.10 -conda activate -pip install --upgrade "jax[cpu]" torch -pip install -r test_requirements.txt -pip install -e . -pytest test -``` - -### VSCode - -I use vscode on my Mac. I loosely followed instruction in -https://code.visualstudio.com/docs/python/python-tutorial -to setup a proper python environment. - -The plugins I installed (a subset of the ones listed above) are: -* VSCode's official Python plugin -* Ruff formatter -* Python Debugger - -I also changed Python interpreter to point at the one in my conda env. -That is all the changes I have. - diff --git a/torchax/torchax/amp.py b/torchax/torchax/amp.py index ccbc63bead63..0fc2bb1b23b4 100644 --- a/torchax/torchax/amp.py +++ b/torchax/torchax/amp.py @@ -57,6 +57,19 @@ def is_float(a): @contextlib.contextmanager def autocast(device, dtype=torch.bfloat16, env=None): + """A context manager for automatic mixed precision (AMP). + + This context manager enables automatic mixed precision, which can improve + performance by using lower-precision data types for certain operations. + + **Arguments:** + + * `device`: The device to use for autocasting (e.g., "cuda", "cpu"). + * `dtype` (`torch.dtype`, optional): The lower-precision data type to use. + Defaults to `torch.bfloat16`. + * `env` (optional): The `torchax` environment. If not provided, the default + environment is used. + """ del device if env is None: import torchax diff --git a/torchax/torchax/config.py b/torchax/torchax/config.py index f439c656287b..62fd98ea8585 100644 --- a/torchax/torchax/config.py +++ b/torchax/torchax/config.py @@ -3,6 +3,36 @@ @dataclasses.dataclass class Configuration: + """A dataclass for configuring the behavior of `torchax`. + + **Attributes:** + + * `debug_print_each_op` (`bool`): If `True`, prints each operation as it is + dispatched. + * `debug_accuracy_for_each_op` (`bool`): If `True`, checks the accuracy of + each operation by comparing its output with the equivalent PyTorch + operation on the CPU. + * `debug_mixed_tensor` (`bool`): If `True`, enables debugging for mixed + tensor operations. + * `debug_print_each_op_operands` (`bool`): If `True`, prints the operands of + each operation. + * `use_int32_for_index` (`bool`): If `True`, uses `int32` for indexing + operations. + * `allow_mixed_math_with_scalar_tensor` (`bool`): If `True`, allows mixed + math operations between `torchax.Tensor` and scalar `torch.Tensor`s. + * `force_materialize_views` (`bool`): If `True`, eagerly materializes `View` + objects into `torchax.Tensor`s. + * `use_dlpack_for_data_conversion` (`bool`): If `True`, uses DLPack for + converting between `jax.Array` and `torch.Tensor`. + * `use_tpu_flash_attention` (`bool`): If `True`, uses TPU-optimized flash + attention. + * `shmap_flash_attention` (`bool`): If `True`, uses `shard_map` for flash + attention. + * `treat_cuda_as_jax_device` (`bool`): If `True`, treats CUDA devices as JAX + devices. + * `internal_respect_torch_return_dtypes` (`bool`): If `True`, respects the + return data types of PyTorch operations. + """ debug_print_each_op: bool = False debug_accuracy_for_each_op: bool = False debug_mixed_tensor: bool = False diff --git a/torchax/torchax/decompositions.py b/torchax/torchax/decompositions.py index d1c1f463d88a..cbae0714bd12 100644 --- a/torchax/torchax/decompositions.py +++ b/torchax/torchax/decompositions.py @@ -1,10 +1,10 @@ -"""This file contains some decompositons that are not available in torch stable. +"""This file contains PyTorch operator decompositions that are not available in +the stable version of PyTorch. -Most likely from Content of -https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py -at main branch HEAD that we find useful here. - -Can also contain decompositions of a torch op in terms of other torch ops. +The decompositions are primarily sourced from the `main` branch of the PyTorch +repository and are included here to provide support for newer operators. This +module can also contain decompositions of a PyTorch op in terms of other +PyTorch ops. """ import functools @@ -104,6 +104,7 @@ def _reflection_or_replication_pad( def bernoulli(self, *, generator=None): + """Decomposition for the `bernoulli` operator.""" return (torch.rand_like(self, dtype=torch.float32) < self).to(self.dtype) @@ -111,11 +112,13 @@ def bernoulli(self, *, generator=None): def rand_like(self, **kwargs): + """Decomposition for the `rand_like` operator.""" dtype = kwargs.get("dtype", self.dtype) return torch.rand(self.shape, dtype=dtype) def channel_shuffle(self, groups): + """Decomposition for the `channel_shuffle` operator.""" batchsize, channels, height, width = self.shape channels_per_group = channels // groups self = self.reshape(batchsize, groups, channels_per_group, height, width) @@ -131,6 +134,7 @@ def channel_shuffle(self, groups): def bernoulli_float(self, p=0.5): + """Decomposition for the `bernoulli_` operator with a float probability.""" return self.bernoulli_(p) @@ -150,9 +154,10 @@ def _grid_sampler_3d( padding_mode: int = 0, align_corners: bool = False, ) -> Tensor: - """References: https://github.com/pytorch/pytorch/blob/06a7dc21c1005750598c37f3adbc031183c74de6/torch/_decomp/decompositions.py#L4075 + """Decomposition for the `grid_sampler_3d` operator. - The above implement the 2d case. + This implementation is based on the 2D version in the PyTorch repository: + https://github.com/pytorch/pytorch/blob/06a7dc21c1005750598c37f3adbc031183c74de6/torch/_decomp/decompositions.py#L4075 """ _expand_grid = False torch._check( @@ -773,4 +778,4 @@ def get_summand(ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor, MUTABLE_DECOMPOSITION = [ torch.ops.aten.bernoulli_.Tensor, torch.ops.aten.bernoulli_.float, -] +] \ No newline at end of file diff --git a/torchax/torchax/device_module.py b/torchax/torchax/device_module.py index be028cfcc21d..41c98b5cad42 100644 --- a/torchax/torchax/device_module.py +++ b/torchax/torchax/device_module.py @@ -2,32 +2,40 @@ def _is_in_bad_fork(): + """Returns `False` as forking is not applicable in the same way as CUDA.""" return False def manual_seed_all(seed): + """A placeholder for API compatibility; does not affect JAX's PRNG.""" pass def device_count(): + """Returns `1` as JAX manages devices as a single logical device.""" return 1 def get_rng_state(): + """Returns an empty list for API compatibility.""" return [] def set_rng_state(new_state, device): + """A placeholder for API compatibility; does not affect JAX's PRNG.""" pass def is_available(): + """Returns `True` if JAX is available.""" return True def current_device(): + """Returns `0` as JAX manages devices as a single logical device.""" return 0 def get_amp_supported_dtype(): + """Returns the data types supported by AMP (Automatic Mixed Precision).""" return [torch.float16, torch.bfloat16] diff --git a/torchax/torchax/export.py b/torchax/torchax/export.py index 987fb92ba6ee..be2da17e1cb8 100644 --- a/torchax/torchax/export.py +++ b/torchax/torchax/export.py @@ -16,7 +16,13 @@ class JaxInterpreter(torch.fx.Interpreter): - """Experimental.""" + """An `fx.Interpreter` that executes a PyTorch FX graph using JAX. + + This interpreter traverses an FX graph and replaces PyTorch operations with + their corresponding JAX implementations from the `torchax` operator registry. + It is a key component in the process of exporting PyTorch models to JAX and + StableHLO. + """ def __init__(self, graph_module): super().__init__(graph_module) @@ -74,11 +80,24 @@ def _extract_states_from_exported_program(exported_model): def exported_program_to_jax(exported_program, export_raw: bool = False): - """returns a pytree of jax arrays(state), and + """Converts a `torch.export.ExportedProgram` to a JAX-compatible function and state. + + This function takes a PyTorch `ExportedProgram`, runs the necessary + decompositions, and returns a JAX-compatible function and the model's state + (parameters and buffers) as JAX arrays. + + **Arguments:** - a callable(func) that is jax function. + * `exported_program` (`torch.export.ExportedProgram`): The PyTorch + `ExportedProgram` to convert. + * `export_raw` (`bool`, optional): If `True`, returns the raw states and + function without converting them to JAX arrays. Defaults to `False`. - func(state, input) would be how you call it. + **Returns:** + + A tuple containing: + * A pytree of JAX arrays representing the model's state. + * A JAX-callable function that takes the state and inputs as arguments. """ if torch.__version__ >= '2.2': # torch version 2.1 didn't expose this yet @@ -115,8 +134,19 @@ def func(states, inputs): def extract_avals(exported): - """Return JAX Abstract Value shapes for all input parameters of the exported - program. This supports dynamic batch dimensions, including with constraints. + """Returns JAX abstract values (`ShapeDtypeStruct`) for all input parameters of the exported program. + + This function supports dynamic batch dimensions, including those with + constraints. + + **Arguments:** + + * `exported` (`torch.export.ExportedProgram`): The exported PyTorch program. + + **Returns:** + + A list of `jax.ShapeDtypeStruct` objects representing the abstract values of + the input parameters. """ def _to_aval(arg_meta, symbolic_shapes): @@ -232,12 +262,24 @@ def _build_symbolic_shape(sym, constraint, free_symbols): def exported_program_to_stablehlo(exported_program): - """Replacement for torch_xla.stablehlo.exported_program_to_stablehlo + """Converts a `torch.export.ExportedProgram` to StableHLO. + + This function serves as a replacement for + `torch_xla.stablehlo.exported_program_to_stablehlo`. It supports dynamic + dimension sizes and generates explicit checks for Dynamo guards in the IR + using `shape_assertion` custom calls. + + **Arguments:** + + * `exported_program` (`torch.export.ExportedProgram`): The exported PyTorch + program. - Convert a program exported via torch.export to StableHLO. + **Returns:** - This supports dynamic dimension sizes and generates explicit checks for - dynamo guards in the IR using shape_assertion custom_call ops. + A tuple containing: + * The model's state (weights) as a pytree of JAX arrays. + * A `jax.export.Exported` object containing the StableHLO representation of + the model. """ weights, func = exported_program_to_jax(exported_program) jax_avals = extract_avals(exported_program) diff --git a/torchax/torchax/flax.py b/torchax/torchax/flax.py index 28542d79c90e..4f3dce83e587 100644 --- a/torchax/torchax/flax.py +++ b/torchax/torchax/flax.py @@ -6,8 +6,32 @@ class FlaxNNModule(torch.nn.Module): + """A `torch.nn.Module` that wraps a Flax module for interoperability. + + This class allows you to use a Flax module within a PyTorch model. It + initializes the Flax module, extracts its parameters, and wraps them in a + `torch.nn.ParameterDict` so they can be managed by PyTorch. The `forward` + pass then calls the Flax module's `apply` method with the appropriate + parameters. + + **Attributes:** + + * `_params` (`torch.nn.Module`): A nested `torch.nn.Module` that holds the + parameters of the Flax module. + * `_flax_module`: The original Flax module. + """ def __init__(self, env, flax_module, sample_args, sample_kwargs=None): + """Initializes the `FlaxNNModule`. + + **Args:** + + * `env`: The `torchax` environment. + * `flax_module`: The Flax module to wrap. + * `sample_args`: A tuple of sample arguments to initialize the Flax module. + * `sample_kwargs` (optional): A dictionary of sample keyword arguments to + initialize the Flax module. + """ super().__init__() prng = env.prng_key sample_kwargs = sample_kwargs or {} @@ -34,6 +58,7 @@ def _decode_nested_dict(self, child_module): return result def forward(self, *args, **kwargs): + """Performs the forward pass by calling the wrapped Flax module.""" nested_dict_params = self._decode_nested_dict(self._params) return tx.interop.call_jax(self._flax_module.apply, nested_dict_params, *args, **kwargs) diff --git a/torchax/torchax/mesh_util.py b/torchax/torchax/mesh_util.py index 208d86a1bac6..e147546dbbe9 100644 --- a/torchax/torchax/mesh_util.py +++ b/torchax/torchax/mesh_util.py @@ -38,46 +38,52 @@ def _shard_first_multiple_of(axis_name, shape, multiple_of): class SingleAxisSharder: - """A callable object that generates PartitionSpecs for single-axis sharding. + """A callable object that generates `PartitionSpec`s for single-axis sharding. - This sharder strategy attempts to shard the *first* dimension of a tensor + This sharding strategy attempts to shard the *first* dimension of a tensor that is divisible by the specified `axis_size` along the given `axis_name`. - It's useful for simple 1D mesh sharding scenarios like FSDP where parameters + It's useful for simple 1D mesh sharding scenarios like FSDP, where parameters are typically sharded along one dimension. - Attributes: - axis_name: The name of the mesh axis to shard along. - axis_size: The size of the mesh axis (number of devices along that axis). + **Attributes:** + + * `axis_name` (`str`): The name of the mesh axis to shard along. + * `axis_size` (`int`): The size of the mesh axis (number of devices along + that axis). + * `replicate_unshardable` (`bool`): If `True`, tensors that cannot be sharded + will be replicated. """ def __init__(self, axis_name, axis_size, replicate_unshardable=False): - """Initializes the SingleAxisSharder. + """Initializes the `SingleAxisSharder`. + + **Args:** - Args: - axis_name: The name of the mesh axis (e.g., "fsdp", "data"). - axis_size: The number of devices along the specified mesh axis. - replicate_unshardable: indicate whether it should return replicated sharding - (P()) when none of the axis is divisible by the axis size. + * `axis_name` (`str`): The name of the mesh axis (e.g., "fsdp", "data"). + * `axis_size` (`int`): The number of devices along the specified mesh axis. + * `replicate_unshardable` (`bool`): If `True`, returns a replicated sharding + (`P()`) when no dimension is divisible by the axis size. """ self.axis_name = axis_name self.axis_size = axis_size self.replicate_unshardable = replicate_unshardable def __call__(self, name, shapedtype): - """Generates a PartitionSpec for a given tensor name and shaped type. - - Args: - name: The name of the tensor (e.g., parameter name). This argument is - provided for compatibility with more complex sharders but is not used - by this simple sharder. - shapedtype: An object with a `.shape` attribute describing the tensor's shape, - and `.dtype` describing it's dtype. Example: jax.Array, jax.ShapeDtypeStruct - or a torch.Tensor) - - Returns: - A jax.sharding.PartitionSpec determined by finding the first dimension - in `shapedtype.shape` divisible by `self.axis_size` using the helper - `_shard_first_multiple_of`. + """Generates a `PartitionSpec` for a given tensor name and shaped type. + + **Args:** + + * `name` (`str`): The name of the tensor (e.g., parameter name). This + argument is provided for compatibility with more complex sharders but is + not used by this simple sharder. + * `shapedtype`: An object with a `.shape` attribute describing the + tensor's shape, and a `.dtype` attribute describing its dtype (e.g., + `jax.Array`, `jax.ShapeDtypeStruct`, or a `torch.Tensor`). + + **Returns:** + + A `jax.sharding.PartitionSpec` determined by finding the first dimension + in `shapedtype.shape` that is divisible by `self.axis_size`. """ del name sharding = _shard_first_multiple_of(self.axis_name, shapedtype.shape, @@ -91,36 +97,38 @@ def __call__(self, name, shapedtype): class Mesh: - """A helper class that wraps `jax.sharding.Mesh` object. - - The goal of this class is to provide helper methods that facilitate the - sharding of PyTorch tensors or models given a JAX device mesh configuration. - It simplifies initializing models directly into a sharded state. - - Attributes: - jax_mesh: The underlying `jax.sharding.Mesh` object defining the device grid - and axis names. - _sharder: The default sharding strategy callable (like SingleAxisSharder) - used to determine the PartitionSpec for each parameter if not overridden - during method calls. Can be None if no default is appropriate or set. + """A helper class that wraps a `jax.sharding.Mesh` object. + + This class provides helper methods for sharding PyTorch tensors and models + across a JAX device mesh, simplifying the process of initializing models + directly into a sharded state. + + **Attributes:** + + * `jax_mesh` (`jax.sharding.Mesh`): The underlying `jax.sharding.Mesh` object + that defines the device grid and axis names. + * `_sharder`: The default sharding strategy callable (like + `SingleAxisSharder`) used to determine the `PartitionSpec` for each + parameter if not overridden. """ @classmethod def fsdp_mesh(cls, axis_name="fsdp"): - """Creates a Mesh instance suitable for 1D FSDP-style sharding. + """Creates a `Mesh` instance suitable for 1D FSDP-style sharding. + + This method creates a 1D mesh that encompasses all available XLA devices and + assigns the specified `axis_name` to this dimension. It then creates a + `Mesh` instance with a `SingleAxisSharder` configured for this 1D mesh. - This named constructor creates a 1D mesh encompassing all available XLA - devices. It assigns the specified `axis_name` to this single dimension. - It then creates a `Mesh` instance using this JAX mesh and a - `SingleAxisSharder` configured appropriately for this 1D mesh. + **Args:** - Args: - axis_name: The name to assign to the single mesh axis (default: "fsdp"). - This name will be used by the default `SingleAxisSharder`. + * `axis_name` (`str`, optional): The name to assign to the single mesh + axis. Defaults to `"fsdp"`. - Returns: - A Mesh instance configured with a 1D JAX mesh across all devices and a - corresponding SingleAxisSharder. + **Returns:** + + A `Mesh` instance configured with a 1D JAX mesh and a corresponding + `SingleAxisSharder`. """ ndevice = jax.device_count() jax_mesh = jax.make_mesh((ndevice,), (axis_name,)) @@ -128,19 +136,16 @@ def fsdp_mesh(cls, axis_name="fsdp"): return cls(jax_mesh, SingleAxisSharder(axis_name, ndevice, True)) def __init__(self, jax_mesh, sharder=None): - """Initializes the Mesh helper. - - Args: - jax_mesh: A pre-configured `jax.sharding.Mesh` object defining the - physical device grid and logical axis names. - sharder: An optional callable (e.g., an instance of SingleAxisSharder) - that takes (name, shapedtype) and returns a `jax.sharding.PartitionSpec`. - This serves as the default sharding strategy. - If None, and the provided `jax_mesh` has exactly one axis, a - `SingleAxisSharder` is created automatically for that single axis. - If None and the mesh has multiple axes, `_sharder` remains None, and - an `override_sharder` must be provided to methods like - `initialize_model_sharded`. + """Initializes the `Mesh` helper. + + **Args:** + + * `jax_mesh` (`jax.sharding.Mesh`): A pre-configured `jax.sharding.Mesh` + object that defines the physical device grid and logical axis names. + * `sharder` (optional): A callable that takes `(name, shapedtype)` and + returns a `jax.sharding.PartitionSpec`. This serves as the default + sharding strategy. If `None`, and the provided `jax_mesh` has exactly + one axis, a `SingleAxisSharder` is created automatically. """ self.jax_mesh = jax_mesh if sharder is None: @@ -156,35 +161,24 @@ def initialize_model_sharded(self, override_sharder=None): """Initializes a PyTorch model with its parameters sharded across the mesh. - This method orchestrates the initialization of a `torch.nn.Module` such - that its parameters are created directly on the target devices according - to the sharding specifications derived from the mesh and the chosen sharder. - It leverages `torchax.interop.jax_jit` to achieve this. + This method initializes a `torch.nn.Module` such that its parameters are + created directly on the target devices according to the sharding + specifications derived from the mesh and the chosen sharder. + + **Args:** - Args: - model_class: The PyTorch model class (a subclass of `torch.nn.Module`). - init_args: A tuple containing the positional arguments required by the + * `model_class`: The PyTorch model class (a subclass of `torch.nn.Module`). + * `init_args`: A tuple of positional arguments for the `model_class.__init__` + method. + * `init_kwargs` (optional): A dictionary of keyword arguments for the `model_class.__init__` method. - init_kwargs: An optional dictionary containing the keyword arguments for - the `model_class.__init__` method. Defaults to None (treated as {}). - override_sharder: An optional callable sharding strategy to use - specifically for this initialization. If provided, it takes precedence - over the mesh's default `_sharder`. It must accept `(name, shapedtype)` - and return a `PartitionSpec`. If None, the mesh's default `_sharder` - is used. - - Returns: - An instance of `model_class` whose parameters have been initialized and - are represented by sharded tensors distributed across the devices in the - `jax_mesh`. - - Raises: - ValueError: If no sharder is available (i.e., `override_sharder` is None - and the mesh's default `_sharder` is also None). - AssertionError: Can be raised by the sharder (e.g., `SingleAxisSharder`) - if it fails to determine a valid sharding for any parameter. - TypeError: If `shapedtype` passed to the sharder doesn't have a `.shape`. - Other errors from JAX JIT compilation or PyTorch model initialization. + * `override_sharder` (optional): A callable sharding strategy to use for + this initialization, which takes precedence over the default sharder. + + **Returns:** + + An instance of `model_class` with its parameters initialized and sharded + across the devices in the `jax_mesh`. """ init_kwargs = init_kwargs or {} with torch.device("meta"), torchax.disable_temporarily(): @@ -211,6 +205,7 @@ def model_initializer(): return model def shard_model(self, model, override_sharder=None): + """Shards the parameters of an existing model across the mesh.""" sharder = override_sharder or self._sharder states = model.state_dict() output_shards = { diff --git a/torchax/torchax/ops/mappings.py b/torchax/torchax/ops/mappings.py index 409a6d8350be..fd9fbabede95 100644 --- a/torchax/torchax/ops/mappings.py +++ b/torchax/torchax/ops/mappings.py @@ -8,6 +8,21 @@ def t2j(t, use_dlpack=True): + """Converts a `torch.Tensor` to a `jax.Array`. + + This function handles the conversion of a PyTorch tensor to a JAX array, + with an option to use DLPack for zero-copy conversion where possible. + + **Arguments:** + + * `t` (`torch.Tensor`): The PyTorch tensor to convert. + * `use_dlpack` (`bool`, optional): If `True`, attempts to use DLPack for + zero-copy conversion. Defaults to `True`. + + **Returns:** + + A `jax.Array` that is equivalent to the input tensor. + """ is_bool = False if t.dtype == torch.bool: is_bool = True @@ -43,6 +58,21 @@ def t2j(t, use_dlpack=True): def j2t(x, use_dlpack=True): + """Converts a `jax.Array` to a `torch.Tensor`. + + This function handles the conversion of a JAX array to a PyTorch tensor, + with an option to use DLPack for zero-copy conversion where possible. + + **Arguments:** + + * `x` (`jax.Array`): The JAX array to convert. + * `use_dlpack` (`bool`, optional): If `True`, attempts to use DLPack for + zero-copy conversion. Defaults to `True`. + + **Returns:** + + A `torch.Tensor` that is equivalent to the input array. + """ with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): res = None if use_dlpack: @@ -126,6 +156,7 @@ def j2t(x, use_dlpack=True): def t2j_dtype(dtype): + """Converts a `torch.dtype` to a JAX dtype.""" if dtype not in TORCH_DTYPE_TO_JAX: raise RuntimeError( f'Attempting to convert unknown type: {dtype} to jax type,') @@ -133,6 +164,7 @@ def t2j_dtype(dtype): def j2t_dtype(dtype): + """Converts a JAX dtype to a `torch.dtype`.""" if dtype not in JAX_DTYPE_TO_TORCH: raise RuntimeError( f'Attempting to convert unknown type: {dtype} to torch type,') diff --git a/torchax/torchax/ops/op_base.py b/torchax/torchax/ops/op_base.py index d69e85ae50a6..9d1f6b585483 100644 --- a/torchax/torchax/ops/op_base.py +++ b/torchax/torchax/ops/op_base.py @@ -12,6 +12,21 @@ class InplaceOp: + """A wrapper for creating in-place versions of functional operators. + + This class takes a functional operator and creates an in-place version of it. + It handles the mutation of the input tensor, including the case where the + input is a `View`. + + **Attributes:** + + * `functional`: The functional operator to wrap. + * `replace` (`bool`): If `True`, the underlying `jax.Array` of the input + tensor is replaced with the new value. Otherwise, the new value is + copied into the input tensor. + * `position_to_mutate` (`int`): The position of the argument to be mutated. + * `is_jax_func` (`bool`): `True` if the functional operator is a JAX function. + """ def __init__(self, functional_op, @@ -51,6 +66,11 @@ def __call__(self, *args, **kwargs): class OutVariant: + """A wrapper for creating out-of-place versions of functional operators. + + This class takes a functional operator and creates an out-of-place version + that writes the result to the `out` keyword argument. + """ def __call__(self, *args, **kwargs): to_mutate = kwargs['out'] @@ -63,13 +83,16 @@ def __call__(self, *args, **kwargs): def convert_dtype(use_default_dtype: bool = True): - """Converts `dtype` kwarg of function from torch to JAX. + """A decorator that converts the `dtype` kwarg of a function from `torch.dtype` to a JAX dtype. + + **Args:** - Args: - use_default_dtype: Whether to use torch default dtype if none is provided. + * `use_default_dtype` (`bool`): If `True`, uses the default PyTorch dtype if + no `dtype` is provided. - Returns: - A decorator that wraps a JAX implementation of a torch function. + **Returns:** + + A decorator that wraps a JAX implementation of a PyTorch function. """ def decorator(func: types.TorchCallable): @@ -94,9 +117,10 @@ def wrapper(*args: P.args, def maybe_convert_constant_dtype(val: Optional[types.JaxValue], dtype: Optional[jnp.dtype]): - """Optionally converts scalar constant's dtype using `numpy` + """Optionally converts the dtype of a scalar constant using NumPy. - Use in cases where you require a constant and can't handle a traced array. + This function is useful in cases where you require a constant and cannot + handle a traced array. """ if val and dtype: if isinstance(val, jax.Array): @@ -108,7 +132,7 @@ def maybe_convert_constant_dtype(val: Optional[types.JaxValue], def promote_int_input(f: Callable[Concatenate[jax.Array, P], types.JaxValue]): - """If the first argument is an int array, promote it to float32.""" + """A decorator that promotes the first integer input of a function to `float32`.""" @functools.wraps(f) def wrapper(x: jax.Array, *args: P.args, **kwargs: P.kwargs): @@ -123,9 +147,11 @@ def wrapper(x: jax.Array, *args: P.args, **kwargs: P.kwargs): def foreach_loop(seq: jax.Array, fn: Callable[[jax.Array, jax.Array], jax.Array], init_val=0.0): - """Run `fn` for each element of 1D array `seq`. + """Applies a function to each element of a 1D array. - Similar to `functools.reduce`, but implemented with `jax.lax.fori_loop`.""" + This function is similar to `functools.reduce`, but is implemented with + `jax.lax.fori_loop` for efficient execution on accelerators. + """ assert len(seq.shape) == 1 return jax.lax.fori_loop(0, len(seq), lambda i, carry: fn(carry, seq[i]), init_val) diff --git a/torchax/torchax/ops/ops_registry.py b/torchax/torchax/ops/ops_registry.py index aa0d61cbb491..0359879a747a 100644 --- a/torchax/torchax/ops/ops_registry.py +++ b/torchax/torchax/ops/ops_registry.py @@ -7,6 +7,22 @@ @dataclasses.dataclass class Operator: + """A dataclass that represents a `torchax` operator. + + This class holds the implementation of a PyTorch operator, along with + metadata that describes how it should be handled by the `torchax` dispatcher. + + **Attributes:** + + * `torch_op` (`TorchCallable`): The original PyTorch operator. + * `func` (`Union[TorchCallable, JaxCallable]`): The implementation of the + operator, which can be either a PyTorch callable or a JAX callable. + * `is_jax_function` (`bool`): `True` if the implementation is a JAX function. + * `is_user_defined` (`bool`): `True` if the operator is defined by the user. + * `needs_env` (`bool`): `True` if the operator needs access to the `torchax` + environment. + * `is_view_op` (`bool`): `True` if the operator is a view operation. + """ torch_op: TorchCallable func: Union[TorchCallable, JaxCallable] is_jax_function: bool @@ -25,6 +41,28 @@ def register_torch_dispatch_op(aten_op, is_user_defined=False, needs_env=False, is_view_op=False): + """Registers a `torch_dispatch` operator. + + This function is used to register an implementation for a PyTorch ATen + operator. + + **Arguments:** + + * `aten_op`: The ATen operator to register (e.g., `torch.ops.aten.add`). + * `impl_callable`: The implementation of the operator. + * `is_jax_function` (`bool`, optional): `True` if the implementation is a JAX + function. Defaults to `True`. + * `is_user_defined` (`bool`, optional): `True` if the operator is defined by + the user. Defaults to `False`. + * `needs_env` (`bool`, optional): `True` if the operator needs access to the + `torchax` environment. Defaults to `False`. + * `is_view_op` (`bool`, optional): `True` if the operator is a view + operation. Defaults to `False`. + + **Returns:** + + The implementation callable. + """ op = Operator( aten_op, impl_callable, @@ -44,6 +82,28 @@ def register_torch_function_op(torch_func, is_user_defined=False, needs_env=False, is_view_op=False): + """Registers a `torch_function` operator. + + This function is used to register an implementation for a `torch_function` + operator (e.g., `torch.add`). + + **Arguments:** + + * `torch_func`: The `torch_function` operator to register. + * `impl_callable`: The implementation of the operator. + * `is_jax_function` (`bool`, optional): `True` if the implementation is a JAX + function. Defaults to `True`. + * `is_user_defined` (`bool`, optional): `True` if the operator is defined by + the user. Defaults to `False`. + * `needs_env` (`bool`, optional): `True` if the operator needs access to the + `torchax` environment. Defaults to `False`. + * `is_view_op` (`bool`, optional): `True` if the operator is a view + operation. Defaults to `False`. + + **Returns:** + + The implementation callable. + """ op = Operator( torch_func, impl_callable, diff --git a/torchax/torchax/tensor.py b/torchax/torchax/tensor.py index 3916fe6501b8..867f626e9b65 100644 --- a/torchax/torchax/tensor.py +++ b/torchax/torchax/tensor.py @@ -38,6 +38,18 @@ def log_nested(env, message): class Tensor(torch.Tensor): + """A `torch.Tensor` subclass that wraps a `jax.Array`. + + This class is the core of `torchax`, allowing PyTorch operations to be + dispatched to JAX. It holds a `jax.Array` internally and overrides + the necessary methods to ensure that operations are correctly routed + through the `torchax` dispatch mechanism. + + **Attributes:** + + * `_elem` (`jax.Array`): The underlying JAX array. + * `_env` (`Environment`): The `torchax` environment this tensor belongs to. + """ @staticmethod def __new__(cls, elem, env, requires_grad=False): @@ -113,17 +125,21 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 'call torchax.enable_globally() before.') def detach(self): + """Detaches the tensor from the computation graph.""" return Tensor(jax.lax.stop_gradient(self.jax()), self._env) def numpy(self) -> numpy.ndarray: + """Converts the tensor to a NumPy array.""" import numpy as np return np.array(self._elem) def jax(self) -> jax.Array: + """Returns the underlying `jax.Array`.""" return self._elem def torch(self) -> torch.Tensor: + """Converts the tensor to a standard `torch.Tensor`.""" return self._env.j2t_copy(self.jax()) @property @@ -153,18 +169,22 @@ def data(self, other): self._elem = other._elem def apply_jax(self, jax_function, *args, **kwargs): + """Applies a JAX function to the underlying `jax.Array`.""" # Call a jax function on _elem res = jax_function(self._elem, *args, **kwargs) return self._env.j2t_iso(res) def apply_jax_(self, jax_function, *args, **kwargs): + """Applies a JAX function in-place to the underlying `jax.Array`.""" self._elem = jax_function(self._elem, *args, **kwargs) return self def tolist(self): + """Converts the tensor to a list.""" return self._elem.tolist() def shard_(self, sharding): + """Applies a sharding constraint to the tensor in-place.""" self.apply_jax_(jax.lax.with_sharding_constraint, sharding) @@ -247,6 +267,8 @@ def __torch_function__(self, class XLADispatchMode(torch_dispatch.TorchDispatchMode): + """A `TorchDispatchMode` that intercepts PyTorch operations and dispatches them to the JAX backend through the `Environment`. + """ def __init__(self, env): self.env = env @@ -329,16 +351,12 @@ def __getattr__(self, name): class Environment(contextlib.ContextDecorator): - """This class holds a set of configurations and "globals" needed - - for executing torch program using jax. - Things included so far: + """Manages the execution environment for `torchax`. - op registry - PRNGKey - Configs - - Also helper functions to manipulate those. + This class holds the configuration, operator registry, PRNG key, and other + "global" state needed to execute PyTorch programs using the JAX backend. + It also provides helper functions for dispatching operations and converting + tensors between PyTorch and JAX representations. """ def __init__(self, configuration=None): @@ -370,6 +388,7 @@ def param(self): return self._property.content[-1] def manual_seed(self, key): + """Sets the seed for the JAX random number generator.""" jax_key = jax.random.PRNGKey(key) new_prop = self.param.override(prng=jax_key) self._property.content.append(new_prop) @@ -522,6 +541,7 @@ def _torch_Tensor_to(self, args, kwargs): return self._to_copy(the_tensor, dtype, device) def dispatch(self, func, types, args, kwargs): + """Dispatches a PyTorch operation to the appropriate JAX implementation.""" kwargs = kwargs or {} if func in TENSOR_CONSTRUCTORS: return self._handle_tensor_constructor(func, args, kwargs) @@ -600,11 +620,13 @@ def is_not_torchax_tensor(x): return res def enable_torch_modes(self): + """Enables the `torchax` dispatch modes.""" self._dispatch_mode.__enter__() self._function_mode.__enter__() self.enabled = True def disable_torch_modes(self, *exc): + """Disables the `torchax` dispatch modes.""" if not exc: exc = (None, None, None) self._function_mode.__exit__(*exc) @@ -634,10 +656,12 @@ def to_xla(self, torchvalues): return res def t2j_iso(self, torchtensors): - """Convert torchax Tensor to jax array. - - This function will not copy, will just unwrap the inner jax array out. - Note: iso is short for "isomorphic" + """Converts `torchax.Tensor`s to `jax.Array`s without copying. + + This function unwraps the underlying `jax.Array` from each `torchax.Tensor` + in the input pytree. + + Note: "iso" is short for "isomorphic". """ def to_jax(x): @@ -657,6 +681,7 @@ def to_jax(x): return res def v2t_iso(self, views): + """Converts `torchax.View`s to `torchax.Tensor`s without copying.""" def to_tensor(x): if isinstance(x, View): @@ -667,18 +692,18 @@ def to_tensor(x): return res def j2t_iso(self, jaxarray): - """Convert jax array to torchax Tensor. - - This function will not copy, will just wrap the jax array with a torchax Tensor - Note: iso is short for "isomorphic" + """Converts `jax.Array`s to `torchax.Tensor`s without copying. + + This function wraps each `jax.Array` in the input pytree with a + `torchax.Tensor`. + + Note: "iso" is short for "isomorphic". """ return torch_pytree.tree_map_only(jax.Array, lambda x: Tensor(x, self), jaxarray) def j2t_copy(self, args): - """Convert torch.Tensor in cpu to a jax array - - This might involves copying the data (depending if dlpack is enabled) + """Converts `jax.Array`s to `torch.Tensor`s on the CPU, potentially copying the data. """ return torch_pytree.tree_map_only( jax.Array, @@ -686,9 +711,7 @@ def j2t_copy(self, args): args) def t2j_copy(self, args): - """Convert jax array to torch.Tensor in cpu. - - This might involves copying the data (depending if dlpack is enabled) + """Converts `torch.Tensor`s to `jax.Array`s, potentially copying the data. """ return torch_pytree.tree_map_only( torch.Tensor, @@ -696,6 +719,7 @@ def t2j_copy(self, args): args) def override_op_definition(self, op_to_override, op_impl): + """Overrides the implementation of a PyTorch operator.""" self._ops[op_to_override] = ops_registry.Operator( op_to_override, op_impl, @@ -706,6 +730,7 @@ def override_op_definition(self, op_to_override, op_impl): @contextlib.contextmanager def override_property(self, **kwargs): + """A context manager to temporarily override properties of the environment.""" new_prop = self.param.override(**kwargs) self._property.content.append(new_prop) yield diff --git a/torchax/torchax/tf_integration.py b/torchax/torchax/tf_integration.py index c9842089bfcf..0575e6f0b812 100644 --- a/torchax/torchax/tf_integration.py +++ b/torchax/torchax/tf_integration.py @@ -9,6 +9,21 @@ def exported_program_to_tf_function(ep, enable_xla=True): + """Converts a `torch.export.ExportedProgram` to a TensorFlow function. + + This function takes a PyTorch `ExportedProgram`, converts it to a JAX program, + and then wraps it as a TensorFlow function using `jax2tf`. + + **Arguments:** + + * `ep` (`torch.export.ExportedProgram`): The PyTorch `ExportedProgram` to convert. + * `enable_xla` (`bool`, optional): Whether to enable XLA compilation for the + converted TensorFlow function. Defaults to `True`. + + **Returns:** + + A TensorFlow function that is equivalent to the input `ExportedProgram`. + """ weights, jax_program = export.exported_program_to_jax(ep) wrapped = lambda *args: jax_program(weights, (args,)) avals = export.extract_avals(ep) @@ -30,6 +45,21 @@ def exported_program_to_tf_function(ep, enable_xla=True): def exported_program_to_tf_module(ep: torch.export.ExportedProgram, enable_xla=True) -> tf.Module: + """Converts a `torch.export.ExportedProgram` to a `tf.Module`. + + This function wraps the TensorFlow function created by + `exported_program_to_tf_function` in a `tf.Module` for easier use and saving. + + **Arguments:** + + * `ep` (`torch.export.ExportedProgram`): The PyTorch `ExportedProgram` to convert. + * `enable_xla` (`bool`, optional): Whether to enable XLA compilation. Defaults + to `True`. + + **Returns:** + + A `tf.Module` containing the converted TensorFlow function. + """ tfm = tf.Module() tfm.f = exported_program_to_tf_function(ep, enable_xla) return tfm @@ -42,22 +72,23 @@ def save_exported_program_as_tf_saved_model( function_alias: str = "", enable_xla=True, ): - """This function will export and save a pytorch ExportedProgram to tf.saved_model format. - - The resulting tf.saved_model can be used inference using tf.serving model - server - or further convert to tflite flatbuffer for on-device serving. - - Args: - torch_model: torch.nn.Module - model to export and save - args: Tuple[Any] - a set of args to trace the model with, i.e. - torch_model(*args) must run - saved_model_dir: os.PathLike - location to an empty directory to store the - saved_model - serving_key: str - serving key tag, this is used by tf.serving to know - which function to run. - function_alias: str - passed through saved_model.save, used to tag a - function for inference converter or other tools. + """Exports and saves a PyTorch `ExportedProgram` to the TensorFlow SavedModel format. + + The resulting SavedModel can be used for inference with TensorFlow Serving or + further converted to TFLite for on-device deployment. + + **Arguments:** + + * `ep` (`torch.export.ExportedProgram`): The PyTorch `ExportedProgram` to save. + * `saved_model_dir` (`os.PathLike`): The path to an empty directory where the + SavedModel will be stored. + * `serving_key` (`str`, optional): The serving key to use for the signature + definition. This is used by TensorFlow Serving to identify the function + to run. Defaults to `tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY`. + * `function_alias` (`str`, optional): An alias for the function, which can be + used by other tools. + * `enable_xla` (`bool`, optional): Whether to enable XLA compilation. Defaults + to `True`. """ tfm = exported_program_to_tf_module(ep, enable_xla=enable_xla) signatures = { @@ -82,22 +113,22 @@ def save_torch_module_as_tf_saved_model( function_alias: str = "", enable_xla=True, ): - """This function will export and save a pytorch nn.Module to tf.saved_model format. - - The resulting tf.saved_model can be used inference using tf.serving model - server - or further convert to tflite flatbuffer for on-device serving. - - Args: - torch_model: torch.nn.Module - model to export and save - args: Tuple[Any] - a set of args to trace the model with, i.e. - torch_model(*args) must run - saved_model_dir: os.PathLike - location to an empty directory to store the - saved_model - serving_key: str - serving key tag, this is used by tf.serving to know - which function to run. - function_alias: str - passed through saved_model.save, used to tag a - function for inference converter or other tools. + """Exports and saves a `torch.nn.Module` to the TensorFlow SavedModel format. + + This function first exports the `torch.nn.Module` to an `ExportedProgram` + and then saves it as a SavedModel. + + **Arguments:** + + * `torch_model` (`torch.nn.Module`): The PyTorch model to export and save. + * `args` (`Tuple[Any]`): A tuple of arguments to trace the model with (i.e., + `torch_model(*args)` must be a valid call). + * `saved_model_dir` (`os.PathLike`): The path to an empty directory where the + SavedModel will be stored. + * `serving_key` (`str`, optional): The serving key for the signature + definition. + * `function_alias` (`str`, optional): An alias for the function. + * `enable_xla` (`bool`, optional): Whether to enable XLA compilation. """ ep = torch.export.export(torch_model, args) save_exported_program_as_tf_saved_model(ep, saved_model_dir, serving_key, @@ -105,6 +136,16 @@ def save_torch_module_as_tf_saved_model( def exported_program_to_tflite_flatbuffer(ep: torch.export.ExportedProgram): + """Converts a `torch.export.ExportedProgram` to a TFLite flatbuffer. + + **Arguments:** + + * `ep` (`torch.export.ExportedProgram`): The PyTorch `ExportedProgram` to convert. + + **Returns:** + + A TFLite flatbuffer model. + """ tfm = exported_program_to_tf_module(ep) tf_concrete_func = tfm.f.get_concrete_function(*tfm.f.input_signature) converter = tf.lite.TFLiteConverter.from_concrete_functions( @@ -115,5 +156,16 @@ def exported_program_to_tflite_flatbuffer(ep: torch.export.ExportedProgram): def torch_module_to_tflite_flatbuffer(torch_model: torch.nn.Module, args: Tuple[Any]): + """Converts a `torch.nn.Module` to a TFLite flatbuffer. + + **Arguments:** + + * `torch_model` (`torch.nn.Module`): The PyTorch model to convert. + * `args` (`Tuple[Any]`): A tuple of arguments to trace the model with. + + **Returns:** + + A TFLite flatbuffer model. + """ ep = torch.export.export(torch_model, args) return exported_program_to_tflite_flatbuffer(ep) diff --git a/torchax/torchax/train.py b/torchax/torchax/train.py index fb4e16fc48ee..6f7ea24576dc 100644 --- a/torchax/torchax/train.py +++ b/torchax/torchax/train.py @@ -12,21 +12,34 @@ def make_train_step(model_fn, loss_fn, optax_optimizer, remat_policy=None): - """Make a function that do one train step given model and loss. - - model_fn: a function representing the model's forward: - i.e. has signature Callable[weights, buffers, args] -> result. Where, - weights is a pytree of trainable parameters - buffers is a pytree of non-trainable parameters / constants - args is the input data loaded from the data set - result is the return value of the model - loss_fn: a function to compute loss. - i.e. it has signature of Callable[result, label] -> loss - where, result is what model_fn returned - loss is loaded from the dataloader. - optax_optimizer: the optimizer from optax library. for example, optax.adam - remat_policy: One of jax.ad_checkpoint.checkpoint_policies, specifies how - to do gradient checkpointing. If None, then it means checkpoint everything. + """Creates a function that performs one training step. + + This function is designed to be used with JAX's `jit` for efficient training. + It takes a model function, a loss function, and an Optax optimizer, and + returns a function that computes the loss, calculates gradients, and updates + the model's weights. + + **Arguments:** + + * `model_fn`: A function representing the model's forward pass. It should + have the signature `Callable[weights, buffers, args] -> result`, where: + * `weights` is a pytree of trainable parameters. + * `buffers` is a pytree of non-trainable parameters and constants. + * `args` is the input data from the dataset. + * `result` is the model's output. + * `loss_fn`: A function to compute the loss. It should have the signature + `Callable[result, label] -> loss`, where: + * `result` is the output of `model_fn`. + * `label` is the ground truth from the dataloader. + * `optax_optimizer`: An optimizer from the Optax library (e.g., `optax.adam`). + * `remat_policy` (optional): A policy from `jax.ad_checkpoint.checkpoint_policies` + that specifies how to perform gradient checkpointing. If `None`, all + intermediate activations will be checkpointed. + + **Returns:** + + A function that performs one training step. It has the signature + `Callable[weights, buffers, opt_state, args, label] -> (loss, new_weights, new_opt_state)`. """ env = torchax.default_env() @@ -58,6 +71,18 @@ class Container: class ScannedModule(torch.nn.Module): + """A `torch.nn.Module` that applies a list of identical modules sequentially. + + This module is designed to be used with `jax.lax.scan` for efficient + execution of repeated layers. It takes a list of modules, stacks their + weights, and applies the same module function to the input in a loop. + + **Attributes:** + + * `checkpoint_policy`: The gradient checkpointing policy to use. + * `params`: A `torch.nn.ParameterDict` containing the stacked weights of the + input modules. + """ def __init__(self, module_list, checkpoint_policy=None): super().__init__() diff --git a/torchax/torchax/view.py b/torchax/torchax/view.py index 040fa24ef9e8..a2fcc0d9831b 100644 --- a/torchax/torchax/view.py +++ b/torchax/torchax/view.py @@ -22,77 +22,77 @@ class ViewInfoType(Enum): class ViewInfo(ABC): + """Abstract base class for all view operations. + + This class defines the interface for applying and updating view transformations + on JAX arrays. Each subclass represents a specific type of view, such as + a slice, reshape, or permutation. """ - Abstract base class for all view operations. - Defines the interface for applying and updating view transformations. - """ def __init__( self, view_info_type: ViewInfoType = ViewInfoType.INVALID, ): - """ - Initialize a ViewInfo object. + """Initializes a ViewInfo object. - Args: - view_info_type: The type of view operation - """ + Args: + view_info_type: The type of view operation. + """ self.view_info_type = view_info_type @abstractmethod def update_tensor(self, new_value: jax.Array, jax_array: jax.Array) -> jax.Array: - """ - Apply this view transformation to a JAX array and update its value. + """Applies this view transformation to a JAX array and updates its value. - Args: - new_value: The new values to set in the view - jax_array: The parent array to update + Args: + new_value: The new values to set in the view. + jax_array: The parent array to update. - Returns: - Updated array - """ + Returns: + The updated array. + """ pass @abstractmethod def transform_tensor(self, jax_array: jax.Array) -> jax.Array: - """ - Apply this view transformation to a JAX array. + """Applies this view transformation to a JAX array. - Args: - jax_array: The array to transform + Args: + jax_array: The array to transform. - Returns: - Transformed array - """ + Returns: + The transformed array. + """ pass @abstractmethod def calculate_output_shape(self, source: jax.Array) -> List[int]: - """ - Calculate the resulting shape after applying this view. + """Calculates the resulting shape after applying this view. - Args: - source: Original jax array before transformation + Args: + source: The original JAX array before transformation. - Returns: - Resulting shape after transformation - """ + Returns: + The resulting shape after transformation. + """ pass class NarrowInfo(ViewInfo): + """Represents a slicing operation on a tensor. + + This class handles operations like `tensor[1:3, :, 2:5:2]`. """ - Represents a slicing operation on a tensor. - Handles operations like tensor[1:3, :, 2:5:2]. - """ def __init__(self, slices: Union[slice, Tuple[slice]]) -> None: + """Initializes a NarrowInfo object. + + Args: + slices: The slice(s) to apply to the tensor. + For example, `jax_array.at[slices]` will return the + transformed tensor. """ - Args: - slices: The slice(s) to apply to the tensor. - E.g. jax_array.at[slices] will return the transformed tensor. - """ super().__init__(ViewInfoType.NARROW) self.slices = slices @@ -116,10 +116,10 @@ def calculate_output_shape(self, source: jax.Array) -> List[int]: class SelectInfo(ViewInfo): + """Represents a selection operation on a tensor. + + This is typically used for indexing operations that select specific elements. """ - Represents a selection operation on a tensor. - Typically used for indexing operations that select specific elements. - """ def __init__(self, dim: int = 0, @@ -151,9 +151,7 @@ def calculate_output_shape(self, source: jax.Array) -> List[int]: class AsStridedInfo(ViewInfo): - """ - Information for as_strided operations. - """ + """Represents an `as_strided` operation on a tensor.""" def __init__(self, stride: List[int], offset: int = 0) -> None: super().__init__(ViewInfoType.AS_STRIDED) @@ -178,18 +176,19 @@ def calculate_output_shape(self, source: jax.Array) -> List[int]: class DiagonalInfo(ViewInfo): + """Represents a diagonal operation on a tensor. + + This class is used to extract diagonal elements from a tensor. """ - Information for diagonal operations. - Extracts diagonal elements from a tensor. - """ def __init__(self, offset: int = 0, dim1: int = 0, dim2: int = 1) -> None: + """Initializes a DiagonalInfo object. + + Args: + offset: The offset from the main diagonal. + dim1: The first dimension for diagonal extraction. + dim2: The second dimension for diagonal extraction. """ - Args: - offset: Offset from the main diagonal - dim1: First dimension for diagonal extraction - dim2: Second dimension for diagonal extraction - """ super().__init__(ViewInfoType.DIAGONAL) self.offset: int = offset self.dim1: int = dim1 @@ -214,20 +213,24 @@ def calculate_output_shape(self, source: jax.Array) -> List[int]: class View(torch.Tensor): + """A `torch.Tensor` subclass that represents a view of another tensor. + + A `View` holds a reference to a parent `torchax.Tensor` or another `View`, + along with a `ViewInfo` object that describes the transformation to be + applied. This allows for lazy evaluation of view operations and efficient + in-place updates. """ - A View is a reference to another Tensor or another View, - with a transformation applied to it. - """ @staticmethod def __new__(cls, parent: Union["torchax.Tensor", "View"], view_info: ViewInfo, env: Any) -> "View": + """Creates a new `View` object. + + Args: + parent: The parent tensor or view. + view_info: Information about the view transformation. + env: The `torchax` environment. """ - Args: - parent: Parent tensor or view - view_info: Information about the view transformation - env: Environment for tensor operations - """ shape = view_info.calculate_output_shape(parent.jax()) return torch.Tensor._make_wrapper_subclass( cls, @@ -245,9 +248,7 @@ def __init__(self, parent: Union["torchax.Tensor", "View"], self._env = env def get_transformation_chain(self) -> List[ViewInfo]: - """ - Get all view transformations from the source tensor to this view. - """ + """Returns the chain of view transformations from the source tensor to this view.""" if isinstance(self.parent, View): transformations = self.parent.get_transformation_chain() transformations.append(self.view_info) @@ -258,18 +259,14 @@ def get_transformation_chain(self) -> List[ViewInfo]: __torch_function__ = torch._C._disabled_torch_function_impl def source_jax(self) -> jax.Array: - """ - Returns the source tensor. - """ + """Returns the underlying `jax.Array` of the source tensor.""" if isinstance(self.parent, View): return self.parent.source_jax() else: return self.parent.jax() def replace_source_jax(self, new_value: jax.Array) -> None: - """ - Update the source tensor with new values. - """ + """Updates the source tensor with a new `jax.Array`.""" if isinstance(self.parent, View): self.parent.replace_source_jax(new_value) else: @@ -277,9 +274,7 @@ def replace_source_jax(self, new_value: jax.Array) -> None: self.parent._elem = new_value def torch(self) -> "torchax.Tensor": - """ - Returns a Torchax tensor representing this view after all transformations - """ + """Returns a `torchax.Tensor` representing this view after all transformations.""" from torchax.tensor import Tensor return Tensor(self.jax(), self._env) @@ -289,11 +284,11 @@ def update( new_values: Union[jax.Array, "View", "torchax.Tensor"], view_infos: Optional[List[ViewInfo]] = None, ) -> None: + """Updates this view with new values, propagating changes back to the source. + + If `view_infos` is not provided, it will use the transformation chain + from the source tensor. """ - Update this view with new values, propagating changes back to source. - If view_infos is None, it will use the transformation chain - from the source tensor. - """ if view_infos is None: view_infos = self.get_transformation_chain() @@ -338,18 +333,14 @@ def __torch_dispatch__( 'call torchax.enable_globally() before.') def create_sub_view(self, view_info: ViewInfo) -> "View": - """ - Create a new view that is a child of this view. - """ + """Creates a new view that is a child of this view.""" return View(self, view_info, self._env) def __str__(self) -> str: return f"View({self.torch()})" def jax(self) -> jax.Array: - """ - Returns a copy of the source tensor after transformations. - """ + """Returns a copy of the source tensor after all transformations have been applied.""" result = self.source_jax() for view_info in self.get_transformation_chain(): result = view_info.transform_tensor(result) From 2889f69e94aa4f5b7d9ec70e3d03a789a2877d31 Mon Sep 17 00:00:00 2001 From: qihqi Date: Mon, 4 Aug 2025 08:07:47 -0700 Subject: [PATCH 035/133] Revert 2 accidental commits that I made. (#9536) --- torchax/torchax/CONTRIBUTING.md | 38 ++++++ torchax/torchax/__init__.py | 79 +------------ torchax/torchax/amp.py | 13 --- torchax/torchax/config.py | 30 ----- torchax/torchax/decompositions.py | 23 ++-- torchax/torchax/device_module.py | 8 -- torchax/torchax/export.py | 62 ++-------- torchax/torchax/flax.py | 25 ---- torchax/torchax/interop.py | 99 +--------------- torchax/torchax/mesh_util.py | 173 ++++++++++++++-------------- torchax/torchax/ops/mappings.py | 32 ----- torchax/torchax/ops/op_base.py | 46 ++------ torchax/torchax/ops/ops_registry.py | 60 ---------- torchax/torchax/tensor.py | 71 ++++-------- torchax/torchax/tf_integration.py | 116 +++++-------------- torchax/torchax/train.py | 55 +++------ torchax/torchax/view.py | 153 ++++++++++++------------ 17 files changed, 313 insertions(+), 770 deletions(-) create mode 100644 torchax/torchax/CONTRIBUTING.md diff --git a/torchax/torchax/CONTRIBUTING.md b/torchax/torchax/CONTRIBUTING.md new file mode 100644 index 000000000000..c61462850652 --- /dev/null +++ b/torchax/torchax/CONTRIBUTING.md @@ -0,0 +1,38 @@ +# Contributing to TorchXLA2 + +We appreciate all contributions. If you are planning to contribute a bug fix for an open issue, please comment on the thread and we're happy to provide any guidance. You are very welcome to pick issues from good first issue and help wanted labels. + +If you plan to contribute new features, utility functions or extensions to the core, please first open an issue and discuss the feature with us. Sending a PR without discussion might end up resulting in a rejected PR, because we might be taking the core in a different direction than you might be aware of. + + +# Developer setup + +## Mac setup: +@qihqi + +I am able to develop directly on mac (m1) laptop for most of parts. Using steps +in README.md works. The condensed version for easy copy & paste: + +```bash +conda create --name python=3.10 +conda activate +pip install --upgrade "jax[cpu]" torch +pip install -r test_requirements.txt +pip install -e . +pytest test +``` + +### VSCode + +I use vscode on my Mac. I loosely followed instruction in +https://code.visualstudio.com/docs/python/python-tutorial +to setup a proper python environment. + +The plugins I installed (a subset of the ones listed above) are: +* VSCode's official Python plugin +* Ruff formatter +* Python Debugger + +I also changed Python interpreter to point at the one in my conda env. +That is all the changes I have. + diff --git a/torchax/torchax/__init__.py b/torchax/torchax/__init__.py index 2cbcc3ad36e9..fe4c1c8ff046 100644 --- a/torchax/torchax/__init__.py +++ b/torchax/torchax/__init__.py @@ -40,30 +40,7 @@ def default_env(): def extract_jax(mod: torch.nn.Module, env=None): - """Extracts the state of a `torch.nn.Module` into a JAX-compatible format. - - **Arguments:** - - * `mod` (`torch.nn.Module`): The PyTorch model to extract the state from. - * `env` (optional): The `torchax` environment to use. If not provided, the default environment is used. - - **Returns:** - - A tuple containing: - - * A `pytree` of `jax.ndarray` representing the model's state (parameters and buffers). - * A JAX-callable function that executes the model's forward pass. - - **Usage:** - - ```python - import torch - import torchax - - model = torch.nn.Linear(10, 20) - states, jax_func = torchax.extract_jax(model) - ``` - """ + """Returns a pytree of jax.ndarray and a jax callable.""" if env is None: env = default_env() states = dict(mod.named_buffers()) @@ -83,31 +60,11 @@ def jax_func(states, args, kwargs=None): def enable_globally(): - """Enables `torchax` globally, which intercepts PyTorch operations and routes them to the JAX backend. This is the primary entry point for using `torchax`. - - **Usage:** - - ```python - import torchax - - torchax.enable_globally() - ``` - """ env = default_env().enable_torch_modes() return env def disable_globally(): - """Disables the `torchax` backend. After calling this, PyTorch operations will revert to their default behavior. - - **Usage:** - - ```python - import torchax - - torchax.disable_globally() - ``` - """ global env default_env().disable_torch_modes() @@ -153,40 +110,6 @@ class CompileOptions: def compile(fn, options: Optional[CompileOptions] = None): - """Compiles a function or `torch.nn.Module` for optimized execution with JAX. - - **Arguments:** - - * `fn`: The function or `torch.nn.Module` to compile. - * `options` (`CompileOptions`, optional): A `CompileOptions` object to configure the compilation process. - - **`CompileOptions`:** - - * `methods_to_compile` (`List[str]`, default=`['forward']`): A list of methods to compile when `fn` is a `torch.nn.Module`. - * `jax_jit_kwargs` (`Dict[str, Any]`, default=`{}`): A dictionary of keyword arguments to pass to `jax.jit`. - * `mode` (`str`, default=`'jax'`): The compilation mode. Currently, only `'jax'` is supported. - - **Returns:** - - A compiled version of the input function or module. - - **Usage:** - - ```python - import torch - import torchax - - model = torch.nn.Linear(10, 20) - compiled_model = torchax.compile(model) - - # With options - options = torchax.CompileOptions( - methods_to_compile=['forward', 'encode'], - jax_jit_kwargs={'static_argnums': (0,)} - ) - compiled_model = torchax.compile(model, options) - ``` - """ options = options or CompileOptions() if options.mode == 'jax': from torchax import interop diff --git a/torchax/torchax/amp.py b/torchax/torchax/amp.py index 0fc2bb1b23b4..ccbc63bead63 100644 --- a/torchax/torchax/amp.py +++ b/torchax/torchax/amp.py @@ -57,19 +57,6 @@ def is_float(a): @contextlib.contextmanager def autocast(device, dtype=torch.bfloat16, env=None): - """A context manager for automatic mixed precision (AMP). - - This context manager enables automatic mixed precision, which can improve - performance by using lower-precision data types for certain operations. - - **Arguments:** - - * `device`: The device to use for autocasting (e.g., "cuda", "cpu"). - * `dtype` (`torch.dtype`, optional): The lower-precision data type to use. - Defaults to `torch.bfloat16`. - * `env` (optional): The `torchax` environment. If not provided, the default - environment is used. - """ del device if env is None: import torchax diff --git a/torchax/torchax/config.py b/torchax/torchax/config.py index 62fd98ea8585..f439c656287b 100644 --- a/torchax/torchax/config.py +++ b/torchax/torchax/config.py @@ -3,36 +3,6 @@ @dataclasses.dataclass class Configuration: - """A dataclass for configuring the behavior of `torchax`. - - **Attributes:** - - * `debug_print_each_op` (`bool`): If `True`, prints each operation as it is - dispatched. - * `debug_accuracy_for_each_op` (`bool`): If `True`, checks the accuracy of - each operation by comparing its output with the equivalent PyTorch - operation on the CPU. - * `debug_mixed_tensor` (`bool`): If `True`, enables debugging for mixed - tensor operations. - * `debug_print_each_op_operands` (`bool`): If `True`, prints the operands of - each operation. - * `use_int32_for_index` (`bool`): If `True`, uses `int32` for indexing - operations. - * `allow_mixed_math_with_scalar_tensor` (`bool`): If `True`, allows mixed - math operations between `torchax.Tensor` and scalar `torch.Tensor`s. - * `force_materialize_views` (`bool`): If `True`, eagerly materializes `View` - objects into `torchax.Tensor`s. - * `use_dlpack_for_data_conversion` (`bool`): If `True`, uses DLPack for - converting between `jax.Array` and `torch.Tensor`. - * `use_tpu_flash_attention` (`bool`): If `True`, uses TPU-optimized flash - attention. - * `shmap_flash_attention` (`bool`): If `True`, uses `shard_map` for flash - attention. - * `treat_cuda_as_jax_device` (`bool`): If `True`, treats CUDA devices as JAX - devices. - * `internal_respect_torch_return_dtypes` (`bool`): If `True`, respects the - return data types of PyTorch operations. - """ debug_print_each_op: bool = False debug_accuracy_for_each_op: bool = False debug_mixed_tensor: bool = False diff --git a/torchax/torchax/decompositions.py b/torchax/torchax/decompositions.py index cbae0714bd12..d1c1f463d88a 100644 --- a/torchax/torchax/decompositions.py +++ b/torchax/torchax/decompositions.py @@ -1,10 +1,10 @@ -"""This file contains PyTorch operator decompositions that are not available in -the stable version of PyTorch. +"""This file contains some decompositons that are not available in torch stable. -The decompositions are primarily sourced from the `main` branch of the PyTorch -repository and are included here to provide support for newer operators. This -module can also contain decompositions of a PyTorch op in terms of other -PyTorch ops. +Most likely from Content of +https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py +at main branch HEAD that we find useful here. + +Can also contain decompositions of a torch op in terms of other torch ops. """ import functools @@ -104,7 +104,6 @@ def _reflection_or_replication_pad( def bernoulli(self, *, generator=None): - """Decomposition for the `bernoulli` operator.""" return (torch.rand_like(self, dtype=torch.float32) < self).to(self.dtype) @@ -112,13 +111,11 @@ def bernoulli(self, *, generator=None): def rand_like(self, **kwargs): - """Decomposition for the `rand_like` operator.""" dtype = kwargs.get("dtype", self.dtype) return torch.rand(self.shape, dtype=dtype) def channel_shuffle(self, groups): - """Decomposition for the `channel_shuffle` operator.""" batchsize, channels, height, width = self.shape channels_per_group = channels // groups self = self.reshape(batchsize, groups, channels_per_group, height, width) @@ -134,7 +131,6 @@ def channel_shuffle(self, groups): def bernoulli_float(self, p=0.5): - """Decomposition for the `bernoulli_` operator with a float probability.""" return self.bernoulli_(p) @@ -154,10 +150,9 @@ def _grid_sampler_3d( padding_mode: int = 0, align_corners: bool = False, ) -> Tensor: - """Decomposition for the `grid_sampler_3d` operator. + """References: https://github.com/pytorch/pytorch/blob/06a7dc21c1005750598c37f3adbc031183c74de6/torch/_decomp/decompositions.py#L4075 - This implementation is based on the 2D version in the PyTorch repository: - https://github.com/pytorch/pytorch/blob/06a7dc21c1005750598c37f3adbc031183c74de6/torch/_decomp/decompositions.py#L4075 + The above implement the 2d case. """ _expand_grid = False torch._check( @@ -778,4 +773,4 @@ def get_summand(ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor, MUTABLE_DECOMPOSITION = [ torch.ops.aten.bernoulli_.Tensor, torch.ops.aten.bernoulli_.float, -] \ No newline at end of file +] diff --git a/torchax/torchax/device_module.py b/torchax/torchax/device_module.py index 41c98b5cad42..be028cfcc21d 100644 --- a/torchax/torchax/device_module.py +++ b/torchax/torchax/device_module.py @@ -2,40 +2,32 @@ def _is_in_bad_fork(): - """Returns `False` as forking is not applicable in the same way as CUDA.""" return False def manual_seed_all(seed): - """A placeholder for API compatibility; does not affect JAX's PRNG.""" pass def device_count(): - """Returns `1` as JAX manages devices as a single logical device.""" return 1 def get_rng_state(): - """Returns an empty list for API compatibility.""" return [] def set_rng_state(new_state, device): - """A placeholder for API compatibility; does not affect JAX's PRNG.""" pass def is_available(): - """Returns `True` if JAX is available.""" return True def current_device(): - """Returns `0` as JAX manages devices as a single logical device.""" return 0 def get_amp_supported_dtype(): - """Returns the data types supported by AMP (Automatic Mixed Precision).""" return [torch.float16, torch.bfloat16] diff --git a/torchax/torchax/export.py b/torchax/torchax/export.py index be2da17e1cb8..987fb92ba6ee 100644 --- a/torchax/torchax/export.py +++ b/torchax/torchax/export.py @@ -16,13 +16,7 @@ class JaxInterpreter(torch.fx.Interpreter): - """An `fx.Interpreter` that executes a PyTorch FX graph using JAX. - - This interpreter traverses an FX graph and replaces PyTorch operations with - their corresponding JAX implementations from the `torchax` operator registry. - It is a key component in the process of exporting PyTorch models to JAX and - StableHLO. - """ + """Experimental.""" def __init__(self, graph_module): super().__init__(graph_module) @@ -80,24 +74,11 @@ def _extract_states_from_exported_program(exported_model): def exported_program_to_jax(exported_program, export_raw: bool = False): - """Converts a `torch.export.ExportedProgram` to a JAX-compatible function and state. - - This function takes a PyTorch `ExportedProgram`, runs the necessary - decompositions, and returns a JAX-compatible function and the model's state - (parameters and buffers) as JAX arrays. - - **Arguments:** + """returns a pytree of jax arrays(state), and - * `exported_program` (`torch.export.ExportedProgram`): The PyTorch - `ExportedProgram` to convert. - * `export_raw` (`bool`, optional): If `True`, returns the raw states and - function without converting them to JAX arrays. Defaults to `False`. + a callable(func) that is jax function. - **Returns:** - - A tuple containing: - * A pytree of JAX arrays representing the model's state. - * A JAX-callable function that takes the state and inputs as arguments. + func(state, input) would be how you call it. """ if torch.__version__ >= '2.2': # torch version 2.1 didn't expose this yet @@ -134,19 +115,8 @@ def func(states, inputs): def extract_avals(exported): - """Returns JAX abstract values (`ShapeDtypeStruct`) for all input parameters of the exported program. - - This function supports dynamic batch dimensions, including those with - constraints. - - **Arguments:** - - * `exported` (`torch.export.ExportedProgram`): The exported PyTorch program. - - **Returns:** - - A list of `jax.ShapeDtypeStruct` objects representing the abstract values of - the input parameters. + """Return JAX Abstract Value shapes for all input parameters of the exported + program. This supports dynamic batch dimensions, including with constraints. """ def _to_aval(arg_meta, symbolic_shapes): @@ -262,24 +232,12 @@ def _build_symbolic_shape(sym, constraint, free_symbols): def exported_program_to_stablehlo(exported_program): - """Converts a `torch.export.ExportedProgram` to StableHLO. - - This function serves as a replacement for - `torch_xla.stablehlo.exported_program_to_stablehlo`. It supports dynamic - dimension sizes and generates explicit checks for Dynamo guards in the IR - using `shape_assertion` custom calls. - - **Arguments:** - - * `exported_program` (`torch.export.ExportedProgram`): The exported PyTorch - program. + """Replacement for torch_xla.stablehlo.exported_program_to_stablehlo - **Returns:** + Convert a program exported via torch.export to StableHLO. - A tuple containing: - * The model's state (weights) as a pytree of JAX arrays. - * A `jax.export.Exported` object containing the StableHLO representation of - the model. + This supports dynamic dimension sizes and generates explicit checks for + dynamo guards in the IR using shape_assertion custom_call ops. """ weights, func = exported_program_to_jax(exported_program) jax_avals = extract_avals(exported_program) diff --git a/torchax/torchax/flax.py b/torchax/torchax/flax.py index 4f3dce83e587..28542d79c90e 100644 --- a/torchax/torchax/flax.py +++ b/torchax/torchax/flax.py @@ -6,32 +6,8 @@ class FlaxNNModule(torch.nn.Module): - """A `torch.nn.Module` that wraps a Flax module for interoperability. - - This class allows you to use a Flax module within a PyTorch model. It - initializes the Flax module, extracts its parameters, and wraps them in a - `torch.nn.ParameterDict` so they can be managed by PyTorch. The `forward` - pass then calls the Flax module's `apply` method with the appropriate - parameters. - - **Attributes:** - - * `_params` (`torch.nn.Module`): A nested `torch.nn.Module` that holds the - parameters of the Flax module. - * `_flax_module`: The original Flax module. - """ def __init__(self, env, flax_module, sample_args, sample_kwargs=None): - """Initializes the `FlaxNNModule`. - - **Args:** - - * `env`: The `torchax` environment. - * `flax_module`: The Flax module to wrap. - * `sample_args`: A tuple of sample arguments to initialize the Flax module. - * `sample_kwargs` (optional): A dictionary of sample keyword arguments to - initialize the Flax module. - """ super().__init__() prng = env.prng_key sample_kwargs = sample_kwargs or {} @@ -58,7 +34,6 @@ def _decode_nested_dict(self, child_module): return result def forward(self, *args, **kwargs): - """Performs the forward pass by calling the wrapped Flax module.""" nested_dict_params = self._decode_nested_dict(self._params) return tx.interop.call_jax(self._flax_module.apply, nested_dict_params, *args, **kwargs) diff --git a/torchax/torchax/interop.py b/torchax/torchax/interop.py index d746e9d03ba2..a87efe9dfe74 100644 --- a/torchax/torchax/interop.py +++ b/torchax/torchax/interop.py @@ -56,29 +56,6 @@ def set_one(module, prefix): class JittableModule(torch.nn.Module): - """A wrapper class that makes a `torch.nn.Module` compatible with `jax.jit`. It separates the model's parameters and buffers, allowing them to be passed as arguments to a functional version of the model. - - **Arguments:** - - * `m` (`torch.nn.Module`): The PyTorch model to wrap. - * `extra_jit_args` (`dict`, optional): A dictionary of extra arguments to pass to `jax.jit`. - * `dedup_parameters` (`bool`, optional): If `True`, deduplicates parameters that are shared within the model. - - **Usage:** - - ```python - import torch - import torchax - from torchax.interop import JittableModule - - model = torch.nn.Linear(10, 20) - jittable_model = JittableModule(model) - - # The first call will compile the model - inputs = torch.randn(5, 10, device='jax') - outputs = jittable_model(inputs) - ``` - """ def __init__(self, m: torch.nn.Module, @@ -253,17 +230,12 @@ def call_torch(torch_func: TorchCallable, *args: JaxValue, def j2t_autograd(fn, call_jax=call_jax): - """Given a JAX function, returns a PyTorch `autograd` function that is implemented with `jax.vjp`. This allows you to define custom gradients for your PyTorch operations using JAX. - - **Arguments:** - - * `fn`: The JAX function for which to create a PyTorch `autograd` function. - * `call_jax` (optional): The function to use for calling JAX functions from PyTorch. + """Given a JAX function, returns a PyTorch autograd function implemented with `jax.vjp(fn)`. - **Returns:** - - A PyTorch function with custom gradients defined by the JAX function. - """ + It wraps `fn` with `jax.vjp` to compute both the output and residuals (intermediate + activations). The wrapped function is then run via `call_jax` and integrated into + the PyTorch autograd framework by saving the residuals into the context object. + """ @wraps(fn) def inner(*args, **kwargs): @@ -361,50 +333,11 @@ def wrap_jax_jit(torch_function, jax_jit_func=jax.jit, kwargs_for_jax=None): def jax_jit(torch_function, kwargs_for_jax_jit=None, fix_for_buffer_donation=False): - """A decorator that applies `jax.jit` to a PyTorch function. - - **Arguments:** - - * `torch_function`: The PyTorch function to be JIT-compiled. - * `kwargs_for_jax_jit` (`dict`, optional): A dictionary of keyword arguments to pass to `jax.jit`. - * `fix_for_buffer_donation` (`bool`, optional): A flag to enable a workaround for buffer donation issues. - - **Returns:** - - A JIT-compiled version of the PyTorch function. - - **Usage:** - - ```python - import torch - import torchax - from torchax.interop import jax_jit - - @jax_jit - def my_function(x, y): - return torch.sin(x) + torch.cos(y) - - x = torch.randn(5, 10, device='jax') - y = torch.randn(5, 10, device='jax') - result = my_function(x, y) - ``` - """ return wrap_jax_jit( torch_function, jax_jit_func=jax.jit, kwargs_for_jax=kwargs_for_jax_jit) def jax_shard_map(torch_function, kwargs_for_jax_shard_map=None): - """Applies `jax.experimental.shard_map` to a PyTorch function, allowing for data parallelism across multiple devices. - - **Arguments:** - - * `torch_function`: The PyTorch function to be sharded. - * `kwargs_for_jax_shard_map` (`dict`, optional): A dictionary of keyword arguments to pass to `shard_map`. - - **Returns:** - - A sharded version of the PyTorch function. - """ return wrap_jax_jit( torch_function, jax_jit_func=shard_map, @@ -412,17 +345,6 @@ def jax_shard_map(torch_function, kwargs_for_jax_shard_map=None): def jax_value_and_grad(torch_function, kwargs_for_value_and_grad=None): - """Applies `jax.value_and_grad` to a PyTorch function. - - **Arguments:** - - * `torch_function`: The PyTorch function. - * `kwargs_for_value_and_grad` (`dict`, optional): A dictionary of keyword arguments to pass to `jax.value_and_grad`. - - **Returns:** - - A function that computes both the value and the gradient of the input `torch_function`. - """ return wrap_jax_jit( torch_function, jax_jit_func=jax.value_and_grad, @@ -430,16 +352,5 @@ def jax_value_and_grad(torch_function, kwargs_for_value_and_grad=None): def gradient_checkpoint(torch_function, kwargs=None): - """Applies `jax.checkpoint` to a PyTorch function. This is useful for reducing memory usage during training by recomputing intermediate activations during the backward pass instead of storing them. - - **Arguments:** - - * `torch_function`: The PyTorch function to checkpoint. - * `kwargs` (`dict`, optional): A dictionary of keyword arguments to pass to `jax.checkpoint`. - - **Returns:** - - A checkpointed version of the PyTorch function. - """ return wrap_jax_jit( torch_function, jax_jit_func=jax.checkpoint, kwargs_for_jax=kwargs) diff --git a/torchax/torchax/mesh_util.py b/torchax/torchax/mesh_util.py index e147546dbbe9..208d86a1bac6 100644 --- a/torchax/torchax/mesh_util.py +++ b/torchax/torchax/mesh_util.py @@ -38,52 +38,46 @@ def _shard_first_multiple_of(axis_name, shape, multiple_of): class SingleAxisSharder: - """A callable object that generates `PartitionSpec`s for single-axis sharding. + """A callable object that generates PartitionSpecs for single-axis sharding. - This sharding strategy attempts to shard the *first* dimension of a tensor + This sharder strategy attempts to shard the *first* dimension of a tensor that is divisible by the specified `axis_size` along the given `axis_name`. - It's useful for simple 1D mesh sharding scenarios like FSDP, where parameters + It's useful for simple 1D mesh sharding scenarios like FSDP where parameters are typically sharded along one dimension. - **Attributes:** - - * `axis_name` (`str`): The name of the mesh axis to shard along. - * `axis_size` (`int`): The size of the mesh axis (number of devices along - that axis). - * `replicate_unshardable` (`bool`): If `True`, tensors that cannot be sharded - will be replicated. + Attributes: + axis_name: The name of the mesh axis to shard along. + axis_size: The size of the mesh axis (number of devices along that axis). """ def __init__(self, axis_name, axis_size, replicate_unshardable=False): - """Initializes the `SingleAxisSharder`. - - **Args:** + """Initializes the SingleAxisSharder. - * `axis_name` (`str`): The name of the mesh axis (e.g., "fsdp", "data"). - * `axis_size` (`int`): The number of devices along the specified mesh axis. - * `replicate_unshardable` (`bool`): If `True`, returns a replicated sharding - (`P()`) when no dimension is divisible by the axis size. + Args: + axis_name: The name of the mesh axis (e.g., "fsdp", "data"). + axis_size: The number of devices along the specified mesh axis. + replicate_unshardable: indicate whether it should return replicated sharding + (P()) when none of the axis is divisible by the axis size. """ self.axis_name = axis_name self.axis_size = axis_size self.replicate_unshardable = replicate_unshardable def __call__(self, name, shapedtype): - """Generates a `PartitionSpec` for a given tensor name and shaped type. - - **Args:** - - * `name` (`str`): The name of the tensor (e.g., parameter name). This - argument is provided for compatibility with more complex sharders but is - not used by this simple sharder. - * `shapedtype`: An object with a `.shape` attribute describing the - tensor's shape, and a `.dtype` attribute describing its dtype (e.g., - `jax.Array`, `jax.ShapeDtypeStruct`, or a `torch.Tensor`). - - **Returns:** - - A `jax.sharding.PartitionSpec` determined by finding the first dimension - in `shapedtype.shape` that is divisible by `self.axis_size`. + """Generates a PartitionSpec for a given tensor name and shaped type. + + Args: + name: The name of the tensor (e.g., parameter name). This argument is + provided for compatibility with more complex sharders but is not used + by this simple sharder. + shapedtype: An object with a `.shape` attribute describing the tensor's shape, + and `.dtype` describing it's dtype. Example: jax.Array, jax.ShapeDtypeStruct + or a torch.Tensor) + + Returns: + A jax.sharding.PartitionSpec determined by finding the first dimension + in `shapedtype.shape` divisible by `self.axis_size` using the helper + `_shard_first_multiple_of`. """ del name sharding = _shard_first_multiple_of(self.axis_name, shapedtype.shape, @@ -97,38 +91,36 @@ def __call__(self, name, shapedtype): class Mesh: - """A helper class that wraps a `jax.sharding.Mesh` object. - - This class provides helper methods for sharding PyTorch tensors and models - across a JAX device mesh, simplifying the process of initializing models - directly into a sharded state. - - **Attributes:** - - * `jax_mesh` (`jax.sharding.Mesh`): The underlying `jax.sharding.Mesh` object - that defines the device grid and axis names. - * `_sharder`: The default sharding strategy callable (like - `SingleAxisSharder`) used to determine the `PartitionSpec` for each - parameter if not overridden. + """A helper class that wraps `jax.sharding.Mesh` object. + + The goal of this class is to provide helper methods that facilitate the + sharding of PyTorch tensors or models given a JAX device mesh configuration. + It simplifies initializing models directly into a sharded state. + + Attributes: + jax_mesh: The underlying `jax.sharding.Mesh` object defining the device grid + and axis names. + _sharder: The default sharding strategy callable (like SingleAxisSharder) + used to determine the PartitionSpec for each parameter if not overridden + during method calls. Can be None if no default is appropriate or set. """ @classmethod def fsdp_mesh(cls, axis_name="fsdp"): - """Creates a `Mesh` instance suitable for 1D FSDP-style sharding. - - This method creates a 1D mesh that encompasses all available XLA devices and - assigns the specified `axis_name` to this dimension. It then creates a - `Mesh` instance with a `SingleAxisSharder` configured for this 1D mesh. + """Creates a Mesh instance suitable for 1D FSDP-style sharding. - **Args:** + This named constructor creates a 1D mesh encompassing all available XLA + devices. It assigns the specified `axis_name` to this single dimension. + It then creates a `Mesh` instance using this JAX mesh and a + `SingleAxisSharder` configured appropriately for this 1D mesh. - * `axis_name` (`str`, optional): The name to assign to the single mesh - axis. Defaults to `"fsdp"`. + Args: + axis_name: The name to assign to the single mesh axis (default: "fsdp"). + This name will be used by the default `SingleAxisSharder`. - **Returns:** - - A `Mesh` instance configured with a 1D JAX mesh and a corresponding - `SingleAxisSharder`. + Returns: + A Mesh instance configured with a 1D JAX mesh across all devices and a + corresponding SingleAxisSharder. """ ndevice = jax.device_count() jax_mesh = jax.make_mesh((ndevice,), (axis_name,)) @@ -136,16 +128,19 @@ def fsdp_mesh(cls, axis_name="fsdp"): return cls(jax_mesh, SingleAxisSharder(axis_name, ndevice, True)) def __init__(self, jax_mesh, sharder=None): - """Initializes the `Mesh` helper. - - **Args:** - - * `jax_mesh` (`jax.sharding.Mesh`): A pre-configured `jax.sharding.Mesh` - object that defines the physical device grid and logical axis names. - * `sharder` (optional): A callable that takes `(name, shapedtype)` and - returns a `jax.sharding.PartitionSpec`. This serves as the default - sharding strategy. If `None`, and the provided `jax_mesh` has exactly - one axis, a `SingleAxisSharder` is created automatically. + """Initializes the Mesh helper. + + Args: + jax_mesh: A pre-configured `jax.sharding.Mesh` object defining the + physical device grid and logical axis names. + sharder: An optional callable (e.g., an instance of SingleAxisSharder) + that takes (name, shapedtype) and returns a `jax.sharding.PartitionSpec`. + This serves as the default sharding strategy. + If None, and the provided `jax_mesh` has exactly one axis, a + `SingleAxisSharder` is created automatically for that single axis. + If None and the mesh has multiple axes, `_sharder` remains None, and + an `override_sharder` must be provided to methods like + `initialize_model_sharded`. """ self.jax_mesh = jax_mesh if sharder is None: @@ -161,24 +156,35 @@ def initialize_model_sharded(self, override_sharder=None): """Initializes a PyTorch model with its parameters sharded across the mesh. - This method initializes a `torch.nn.Module` such that its parameters are - created directly on the target devices according to the sharding - specifications derived from the mesh and the chosen sharder. - - **Args:** + This method orchestrates the initialization of a `torch.nn.Module` such + that its parameters are created directly on the target devices according + to the sharding specifications derived from the mesh and the chosen sharder. + It leverages `torchax.interop.jax_jit` to achieve this. - * `model_class`: The PyTorch model class (a subclass of `torch.nn.Module`). - * `init_args`: A tuple of positional arguments for the `model_class.__init__` - method. - * `init_kwargs` (optional): A dictionary of keyword arguments for the + Args: + model_class: The PyTorch model class (a subclass of `torch.nn.Module`). + init_args: A tuple containing the positional arguments required by the `model_class.__init__` method. - * `override_sharder` (optional): A callable sharding strategy to use for - this initialization, which takes precedence over the default sharder. - - **Returns:** - - An instance of `model_class` with its parameters initialized and sharded - across the devices in the `jax_mesh`. + init_kwargs: An optional dictionary containing the keyword arguments for + the `model_class.__init__` method. Defaults to None (treated as {}). + override_sharder: An optional callable sharding strategy to use + specifically for this initialization. If provided, it takes precedence + over the mesh's default `_sharder`. It must accept `(name, shapedtype)` + and return a `PartitionSpec`. If None, the mesh's default `_sharder` + is used. + + Returns: + An instance of `model_class` whose parameters have been initialized and + are represented by sharded tensors distributed across the devices in the + `jax_mesh`. + + Raises: + ValueError: If no sharder is available (i.e., `override_sharder` is None + and the mesh's default `_sharder` is also None). + AssertionError: Can be raised by the sharder (e.g., `SingleAxisSharder`) + if it fails to determine a valid sharding for any parameter. + TypeError: If `shapedtype` passed to the sharder doesn't have a `.shape`. + Other errors from JAX JIT compilation or PyTorch model initialization. """ init_kwargs = init_kwargs or {} with torch.device("meta"), torchax.disable_temporarily(): @@ -205,7 +211,6 @@ def model_initializer(): return model def shard_model(self, model, override_sharder=None): - """Shards the parameters of an existing model across the mesh.""" sharder = override_sharder or self._sharder states = model.state_dict() output_shards = { diff --git a/torchax/torchax/ops/mappings.py b/torchax/torchax/ops/mappings.py index fd9fbabede95..409a6d8350be 100644 --- a/torchax/torchax/ops/mappings.py +++ b/torchax/torchax/ops/mappings.py @@ -8,21 +8,6 @@ def t2j(t, use_dlpack=True): - """Converts a `torch.Tensor` to a `jax.Array`. - - This function handles the conversion of a PyTorch tensor to a JAX array, - with an option to use DLPack for zero-copy conversion where possible. - - **Arguments:** - - * `t` (`torch.Tensor`): The PyTorch tensor to convert. - * `use_dlpack` (`bool`, optional): If `True`, attempts to use DLPack for - zero-copy conversion. Defaults to `True`. - - **Returns:** - - A `jax.Array` that is equivalent to the input tensor. - """ is_bool = False if t.dtype == torch.bool: is_bool = True @@ -58,21 +43,6 @@ def t2j(t, use_dlpack=True): def j2t(x, use_dlpack=True): - """Converts a `jax.Array` to a `torch.Tensor`. - - This function handles the conversion of a JAX array to a PyTorch tensor, - with an option to use DLPack for zero-copy conversion where possible. - - **Arguments:** - - * `x` (`jax.Array`): The JAX array to convert. - * `use_dlpack` (`bool`, optional): If `True`, attempts to use DLPack for - zero-copy conversion. Defaults to `True`. - - **Returns:** - - A `torch.Tensor` that is equivalent to the input array. - """ with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): res = None if use_dlpack: @@ -156,7 +126,6 @@ def j2t(x, use_dlpack=True): def t2j_dtype(dtype): - """Converts a `torch.dtype` to a JAX dtype.""" if dtype not in TORCH_DTYPE_TO_JAX: raise RuntimeError( f'Attempting to convert unknown type: {dtype} to jax type,') @@ -164,7 +133,6 @@ def t2j_dtype(dtype): def j2t_dtype(dtype): - """Converts a JAX dtype to a `torch.dtype`.""" if dtype not in JAX_DTYPE_TO_TORCH: raise RuntimeError( f'Attempting to convert unknown type: {dtype} to torch type,') diff --git a/torchax/torchax/ops/op_base.py b/torchax/torchax/ops/op_base.py index 9d1f6b585483..d69e85ae50a6 100644 --- a/torchax/torchax/ops/op_base.py +++ b/torchax/torchax/ops/op_base.py @@ -12,21 +12,6 @@ class InplaceOp: - """A wrapper for creating in-place versions of functional operators. - - This class takes a functional operator and creates an in-place version of it. - It handles the mutation of the input tensor, including the case where the - input is a `View`. - - **Attributes:** - - * `functional`: The functional operator to wrap. - * `replace` (`bool`): If `True`, the underlying `jax.Array` of the input - tensor is replaced with the new value. Otherwise, the new value is - copied into the input tensor. - * `position_to_mutate` (`int`): The position of the argument to be mutated. - * `is_jax_func` (`bool`): `True` if the functional operator is a JAX function. - """ def __init__(self, functional_op, @@ -66,11 +51,6 @@ def __call__(self, *args, **kwargs): class OutVariant: - """A wrapper for creating out-of-place versions of functional operators. - - This class takes a functional operator and creates an out-of-place version - that writes the result to the `out` keyword argument. - """ def __call__(self, *args, **kwargs): to_mutate = kwargs['out'] @@ -83,16 +63,13 @@ def __call__(self, *args, **kwargs): def convert_dtype(use_default_dtype: bool = True): - """A decorator that converts the `dtype` kwarg of a function from `torch.dtype` to a JAX dtype. - - **Args:** + """Converts `dtype` kwarg of function from torch to JAX. - * `use_default_dtype` (`bool`): If `True`, uses the default PyTorch dtype if - no `dtype` is provided. + Args: + use_default_dtype: Whether to use torch default dtype if none is provided. - **Returns:** - - A decorator that wraps a JAX implementation of a PyTorch function. + Returns: + A decorator that wraps a JAX implementation of a torch function. """ def decorator(func: types.TorchCallable): @@ -117,10 +94,9 @@ def wrapper(*args: P.args, def maybe_convert_constant_dtype(val: Optional[types.JaxValue], dtype: Optional[jnp.dtype]): - """Optionally converts the dtype of a scalar constant using NumPy. + """Optionally converts scalar constant's dtype using `numpy` - This function is useful in cases where you require a constant and cannot - handle a traced array. + Use in cases where you require a constant and can't handle a traced array. """ if val and dtype: if isinstance(val, jax.Array): @@ -132,7 +108,7 @@ def maybe_convert_constant_dtype(val: Optional[types.JaxValue], def promote_int_input(f: Callable[Concatenate[jax.Array, P], types.JaxValue]): - """A decorator that promotes the first integer input of a function to `float32`.""" + """If the first argument is an int array, promote it to float32.""" @functools.wraps(f) def wrapper(x: jax.Array, *args: P.args, **kwargs: P.kwargs): @@ -147,11 +123,9 @@ def wrapper(x: jax.Array, *args: P.args, **kwargs: P.kwargs): def foreach_loop(seq: jax.Array, fn: Callable[[jax.Array, jax.Array], jax.Array], init_val=0.0): - """Applies a function to each element of a 1D array. + """Run `fn` for each element of 1D array `seq`. - This function is similar to `functools.reduce`, but is implemented with - `jax.lax.fori_loop` for efficient execution on accelerators. - """ + Similar to `functools.reduce`, but implemented with `jax.lax.fori_loop`.""" assert len(seq.shape) == 1 return jax.lax.fori_loop(0, len(seq), lambda i, carry: fn(carry, seq[i]), init_val) diff --git a/torchax/torchax/ops/ops_registry.py b/torchax/torchax/ops/ops_registry.py index 0359879a747a..aa0d61cbb491 100644 --- a/torchax/torchax/ops/ops_registry.py +++ b/torchax/torchax/ops/ops_registry.py @@ -7,22 +7,6 @@ @dataclasses.dataclass class Operator: - """A dataclass that represents a `torchax` operator. - - This class holds the implementation of a PyTorch operator, along with - metadata that describes how it should be handled by the `torchax` dispatcher. - - **Attributes:** - - * `torch_op` (`TorchCallable`): The original PyTorch operator. - * `func` (`Union[TorchCallable, JaxCallable]`): The implementation of the - operator, which can be either a PyTorch callable or a JAX callable. - * `is_jax_function` (`bool`): `True` if the implementation is a JAX function. - * `is_user_defined` (`bool`): `True` if the operator is defined by the user. - * `needs_env` (`bool`): `True` if the operator needs access to the `torchax` - environment. - * `is_view_op` (`bool`): `True` if the operator is a view operation. - """ torch_op: TorchCallable func: Union[TorchCallable, JaxCallable] is_jax_function: bool @@ -41,28 +25,6 @@ def register_torch_dispatch_op(aten_op, is_user_defined=False, needs_env=False, is_view_op=False): - """Registers a `torch_dispatch` operator. - - This function is used to register an implementation for a PyTorch ATen - operator. - - **Arguments:** - - * `aten_op`: The ATen operator to register (e.g., `torch.ops.aten.add`). - * `impl_callable`: The implementation of the operator. - * `is_jax_function` (`bool`, optional): `True` if the implementation is a JAX - function. Defaults to `True`. - * `is_user_defined` (`bool`, optional): `True` if the operator is defined by - the user. Defaults to `False`. - * `needs_env` (`bool`, optional): `True` if the operator needs access to the - `torchax` environment. Defaults to `False`. - * `is_view_op` (`bool`, optional): `True` if the operator is a view - operation. Defaults to `False`. - - **Returns:** - - The implementation callable. - """ op = Operator( aten_op, impl_callable, @@ -82,28 +44,6 @@ def register_torch_function_op(torch_func, is_user_defined=False, needs_env=False, is_view_op=False): - """Registers a `torch_function` operator. - - This function is used to register an implementation for a `torch_function` - operator (e.g., `torch.add`). - - **Arguments:** - - * `torch_func`: The `torch_function` operator to register. - * `impl_callable`: The implementation of the operator. - * `is_jax_function` (`bool`, optional): `True` if the implementation is a JAX - function. Defaults to `True`. - * `is_user_defined` (`bool`, optional): `True` if the operator is defined by - the user. Defaults to `False`. - * `needs_env` (`bool`, optional): `True` if the operator needs access to the - `torchax` environment. Defaults to `False`. - * `is_view_op` (`bool`, optional): `True` if the operator is a view - operation. Defaults to `False`. - - **Returns:** - - The implementation callable. - """ op = Operator( torch_func, impl_callable, diff --git a/torchax/torchax/tensor.py b/torchax/torchax/tensor.py index 867f626e9b65..3916fe6501b8 100644 --- a/torchax/torchax/tensor.py +++ b/torchax/torchax/tensor.py @@ -38,18 +38,6 @@ def log_nested(env, message): class Tensor(torch.Tensor): - """A `torch.Tensor` subclass that wraps a `jax.Array`. - - This class is the core of `torchax`, allowing PyTorch operations to be - dispatched to JAX. It holds a `jax.Array` internally and overrides - the necessary methods to ensure that operations are correctly routed - through the `torchax` dispatch mechanism. - - **Attributes:** - - * `_elem` (`jax.Array`): The underlying JAX array. - * `_env` (`Environment`): The `torchax` environment this tensor belongs to. - """ @staticmethod def __new__(cls, elem, env, requires_grad=False): @@ -125,21 +113,17 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 'call torchax.enable_globally() before.') def detach(self): - """Detaches the tensor from the computation graph.""" return Tensor(jax.lax.stop_gradient(self.jax()), self._env) def numpy(self) -> numpy.ndarray: - """Converts the tensor to a NumPy array.""" import numpy as np return np.array(self._elem) def jax(self) -> jax.Array: - """Returns the underlying `jax.Array`.""" return self._elem def torch(self) -> torch.Tensor: - """Converts the tensor to a standard `torch.Tensor`.""" return self._env.j2t_copy(self.jax()) @property @@ -169,22 +153,18 @@ def data(self, other): self._elem = other._elem def apply_jax(self, jax_function, *args, **kwargs): - """Applies a JAX function to the underlying `jax.Array`.""" # Call a jax function on _elem res = jax_function(self._elem, *args, **kwargs) return self._env.j2t_iso(res) def apply_jax_(self, jax_function, *args, **kwargs): - """Applies a JAX function in-place to the underlying `jax.Array`.""" self._elem = jax_function(self._elem, *args, **kwargs) return self def tolist(self): - """Converts the tensor to a list.""" return self._elem.tolist() def shard_(self, sharding): - """Applies a sharding constraint to the tensor in-place.""" self.apply_jax_(jax.lax.with_sharding_constraint, sharding) @@ -267,8 +247,6 @@ def __torch_function__(self, class XLADispatchMode(torch_dispatch.TorchDispatchMode): - """A `TorchDispatchMode` that intercepts PyTorch operations and dispatches them to the JAX backend through the `Environment`. - """ def __init__(self, env): self.env = env @@ -351,12 +329,16 @@ def __getattr__(self, name): class Environment(contextlib.ContextDecorator): - """Manages the execution environment for `torchax`. + """This class holds a set of configurations and "globals" needed + + for executing torch program using jax. + Things included so far: - This class holds the configuration, operator registry, PRNG key, and other - "global" state needed to execute PyTorch programs using the JAX backend. - It also provides helper functions for dispatching operations and converting - tensors between PyTorch and JAX representations. + op registry + PRNGKey + Configs + + Also helper functions to manipulate those. """ def __init__(self, configuration=None): @@ -388,7 +370,6 @@ def param(self): return self._property.content[-1] def manual_seed(self, key): - """Sets the seed for the JAX random number generator.""" jax_key = jax.random.PRNGKey(key) new_prop = self.param.override(prng=jax_key) self._property.content.append(new_prop) @@ -541,7 +522,6 @@ def _torch_Tensor_to(self, args, kwargs): return self._to_copy(the_tensor, dtype, device) def dispatch(self, func, types, args, kwargs): - """Dispatches a PyTorch operation to the appropriate JAX implementation.""" kwargs = kwargs or {} if func in TENSOR_CONSTRUCTORS: return self._handle_tensor_constructor(func, args, kwargs) @@ -620,13 +600,11 @@ def is_not_torchax_tensor(x): return res def enable_torch_modes(self): - """Enables the `torchax` dispatch modes.""" self._dispatch_mode.__enter__() self._function_mode.__enter__() self.enabled = True def disable_torch_modes(self, *exc): - """Disables the `torchax` dispatch modes.""" if not exc: exc = (None, None, None) self._function_mode.__exit__(*exc) @@ -656,12 +634,10 @@ def to_xla(self, torchvalues): return res def t2j_iso(self, torchtensors): - """Converts `torchax.Tensor`s to `jax.Array`s without copying. - - This function unwraps the underlying `jax.Array` from each `torchax.Tensor` - in the input pytree. - - Note: "iso" is short for "isomorphic". + """Convert torchax Tensor to jax array. + + This function will not copy, will just unwrap the inner jax array out. + Note: iso is short for "isomorphic" """ def to_jax(x): @@ -681,7 +657,6 @@ def to_jax(x): return res def v2t_iso(self, views): - """Converts `torchax.View`s to `torchax.Tensor`s without copying.""" def to_tensor(x): if isinstance(x, View): @@ -692,18 +667,18 @@ def to_tensor(x): return res def j2t_iso(self, jaxarray): - """Converts `jax.Array`s to `torchax.Tensor`s without copying. - - This function wraps each `jax.Array` in the input pytree with a - `torchax.Tensor`. - - Note: "iso" is short for "isomorphic". + """Convert jax array to torchax Tensor. + + This function will not copy, will just wrap the jax array with a torchax Tensor + Note: iso is short for "isomorphic" """ return torch_pytree.tree_map_only(jax.Array, lambda x: Tensor(x, self), jaxarray) def j2t_copy(self, args): - """Converts `jax.Array`s to `torch.Tensor`s on the CPU, potentially copying the data. + """Convert torch.Tensor in cpu to a jax array + + This might involves copying the data (depending if dlpack is enabled) """ return torch_pytree.tree_map_only( jax.Array, @@ -711,7 +686,9 @@ def j2t_copy(self, args): args) def t2j_copy(self, args): - """Converts `torch.Tensor`s to `jax.Array`s, potentially copying the data. + """Convert jax array to torch.Tensor in cpu. + + This might involves copying the data (depending if dlpack is enabled) """ return torch_pytree.tree_map_only( torch.Tensor, @@ -719,7 +696,6 @@ def t2j_copy(self, args): args) def override_op_definition(self, op_to_override, op_impl): - """Overrides the implementation of a PyTorch operator.""" self._ops[op_to_override] = ops_registry.Operator( op_to_override, op_impl, @@ -730,7 +706,6 @@ def override_op_definition(self, op_to_override, op_impl): @contextlib.contextmanager def override_property(self, **kwargs): - """A context manager to temporarily override properties of the environment.""" new_prop = self.param.override(**kwargs) self._property.content.append(new_prop) yield diff --git a/torchax/torchax/tf_integration.py b/torchax/torchax/tf_integration.py index 0575e6f0b812..c9842089bfcf 100644 --- a/torchax/torchax/tf_integration.py +++ b/torchax/torchax/tf_integration.py @@ -9,21 +9,6 @@ def exported_program_to_tf_function(ep, enable_xla=True): - """Converts a `torch.export.ExportedProgram` to a TensorFlow function. - - This function takes a PyTorch `ExportedProgram`, converts it to a JAX program, - and then wraps it as a TensorFlow function using `jax2tf`. - - **Arguments:** - - * `ep` (`torch.export.ExportedProgram`): The PyTorch `ExportedProgram` to convert. - * `enable_xla` (`bool`, optional): Whether to enable XLA compilation for the - converted TensorFlow function. Defaults to `True`. - - **Returns:** - - A TensorFlow function that is equivalent to the input `ExportedProgram`. - """ weights, jax_program = export.exported_program_to_jax(ep) wrapped = lambda *args: jax_program(weights, (args,)) avals = export.extract_avals(ep) @@ -45,21 +30,6 @@ def exported_program_to_tf_function(ep, enable_xla=True): def exported_program_to_tf_module(ep: torch.export.ExportedProgram, enable_xla=True) -> tf.Module: - """Converts a `torch.export.ExportedProgram` to a `tf.Module`. - - This function wraps the TensorFlow function created by - `exported_program_to_tf_function` in a `tf.Module` for easier use and saving. - - **Arguments:** - - * `ep` (`torch.export.ExportedProgram`): The PyTorch `ExportedProgram` to convert. - * `enable_xla` (`bool`, optional): Whether to enable XLA compilation. Defaults - to `True`. - - **Returns:** - - A `tf.Module` containing the converted TensorFlow function. - """ tfm = tf.Module() tfm.f = exported_program_to_tf_function(ep, enable_xla) return tfm @@ -72,23 +42,22 @@ def save_exported_program_as_tf_saved_model( function_alias: str = "", enable_xla=True, ): - """Exports and saves a PyTorch `ExportedProgram` to the TensorFlow SavedModel format. - - The resulting SavedModel can be used for inference with TensorFlow Serving or - further converted to TFLite for on-device deployment. - - **Arguments:** - - * `ep` (`torch.export.ExportedProgram`): The PyTorch `ExportedProgram` to save. - * `saved_model_dir` (`os.PathLike`): The path to an empty directory where the - SavedModel will be stored. - * `serving_key` (`str`, optional): The serving key to use for the signature - definition. This is used by TensorFlow Serving to identify the function - to run. Defaults to `tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY`. - * `function_alias` (`str`, optional): An alias for the function, which can be - used by other tools. - * `enable_xla` (`bool`, optional): Whether to enable XLA compilation. Defaults - to `True`. + """This function will export and save a pytorch ExportedProgram to tf.saved_model format. + + The resulting tf.saved_model can be used inference using tf.serving model + server + or further convert to tflite flatbuffer for on-device serving. + + Args: + torch_model: torch.nn.Module - model to export and save + args: Tuple[Any] - a set of args to trace the model with, i.e. + torch_model(*args) must run + saved_model_dir: os.PathLike - location to an empty directory to store the + saved_model + serving_key: str - serving key tag, this is used by tf.serving to know + which function to run. + function_alias: str - passed through saved_model.save, used to tag a + function for inference converter or other tools. """ tfm = exported_program_to_tf_module(ep, enable_xla=enable_xla) signatures = { @@ -113,22 +82,22 @@ def save_torch_module_as_tf_saved_model( function_alias: str = "", enable_xla=True, ): - """Exports and saves a `torch.nn.Module` to the TensorFlow SavedModel format. - - This function first exports the `torch.nn.Module` to an `ExportedProgram` - and then saves it as a SavedModel. - - **Arguments:** - - * `torch_model` (`torch.nn.Module`): The PyTorch model to export and save. - * `args` (`Tuple[Any]`): A tuple of arguments to trace the model with (i.e., - `torch_model(*args)` must be a valid call). - * `saved_model_dir` (`os.PathLike`): The path to an empty directory where the - SavedModel will be stored. - * `serving_key` (`str`, optional): The serving key for the signature - definition. - * `function_alias` (`str`, optional): An alias for the function. - * `enable_xla` (`bool`, optional): Whether to enable XLA compilation. + """This function will export and save a pytorch nn.Module to tf.saved_model format. + + The resulting tf.saved_model can be used inference using tf.serving model + server + or further convert to tflite flatbuffer for on-device serving. + + Args: + torch_model: torch.nn.Module - model to export and save + args: Tuple[Any] - a set of args to trace the model with, i.e. + torch_model(*args) must run + saved_model_dir: os.PathLike - location to an empty directory to store the + saved_model + serving_key: str - serving key tag, this is used by tf.serving to know + which function to run. + function_alias: str - passed through saved_model.save, used to tag a + function for inference converter or other tools. """ ep = torch.export.export(torch_model, args) save_exported_program_as_tf_saved_model(ep, saved_model_dir, serving_key, @@ -136,16 +105,6 @@ def save_torch_module_as_tf_saved_model( def exported_program_to_tflite_flatbuffer(ep: torch.export.ExportedProgram): - """Converts a `torch.export.ExportedProgram` to a TFLite flatbuffer. - - **Arguments:** - - * `ep` (`torch.export.ExportedProgram`): The PyTorch `ExportedProgram` to convert. - - **Returns:** - - A TFLite flatbuffer model. - """ tfm = exported_program_to_tf_module(ep) tf_concrete_func = tfm.f.get_concrete_function(*tfm.f.input_signature) converter = tf.lite.TFLiteConverter.from_concrete_functions( @@ -156,16 +115,5 @@ def exported_program_to_tflite_flatbuffer(ep: torch.export.ExportedProgram): def torch_module_to_tflite_flatbuffer(torch_model: torch.nn.Module, args: Tuple[Any]): - """Converts a `torch.nn.Module` to a TFLite flatbuffer. - - **Arguments:** - - * `torch_model` (`torch.nn.Module`): The PyTorch model to convert. - * `args` (`Tuple[Any]`): A tuple of arguments to trace the model with. - - **Returns:** - - A TFLite flatbuffer model. - """ ep = torch.export.export(torch_model, args) return exported_program_to_tflite_flatbuffer(ep) diff --git a/torchax/torchax/train.py b/torchax/torchax/train.py index 6f7ea24576dc..fb4e16fc48ee 100644 --- a/torchax/torchax/train.py +++ b/torchax/torchax/train.py @@ -12,34 +12,21 @@ def make_train_step(model_fn, loss_fn, optax_optimizer, remat_policy=None): - """Creates a function that performs one training step. - - This function is designed to be used with JAX's `jit` for efficient training. - It takes a model function, a loss function, and an Optax optimizer, and - returns a function that computes the loss, calculates gradients, and updates - the model's weights. - - **Arguments:** - - * `model_fn`: A function representing the model's forward pass. It should - have the signature `Callable[weights, buffers, args] -> result`, where: - * `weights` is a pytree of trainable parameters. - * `buffers` is a pytree of non-trainable parameters and constants. - * `args` is the input data from the dataset. - * `result` is the model's output. - * `loss_fn`: A function to compute the loss. It should have the signature - `Callable[result, label] -> loss`, where: - * `result` is the output of `model_fn`. - * `label` is the ground truth from the dataloader. - * `optax_optimizer`: An optimizer from the Optax library (e.g., `optax.adam`). - * `remat_policy` (optional): A policy from `jax.ad_checkpoint.checkpoint_policies` - that specifies how to perform gradient checkpointing. If `None`, all - intermediate activations will be checkpointed. - - **Returns:** - - A function that performs one training step. It has the signature - `Callable[weights, buffers, opt_state, args, label] -> (loss, new_weights, new_opt_state)`. + """Make a function that do one train step given model and loss. + + model_fn: a function representing the model's forward: + i.e. has signature Callable[weights, buffers, args] -> result. Where, + weights is a pytree of trainable parameters + buffers is a pytree of non-trainable parameters / constants + args is the input data loaded from the data set + result is the return value of the model + loss_fn: a function to compute loss. + i.e. it has signature of Callable[result, label] -> loss + where, result is what model_fn returned + loss is loaded from the dataloader. + optax_optimizer: the optimizer from optax library. for example, optax.adam + remat_policy: One of jax.ad_checkpoint.checkpoint_policies, specifies how + to do gradient checkpointing. If None, then it means checkpoint everything. """ env = torchax.default_env() @@ -71,18 +58,6 @@ class Container: class ScannedModule(torch.nn.Module): - """A `torch.nn.Module` that applies a list of identical modules sequentially. - - This module is designed to be used with `jax.lax.scan` for efficient - execution of repeated layers. It takes a list of modules, stacks their - weights, and applies the same module function to the input in a loop. - - **Attributes:** - - * `checkpoint_policy`: The gradient checkpointing policy to use. - * `params`: A `torch.nn.ParameterDict` containing the stacked weights of the - input modules. - """ def __init__(self, module_list, checkpoint_policy=None): super().__init__() diff --git a/torchax/torchax/view.py b/torchax/torchax/view.py index a2fcc0d9831b..040fa24ef9e8 100644 --- a/torchax/torchax/view.py +++ b/torchax/torchax/view.py @@ -22,77 +22,77 @@ class ViewInfoType(Enum): class ViewInfo(ABC): - """Abstract base class for all view operations. - - This class defines the interface for applying and updating view transformations - on JAX arrays. Each subclass represents a specific type of view, such as - a slice, reshape, or permutation. """ + Abstract base class for all view operations. + Defines the interface for applying and updating view transformations. + """ def __init__( self, view_info_type: ViewInfoType = ViewInfoType.INVALID, ): - """Initializes a ViewInfo object. - - Args: - view_info_type: The type of view operation. """ + Initialize a ViewInfo object. + + Args: + view_info_type: The type of view operation + """ self.view_info_type = view_info_type @abstractmethod def update_tensor(self, new_value: jax.Array, jax_array: jax.Array) -> jax.Array: - """Applies this view transformation to a JAX array and updates its value. + """ + Apply this view transformation to a JAX array and update its value. - Args: - new_value: The new values to set in the view. - jax_array: The parent array to update. + Args: + new_value: The new values to set in the view + jax_array: The parent array to update - Returns: - The updated array. - """ + Returns: + Updated array + """ pass @abstractmethod def transform_tensor(self, jax_array: jax.Array) -> jax.Array: - """Applies this view transformation to a JAX array. + """ + Apply this view transformation to a JAX array. - Args: - jax_array: The array to transform. + Args: + jax_array: The array to transform - Returns: - The transformed array. - """ + Returns: + Transformed array + """ pass @abstractmethod def calculate_output_shape(self, source: jax.Array) -> List[int]: - """Calculates the resulting shape after applying this view. + """ + Calculate the resulting shape after applying this view. - Args: - source: The original JAX array before transformation. + Args: + source: Original jax array before transformation - Returns: - The resulting shape after transformation. - """ + Returns: + Resulting shape after transformation + """ pass class NarrowInfo(ViewInfo): - """Represents a slicing operation on a tensor. - - This class handles operations like `tensor[1:3, :, 2:5:2]`. """ + Represents a slicing operation on a tensor. + Handles operations like tensor[1:3, :, 2:5:2]. + """ def __init__(self, slices: Union[slice, Tuple[slice]]) -> None: - """Initializes a NarrowInfo object. - - Args: - slices: The slice(s) to apply to the tensor. - For example, `jax_array.at[slices]` will return the - transformed tensor. """ + Args: + slices: The slice(s) to apply to the tensor. + E.g. jax_array.at[slices] will return the transformed tensor. + """ super().__init__(ViewInfoType.NARROW) self.slices = slices @@ -116,10 +116,10 @@ def calculate_output_shape(self, source: jax.Array) -> List[int]: class SelectInfo(ViewInfo): - """Represents a selection operation on a tensor. - - This is typically used for indexing operations that select specific elements. """ + Represents a selection operation on a tensor. + Typically used for indexing operations that select specific elements. + """ def __init__(self, dim: int = 0, @@ -151,7 +151,9 @@ def calculate_output_shape(self, source: jax.Array) -> List[int]: class AsStridedInfo(ViewInfo): - """Represents an `as_strided` operation on a tensor.""" + """ + Information for as_strided operations. + """ def __init__(self, stride: List[int], offset: int = 0) -> None: super().__init__(ViewInfoType.AS_STRIDED) @@ -176,19 +178,18 @@ def calculate_output_shape(self, source: jax.Array) -> List[int]: class DiagonalInfo(ViewInfo): - """Represents a diagonal operation on a tensor. - - This class is used to extract diagonal elements from a tensor. """ + Information for diagonal operations. + Extracts diagonal elements from a tensor. + """ def __init__(self, offset: int = 0, dim1: int = 0, dim2: int = 1) -> None: - """Initializes a DiagonalInfo object. - - Args: - offset: The offset from the main diagonal. - dim1: The first dimension for diagonal extraction. - dim2: The second dimension for diagonal extraction. """ + Args: + offset: Offset from the main diagonal + dim1: First dimension for diagonal extraction + dim2: Second dimension for diagonal extraction + """ super().__init__(ViewInfoType.DIAGONAL) self.offset: int = offset self.dim1: int = dim1 @@ -213,24 +214,20 @@ def calculate_output_shape(self, source: jax.Array) -> List[int]: class View(torch.Tensor): - """A `torch.Tensor` subclass that represents a view of another tensor. - - A `View` holds a reference to a parent `torchax.Tensor` or another `View`, - along with a `ViewInfo` object that describes the transformation to be - applied. This allows for lazy evaluation of view operations and efficient - in-place updates. """ + A View is a reference to another Tensor or another View, + with a transformation applied to it. + """ @staticmethod def __new__(cls, parent: Union["torchax.Tensor", "View"], view_info: ViewInfo, env: Any) -> "View": - """Creates a new `View` object. - - Args: - parent: The parent tensor or view. - view_info: Information about the view transformation. - env: The `torchax` environment. """ + Args: + parent: Parent tensor or view + view_info: Information about the view transformation + env: Environment for tensor operations + """ shape = view_info.calculate_output_shape(parent.jax()) return torch.Tensor._make_wrapper_subclass( cls, @@ -248,7 +245,9 @@ def __init__(self, parent: Union["torchax.Tensor", "View"], self._env = env def get_transformation_chain(self) -> List[ViewInfo]: - """Returns the chain of view transformations from the source tensor to this view.""" + """ + Get all view transformations from the source tensor to this view. + """ if isinstance(self.parent, View): transformations = self.parent.get_transformation_chain() transformations.append(self.view_info) @@ -259,14 +258,18 @@ def get_transformation_chain(self) -> List[ViewInfo]: __torch_function__ = torch._C._disabled_torch_function_impl def source_jax(self) -> jax.Array: - """Returns the underlying `jax.Array` of the source tensor.""" + """ + Returns the source tensor. + """ if isinstance(self.parent, View): return self.parent.source_jax() else: return self.parent.jax() def replace_source_jax(self, new_value: jax.Array) -> None: - """Updates the source tensor with a new `jax.Array`.""" + """ + Update the source tensor with new values. + """ if isinstance(self.parent, View): self.parent.replace_source_jax(new_value) else: @@ -274,7 +277,9 @@ def replace_source_jax(self, new_value: jax.Array) -> None: self.parent._elem = new_value def torch(self) -> "torchax.Tensor": - """Returns a `torchax.Tensor` representing this view after all transformations.""" + """ + Returns a Torchax tensor representing this view after all transformations + """ from torchax.tensor import Tensor return Tensor(self.jax(), self._env) @@ -284,11 +289,11 @@ def update( new_values: Union[jax.Array, "View", "torchax.Tensor"], view_infos: Optional[List[ViewInfo]] = None, ) -> None: - """Updates this view with new values, propagating changes back to the source. - - If `view_infos` is not provided, it will use the transformation chain - from the source tensor. """ + Update this view with new values, propagating changes back to source. + If view_infos is None, it will use the transformation chain + from the source tensor. + """ if view_infos is None: view_infos = self.get_transformation_chain() @@ -333,14 +338,18 @@ def __torch_dispatch__( 'call torchax.enable_globally() before.') def create_sub_view(self, view_info: ViewInfo) -> "View": - """Creates a new view that is a child of this view.""" + """ + Create a new view that is a child of this view. + """ return View(self, view_info, self._env) def __str__(self) -> str: return f"View({self.torch()})" def jax(self) -> jax.Array: - """Returns a copy of the source tensor after all transformations have been applied.""" + """ + Returns a copy of the source tensor after transformations. + """ result = self.source_jax() for view_info in self.get_transformation_chain(): result = view_info.transform_tensor(result) From 43589c03b72d5ee07046886f6eb5212c67ff1ed5 Mon Sep 17 00:00:00 2001 From: aws-cph Date: Mon, 4 Aug 2025 15:08:35 -0700 Subject: [PATCH 036/133] Implement XLAShardedTensor.redistribute and test (#9529) --- test/neuron/run_tests.sh | 1 + test/run_tests.sh | 1 + test/spmd/test_dtensor_redistribute.py | 269 ++++++++++++++++++ test/tpu/run_tests.sh | 1 + .../distributed/spmd/xla_sharded_tensor.py | 36 ++- 5 files changed, 307 insertions(+), 1 deletion(-) create mode 100644 test/spmd/test_dtensor_redistribute.py diff --git a/test/neuron/run_tests.sh b/test/neuron/run_tests.sh index ecc302aa30fc..d8ee9a39b03e 100755 --- a/test/neuron/run_tests.sh +++ b/test/neuron/run_tests.sh @@ -256,6 +256,7 @@ function run_xla_op_tests3 { #run_test "$_TEST_DIR/spmd/test_dtensor_integration2.py" run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_spec_conv.py" + run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_redistribute.py" run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py" #run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py" run_test "$_TEST_DIR/spmd/test_train_spmd_linear_model.py" diff --git a/test/run_tests.sh b/test/run_tests.sh index 66c8bbff0406..54c893c7b405 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -256,6 +256,7 @@ function run_xla_op_tests3 { run_test_multi_devices_without_func "$_TEST_DIR/spmd/test_dtensor_integration3.py" run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py" + run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_redistribute.py" run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py" run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py" run_test "$_TEST_DIR/spmd/test_mp_input_sharding.py" diff --git a/test/spmd/test_dtensor_redistribute.py b/test/spmd/test_dtensor_redistribute.py new file mode 100644 index 000000000000..dc7febf59305 --- /dev/null +++ b/test/spmd/test_dtensor_redistribute.py @@ -0,0 +1,269 @@ +import sys +import unittest +import torch +from torch.distributed.tensor.placement_types import Shard, Replicate, Partial +import torch_xla.runtime as xr +import torch_xla.distributed.spmd as xs +import torch_xla +import numpy as np +import test_xla_sharding_base +from absl.testing import parameterized + + +class DTensorRedistributeTest(test_xla_sharding_base.XlaShardingTest, + parameterized.TestCase): + + @classmethod + def setUpClass(cls): + super().setUpClass() + xr.use_spmd() + + def _verify_sharding_spec(self, tensor, expected_devices=None): + """Verify tensor sharding spec after mark_step""" + torch_xla.sync() + sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(tensor) + if expected_devices: + self.assertIn(expected_devices, sharding_spec) + return sharding_spec + + # Test tensor shapes: 0D, 1D, 2D, 3D + @parameterized.parameters( + ((), ()), # 0D scalar + ((8,), (0,)), # 1D + ((8, 16), (0, None)), # 2D + ((4, 8, 16), (0, None, None)) # 3D + ) + @unittest.skipIf(xr.global_runtime_device_count() < 2, "Need ≥2 devices") + def test_tensor_shapes(self, shape, partition_spec): + device_count = xr.global_runtime_device_count() + mesh = xs.Mesh(np.arange(device_count), (device_count,)) + + if len(shape) == 0: + tensor = torch.tensor(1.0).to('xla') + placements = [Replicate()] + expected_spec = () + else: + tensor = torch.randn(shape).to('xla') + sharded_tensor = xs.mark_sharding(tensor, mesh, partition_spec) + placements = [Shard(0)] + expected_spec = partition_spec + + redistributed = sharded_tensor.redistribute(mesh, placements) + self.assertEqual(redistributed.partition_spec, expected_spec) + + # Convert partition spec to expected devices pattern + devices_pattern = [ + str(device_count) if spec == 0 else '1' for spec in expected_spec + ] + expected_devices = f"devices=[{','.join(devices_pattern)}]" + + # Skip HLO verification for 4D tensors due to XLA optimization issues + if len(shape) < 4: + self._verify_sharding_spec(redistributed.global_tensor, + expected_devices) + + # Test tensor dtypes: bf16, f32, int32 + @parameterized.parameters(torch.bfloat16, torch.float32, torch.int32) + @unittest.skipIf(xr.global_runtime_device_count() < 2, "Need ≥2 devices") + def test_tensor_dtypes(self, dtype): + device_count = xr.global_runtime_device_count() + mesh = xs.Mesh(np.arange(device_count), (device_count,)) + + if dtype == torch.int32: + tensor = torch.randint(0, 100, (8, 16), dtype=dtype).to('xla') + else: + tensor = torch.randn(8, 16, dtype=dtype).to('xla') + + sharded_tensor = xs.mark_sharding(tensor, mesh, (0, None)) + placements = [Shard(0)] + + redistributed = sharded_tensor.redistribute(mesh, placements) + self.assertEqual(redistributed.partition_spec, (0, None)) + self.assertEqual(redistributed.global_tensor.dtype, dtype) + + # Verify HLO sharding + expected_devices = f"devices=[{device_count},1]" + self._verify_sharding_spec(redistributed.global_tensor, expected_devices) + + # Test device mesh dimensions: 1D, 2D + @unittest.skipIf(xr.global_runtime_device_count() < 4, "Need ≥4 devices") + def test_device_mesh_dimensions(self): + device_count = xr.global_runtime_device_count() + + # 1D mesh + mesh_1d = xs.Mesh(np.arange(device_count), (device_count,)) + tensor = torch.randn(8, 16).to('xla') + sharded_tensor = xs.mark_sharding(tensor, mesh_1d, (0, None)) + + redistributed = sharded_tensor.redistribute(mesh_1d, [Shard(1)]) + self.assertEqual(redistributed.partition_spec, (None, 0)) + + # Verify HLO sharding for 1D mesh + expected_devices = f"devices=[1,{device_count}]" + self._verify_sharding_spec(redistributed.global_tensor, expected_devices) + + # 2D mesh + if device_count >= 4 and device_count % 2 == 0: + mesh_2d = xs.Mesh(np.arange(device_count), (2, device_count // 2)) + tensor_2d = torch.randn(8, 16).to('xla') + sharded_tensor = xs.mark_sharding(tensor_2d, mesh_2d, (0, None)) + + redistributed = sharded_tensor.redistribute( + mesh_2d, [Replicate(), Shard(1)]) + self.assertEqual(redistributed.partition_spec, (None, 1)) + + # Verify HLO sharding for 2D mesh + expected_devices = f"devices=[1,{device_count // 2},{device_count // 2}]" + self._verify_sharding_spec(redistributed.global_tensor, expected_devices) + + # Test placement types: Replicate, Shard + @unittest.skipIf(xr.global_runtime_device_count() < 2, "Need ≥2 devices") + def test_placement_types(self): + device_count = xr.global_runtime_device_count() + mesh = xs.Mesh(np.arange(device_count), (device_count,)) + tensor = torch.randn(8, 16).to('xla') + + # Test Replicate + sharded_tensor = xs.mark_sharding(tensor, mesh, (0, None)) + redistributed = sharded_tensor.redistribute(mesh, [Replicate()]) + self.assertEqual(redistributed.partition_spec, (None, None)) + + # Verify HLO sharding for replicated + self._verify_sharding_spec(redistributed.global_tensor, "replicated") + + # Test Shard on different dimensions + for dim in [0, 1]: + with self.subTest(shard_dim=dim): + redistributed = sharded_tensor.redistribute(mesh, [Shard(dim)]) + expected_spec = [None, None] + expected_spec[dim] = 0 + self.assertEqual(redistributed.partition_spec, tuple(expected_spec)) + + # Verify HLO sharding + devices_pattern = [ + str(device_count) if i == dim else '1' for i in range(2) + ] + expected_devices = f"devices=[{','.join(devices_pattern)}]" + self._verify_sharding_spec(redistributed.global_tensor, + expected_devices) + + # Test error cases with invalid inputs + @unittest.skipIf(xr.global_runtime_device_count() < 2, "Need ≥2 devices") + def test_invalid_inputs(self): + device_count = xr.global_runtime_device_count() + mesh = xs.Mesh(np.arange(device_count), (device_count,)) + tensor = torch.randn(8, 16).to('xla') + sharded_tensor = xs.mark_sharding(tensor, mesh, (0, None)) + + # Test invalid shard dimension (tensor only has dims 0,1 but asking for dim 2) + with self.assertRaises(IndexError): + sharded_tensor.redistribute(mesh, [Shard(2)]) + + # Test mismatched placements length (1D mesh expects 1 placement, not 2) + with self.assertRaises(ValueError): + sharded_tensor.redistribute(mesh, [Shard(0), Shard(1)]) + + # Test Partial placement - should raise error about not being implemented + with self.assertRaises(NotImplementedError): + sharded_tensor.redistribute(mesh, [Partial()]) + + # Test sharding propagation through operations + @unittest.skipIf(xr.global_runtime_device_count() < 2, "Need ≥2 devices") + def test_sharding_propagation(self): + device_count = xr.global_runtime_device_count() + mesh = xs.Mesh(np.arange(device_count), (device_count,)) + + # Unary ops + tensor = torch.randn(8, 16).to('xla') + sharded_tensor = xs.mark_sharding(tensor, mesh, (0, None)) + redistributed = sharded_tensor.redistribute(mesh, [Shard(0)]) + + relu_result = torch.relu(redistributed.global_tensor) + self.assertEqual(relu_result.shape, (8, 16)) + self.assertTrue(torch.all(relu_result >= 0)) + + # Binary ops + tensor2 = torch.randn(8, 16).to('xla') + sharded_tensor2 = xs.mark_sharding(tensor2, mesh, (0, None)) + redistributed2 = sharded_tensor2.redistribute(mesh, [Shard(0)]) + + add_result = redistributed.global_tensor + redistributed2.global_tensor + mul_result = redistributed.global_tensor * redistributed2.global_tensor + + # Verify operation results + self.assertEqual(add_result.shape, (8, 16)) + self.assertEqual(mul_result.shape, (8, 16)) + + # Verify operations work correctly + self.assertTrue( + torch.allclose( + add_result, + redistributed.global_tensor + redistributed2.global_tensor)) + self.assertTrue( + torch.allclose( + mul_result, + redistributed.global_tensor * redistributed2.global_tensor)) + + # Test comprehensive redistribute scenarios + @unittest.skipIf(xr.global_runtime_device_count() < 2, "Need ≥2 devices") + def test_comprehensive_redistribute(self): + device_count = xr.global_runtime_device_count() + mesh = xs.Mesh(np.arange(device_count), (device_count,)) + + tensor = torch.randn(8, 16).to('xla') + sharded_tensor = xs.mark_sharding(tensor, mesh, (0, None)) + + # Test all placement combinations for 1D mesh + placement_types = [Replicate(), Shard(0), Shard(1)] + + for placement in placement_types: + with self.subTest(placement=placement): + placements = [placement] + + if isinstance(placement, Shard): + expected_spec = [None] * 2 + expected_spec[placement.dim] = 0 + expected_spec = tuple(expected_spec) + else: + expected_spec = (None, None) + + redistributed = sharded_tensor.redistribute(mesh, placements) + self.assertEqual(redistributed.partition_spec, expected_spec) + + # Verify HLO sharding + if isinstance(placement, Shard): + devices_pattern = [ + str(device_count) if i == placement.dim else '1' for i in range(2) + ] + expected_devices = f"devices=[{','.join(devices_pattern)}]" + else: + expected_devices = "replicated" + self._verify_sharding_spec(redistributed.global_tensor, + expected_devices) + + # Test async redistribute + @unittest.skipIf(xr.global_runtime_device_count() < 4, "Need ≥4 devices") + def test_async_redistribute(self): + device_count = xr.global_runtime_device_count() + mesh_shape = (2, device_count // 2) + mesh = xs.Mesh(np.arange(device_count), mesh_shape) + + tensor = torch.randn(8, 16).to('xla') + sharded_tensor = xs.mark_sharding(tensor, mesh, (0, None)) + + # Test async redistribute + placements = [Replicate(), Shard(0)] + redistributed = sharded_tensor.redistribute(mesh, placements, async_op=True) + self.assertEqual(redistributed.partition_spec, (1, None)) + + # Verify async operation creates different tensor object + self.assertIsNot(redistributed.global_tensor, sharded_tensor.global_tensor) + + # Verify HLO sharding for async redistribute (XLA generates more complex pattern) + expected_devices = f"devices=[2,1,{device_count // 2}]" + self._verify_sharding_spec(redistributed.global_tensor, expected_devices) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 017fed5294fa..440db8bd28ad 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -62,6 +62,7 @@ run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py" run_test "$_TEST_DIR/spmd/test_fsdp_v2.py" run_test "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" run_test "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py" +run_test "$_TEST_DIR/spmd/test_dtensor_redistribute.py" run_test "$_TEST_DIR/test_gradient_accumulation.py" XLA_EXPERIMENTAL=nonzero:masked_select:nms run_test "$_TEST_DIR/ds/test_dynamic_shape_models.py" -v run_test "$_TEST_DIR/test_autocast.py" diff --git a/torch_xla/distributed/spmd/xla_sharded_tensor.py b/torch_xla/distributed/spmd/xla_sharded_tensor.py index c6a9a5d4f58c..5a049b5864e3 100644 --- a/torch_xla/distributed/spmd/xla_sharded_tensor.py +++ b/torch_xla/distributed/spmd/xla_sharded_tensor.py @@ -9,7 +9,7 @@ import torch_xla.runtime as xr from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor.placement_types import Placement, Shard, Replicate +from torch.distributed.tensor.placement_types import Placement, Shard, Replicate, Partial from torch.utils._pytree import tree_map_only @@ -264,6 +264,40 @@ def invalidate_spec_cache(self): """Invalidate the cached DTensorSpec.""" self._cached_spec = None + def redistribute(self, device_mesh, placements, *, async_op: bool = False): + # Validate inputs + if len(placements) != len(device_mesh.mesh_shape): + raise ValueError( + f"Number of placements ({len(placements)}) must match mesh dimensions ({len(device_mesh.mesh_shape)})" + ) + + # Check for unsupported placement types + for placement in placements: + if isinstance(placement, Partial): + raise NotImplementedError( + "Partial placement is not yet implemented and may have unexpected behavior. " + "Use Shard or Replicate placements instead.") + + # Convert placements to partition spec + partition_spec = [None] * len(self.global_tensor.shape) + for mesh_dim, placement in enumerate(placements): + if isinstance(placement, Shard): + if placement.dim >= len(self.global_tensor.shape): + raise IndexError( + f"Shard dimension {placement.dim} is out of bounds for tensor with {len(self.global_tensor.shape)} dimensions" + ) + partition_spec[placement.dim] = mesh_dim + + result_tensor = self.global_tensor.clone( + ) if async_op else self.global_tensor + op_sharding = device_mesh.get_op_sharding(tuple(partition_spec)) + torch_xla._XLAC._xla_annotate_custom_sharding(result_tensor, op_sharding) + + return XLAShardedTensor( + result_tensor, + mesh_shape=device_mesh.mesh_shape, + partition_spec=tuple(partition_spec)) + @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): return super().__torch_function__(func, types, args, kwargs) From 15496cd2d89a4eeefb0e0fb971a68a47b1ab4c2a Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Tue, 5 Aug 2025 16:20:20 -0300 Subject: [PATCH 037/133] Do not set `PJRT_DEVICE=CUDA` automatically on import. (#9540) --- torch_xla/runtime.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index 2e274190db75..d83bcba8a1dd 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -63,16 +63,6 @@ def _maybe_select_default_device(): if torch_xla._found_libtpu and tpu.num_available_chips() > 0: logging.warning('libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.') os.environ[xenv.PJRT_DEVICE] = 'TPU' - elif xu.getenv_as(xenv.GPU_NUM_DEVICES, int, 0) > 0: - logging.warning('GPU_NUM_DEVICES is set. Setting PJRT_DEVICE=CUDA') - os.environ[xenv.PJRT_DEVICE] = 'CUDA' - elif torch.cuda.is_available() and torch.cuda.device_count() > 0: - num_devices_str = str(torch.cuda.device_count()) - logging.warning( - 'Found CUDA without GPU_NUM_DEVICES. Defaulting to PJRT_DEVICE=CUDA with GPU_NUM_DEVICES=' - + num_devices_str) - os.environ[xenv.PJRT_DEVICE] = 'CUDA' - os.environ[xenv.GPU_NUM_DEVICES] = num_devices_str elif torch_xla._found_libneuronxla: logging.warning('Found libneuronpjrt.so. Setting PJRT_DEVICE=NEURON.') os.environ[xenv.PJRT_DEVICE] = 'NEURON' From e5e75a840db05321d83731cd753cf924796c281f Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Wed, 6 Aug 2025 15:26:50 -0700 Subject: [PATCH 038/133] Add triggers for release 2.8.0 (#9545) --- .../artifacts.auto.tfvars | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/infra/tpu-pytorch-releases/artifacts.auto.tfvars b/infra/tpu-pytorch-releases/artifacts.auto.tfvars index d43a88b76675..fd47e4b63f79 100644 --- a/infra/tpu-pytorch-releases/artifacts.auto.tfvars +++ b/infra/tpu-pytorch-releases/artifacts.auto.tfvars @@ -3,6 +3,51 @@ manual_nightly_builds = [ ] manual_versioned_builds = [ + { + git_tag = "v2.8.0" + package_version = "2.8.0" + pytorch_git_rev = "v2.8.0" + accelerator = "tpu" + python_version = "3.10" + bundle_libtpu = "0" + cxx11_abi = "1" + }, + { + git_tag = "v2.8.0" + package_version = "2.8.0" + pytorch_git_rev = "v2.8.0" + accelerator = "tpu" + python_version = "3.10" + bundle_libtpu = "1" + cxx11_abi = "1" + }, + { + git_tag = "v2.8.0" + package_version = "2.8.0" + pytorch_git_rev = "v2.8.0" + accelerator = "tpu" + python_version = "3.11" + bundle_libtpu = "0" + cxx11_abi = "1" + }, + { + git_tag = "v2.8.0" + package_version = "2.8.0" + pytorch_git_rev = "v2.8.0" + accelerator = "tpu" + python_version = "3.12" + bundle_libtpu = "0" + cxx11_abi = "1" + }, + { + git_tag = "v2.8.0" + package_version = "2.8.0" + pytorch_git_rev = "v2.8.0" + accelerator = "tpu" + python_version = "3.13" + bundle_libtpu = "0" + cxx11_abi = "1" + }, { git_tag = "v2.7.0" package_version = "2.7.0" From 30ad68a2965958c9530bb6b0ac3eaa4f615f7eaf Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 7 Aug 2025 09:13:58 -0300 Subject: [PATCH 039/133] Update torchbench pin location. (#9543) --- test/benchmarks/run_torchbench_tests.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/benchmarks/run_torchbench_tests.sh b/test/benchmarks/run_torchbench_tests.sh index e3c2d934eb5b..9ac2d90235b0 100755 --- a/test/benchmarks/run_torchbench_tests.sh +++ b/test/benchmarks/run_torchbench_tests.sh @@ -54,7 +54,7 @@ function install_package() { function install_torchbench_models() { pushd $CDIR - torchbench_commit_hash=$(cat $PYTORCH_DIR/.github/ci_commit_pins/torchbench.txt) + torchbench_commit_hash=$(cat $PYTORCH_DIR/.ci/docker/ci_commit_pins/torchbench.txt) git clone --quiet https://github.com/pytorch/benchmark.git "$TORCHBENCH_DIR" cd $TORCHBENCH_DIR git checkout $torchbench_commit_hash From 6050927ce079cdc0d5f68a3fbbf970606fcf1f8e Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 7 Aug 2025 11:44:02 -0300 Subject: [PATCH 040/133] Improve error message of functions related to `GetXlaTensor()`. (#9520) --- test/quantized_ops/test_dot_general.py | 19 +++++ torch_xla/csrc/BUILD | 1 + torch_xla/csrc/aten_xla_bridge.cpp | 50 +++++++++--- torch_xla/csrc/aten_xla_bridge.h | 9 +++ torch_xla/csrc/init_python_bindings.cpp | 103 ++++++++++++++++-------- torch_xla/csrc/status.cpp | 10 ++- torch_xla/csrc/status.h | 14 ++++ 7 files changed, 160 insertions(+), 46 deletions(-) diff --git a/test/quantized_ops/test_dot_general.py b/test/quantized_ops/test_dot_general.py index 71a39ff56e96..68418945a066 100644 --- a/test/quantized_ops/test_dot_general.py +++ b/test/quantized_ops/test_dot_general.py @@ -56,6 +56,25 @@ def test_dot_general_int32_dtype(self): preferred_element_type=torch.int32) self.assertTrue(torch.allclose(xla_out.cpu(), expected_out)) + def test_raises_error_on_non_xla_tensor(self): + lhs = torch.rand(10, 3, 4, dtype=torch.bfloat16) + rhs = torch.rand(10, 4, 5, dtype=torch.bfloat16) + + def test(args, non_xla_tensor_arg): + arg_number_to_str = ["first", "second"] + position = arg_number_to_str[non_xla_tensor_arg] + try: + torch_xla._XLAC._xla_dot_general(*args, (([2], [1]), ([0], [0]))) + except RuntimeError as err: + error_message = ( + f"Expected input tensor ({position} argument) to be an actual XLA tensor. " + f"Got: CPUBFloat16Type. Consider moving it ({position} argument) to XLA." + ) + self.assertEqual(str(err), error_message) + + test((lhs, rhs.to(device)), non_xla_tensor_arg=0) + test((lhs.to(device), rhs), non_xla_tensor_arg=1) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 31ab65dbbcaf..8f2a1bdc67cb 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -280,6 +280,7 @@ ptxla_cc_library( "//torch_xla/csrc/runtime:xla_coordinator", "//torch_xla/csrc/runtime:xla_util", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:variant", diff --git a/torch_xla/csrc/aten_xla_bridge.cpp b/torch_xla/csrc/aten_xla_bridge.cpp index 05d92101383a..0f1969e64d5c 100644 --- a/torch_xla/csrc/aten_xla_bridge.cpp +++ b/torch_xla/csrc/aten_xla_bridge.cpp @@ -7,7 +7,7 @@ #include #include -#include "absl/status/status.h" +#include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/runtime/debug_macros.h" @@ -80,8 +80,12 @@ static absl::StatusOr GetXlaTensorImpl( XLATensorImpl* impl = dynamic_cast(inner_tensor.unsafeGetTensorImpl()); if (impl == nullptr) { - return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( - "Input tensor is not an XLA tensor: ", tensor.toString()))); + auto error_message = + absl::StrCat("Failed retrieving the inner XLATensorImpl* from ", + tensor.toString(), ". ", + "It's likely that `tensor` is not an actual XLA tensor, " + "i.e. it wasn't created inside PyTorch/XLA."); + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(error_message)); } return impl; } @@ -99,7 +103,9 @@ absl::StatusOr GetXlaTensor( // To make sure we have the most updated version of tensor. at::functionalization::impl::sync(tensor); } - XLA_ASSIGN_OR_RETURN(XLATensorImpl * impl, GetXlaTensorImpl(tensor)); + XLA_ASSIGN_OR_RETURN( + XLATensorImpl * impl, GetXlaTensorImpl(tensor), + absl::StrCat("Expected XLA tensor. Got: ", tensor.toString())); return impl->tensor(); } @@ -107,33 +113,53 @@ absl::StatusOr> GetXlaTensors( const at::ITensorListRef& tensors) { std::vector xla_tensors; xla_tensors.reserve(tensors.size()); + std::size_t index = 0; for (const auto& tensor : tensors) { - XLA_ASSIGN_OR_RETURN(XLATensorPtr ptr, bridge::GetXlaTensor(tensor)); + XLA_ASSIGN_OR_RETURN( + XLATensorPtr ptr, bridge::GetXlaTensor(tensor), + absl::StrCat("Expected all tensors in the given list to be XLA " + "tensors. Element at index ", + index, " is not an XLA tensor. Got: ", tensor.toString())); xla_tensors.push_back(std::move(ptr)); + index += 1; } return xla_tensors; } +absl::StatusOr GetInputXlaTensor( + const at::Tensor& tensor, const std::string_view param) { + XLA_ASSIGN_OR_RETURN( + XLATensorPtr ptr, GetXlaTensor(tensor), + absl::StrCat("Expected input tensor (", param, + ") to be an actual XLA tensor. Got: ", tensor.toString(), + ". Consider moving it (", param, ") to XLA.")); + return ptr; +} + bool IsXlaTensor(const at::Tensor& tensor) { return GetXlaTensorImpl(tensor).ok(); } absl::Status ReplaceXlaTensor(const at::Tensor& tensor, XLATensorPtr new_xla_tensor) { - XLA_ASSIGN_OR_RETURN(XLATensorImpl * impl, GetXlaTensorImpl(tensor)); + XLA_ASSIGN_OR_RETURN(XLATensorImpl * impl, GetXlaTensorImpl(tensor), + "Failed replacing the XLA tensor in the given tensor."); impl->set_tensor(std::move(new_xla_tensor)); return absl::OkStatus(); } absl::Status ReplaceXlaTensor(const std::vector& tensors, const std::vector new_xla_tensors) { - if (tensors.size() != new_xla_tensors.size()) { - return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError( - absl::StrCat("The size of tensors and new_xla_tensors are not equal: ", - tensors.size(), " vs. ", new_xla_tensors.size()))); - } + ABSL_CHECK(tensors.size() == new_xla_tensors.size()) + << "Expected the size of the list of tensors (" << tensors.size() + << ") to match the size of the list of XLATensorPtr (" + << new_xla_tensors.size() << ")"; for (size_t i = 0; i < tensors.size(); ++i) { - XLA_RETURN_IF_ERROR(ReplaceXlaTensor(tensors[i], new_xla_tensors[i])); + XLA_RETURN_IF_ERROR( + ReplaceXlaTensor(tensors[i], new_xla_tensors[i]), + absl::StrCat( + "Failed replacing the XLA tensor at index ", i, + ". The reason being that that tensor is not an XLA tensor.")); } return absl::OkStatus(); } diff --git a/torch_xla/csrc/aten_xla_bridge.h b/torch_xla/csrc/aten_xla_bridge.h index d04873ec8ff3..2577a7a9d42f 100644 --- a/torch_xla/csrc/aten_xla_bridge.h +++ b/torch_xla/csrc/aten_xla_bridge.h @@ -59,6 +59,15 @@ absl::StatusOr GetXlaTensor( absl::StatusOr> GetXlaTensors( const at::ITensorListRef& tensors); +// Retrieves the underlying `XLATensorPtr` from `tensor`. +// +// If `tensor` is not an actual XLA tensor, this function will craft a +// specialized error message for PyTorch operations or Python API +// functions, i.e. functions where the parameter name makes sense for +// the end user. +absl::StatusOr GetInputXlaTensor( + const at::Tensor& tensor, std::string_view param); + bool IsXlaTensor(const at::Tensor& tensor); // Replaces the XLA tensor embedded within `tensor`'s XLA TensorImpl with diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 8d23ae7f9862..1d9a4b3c3ebd 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -25,7 +25,7 @@ #include #include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" +#include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" #include "absl/synchronization/blocking_counter.h" #include "absl/types/variant.h" @@ -38,6 +38,7 @@ #include "pybind11/pytypes.h" #include "pybind11/stl.h" #include "pybind11/stl_bind.h" +#include "status.h" #include "torch_xla/csrc/XLANativeFunctions.h" #include "torch_xla/csrc/aten_autograd_ops.h" #include "torch_xla/csrc/aten_fallback.h" @@ -87,6 +88,23 @@ namespace { constexpr int64_t kSeedInfoId = -127389; +// Traits related to the return type of the lambda function that wraps the +// actual implementation inside PythonScope. +template +struct RemoveStatus { + using type = T; +}; + +template <> +struct RemoveStatus { + using type = void; +}; + +template +struct RemoveStatus> { + using type = T; +}; + // Wraps a python scope (e.g. py::module) to provide more convenient APIs. // It behaves like a Scope object but has enhanced behaviors for the def*() // methods. This class has reference semantics, just like the Scope class. @@ -153,15 +171,29 @@ class PythonScope : public Scope { template static void Bind(Scope& scope, const char* const name, F&& f, const Extra&... extra) { - using RetType = + // `f` return type. + using FnRetType = typename c10::guts::infer_function_traits::type::return_type; - auto lambda = [f = std::move(f)](Args... args) -> RetType { + // Wrapper lambda return type. + // This is needed for handling returning status types. + using LambdaRetType = typename RemoveStatus::type; + // FnRetType is a status type iff after unwrapping the status type, + // the resulting type (i.e. LambdaRetType) is NOT the same as FnRetType. + constexpr bool returns_status_type = + !std::is_same::value; + + auto lambda = [f = std::move(f)](Args... args) -> LambdaRetType { // RAII for emitting Python warnings. // // This turns messages passed to `TORCH_WARN()` in `f` into Python // warnings. torch::PyWarningHandler handler; - return f(args...); + + if constexpr (returns_status_type) { + return GetValueOrThrow(f(args...)); + } else { + return f(args...); + } }; if constexpr (kind == FunctionKind::kInit) { @@ -237,13 +269,11 @@ std::string GetTensorsDump( const std::vector& tensors, const std::function< std::string(absl::Span)>& coverter) { + auto xtensors = GetValueOrThrow(bridge::GetXlaTensors(tensors)); std::vector nodes; - std::vector values; - for (auto& tensor : tensors) { - XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(tensor)); - values.push_back(xtensor->GetIrValue()); - nodes.push_back(values.back().node.get()); - } + std::transform( + xtensors.begin(), xtensors.end(), std::back_inserter(nodes), + [](const XLATensorPtr& ptr) { return ptr->GetIrValue().node.get(); }); return coverter(nodes); } @@ -363,7 +393,7 @@ std::vector> ExtractXlaDotGeneralDimVectors( return dim_vectors; } -at::Tensor XlaDotGeneral(const at::Tensor& lhs, const at::Tensor& rhs, +at::Tensor XlaDotGeneral(const XLATensorPtr& xlhs, const XLATensorPtr& xrhs, const std::vector>& dim_vectors, std::optional preferred_element_type) { std::optional at_preferred_element_type; @@ -373,9 +403,7 @@ at::Tensor XlaDotGeneral(const at::Tensor& lhs, const at::Tensor& rhs, ->scalar_type; } return bridge::AtenFromXlaTensor(tensor_methods::xla_dot_general( - GetValueOrThrow(bridge::GetXlaTensor(lhs)), - GetValueOrThrow(bridge::GetXlaTensor(rhs)), dim_vectors, - at_preferred_element_type)); + xlhs, xrhs, dim_vectors, at_preferred_element_type)); } std::vector> CreateSourceTargetPairs( @@ -1841,10 +1869,11 @@ void InitXlaModuleBindings(py::module m) { }) .def( "_xla_dot_general", - [](const at::Tensor& lhs, const at::Tensor& rhs, + [](const at::Tensor& lhs, + const at::Tensor& rhs, py::tuple dimension_numbers, std::optional& precision_config, - std::optional& preferred_element_type) -> at::Tensor { + std::optional& preferred_element_type) -> absl::StatusOr { // Python binding for xla::DotGeneral // https://openxla.org/xla/operation_semantics#dotgeneral std::vector> dim_vectors = @@ -1852,9 +1881,13 @@ void InitXlaModuleBindings(py::module m) { XLA_CHECK(!precision_config.has_value()) << "_xla_dot_general: precision_config is not supported yet, " "default precision setting will be applied."; - at::Tensor result = - XlaDotGeneral(lhs, rhs, dim_vectors, preferred_element_type); - return result; + XLA_ASSIGN_OR_RETURN( + XLATensorPtr xlhs, + bridge::GetInputXlaTensor(lhs, /* param= */ "first argument")); + XLA_ASSIGN_OR_RETURN( + XLATensorPtr xrhs, + bridge::GetInputXlaTensor(rhs, /* param= */ "second argument")); + return XlaDotGeneral(xlhs, xrhs, dim_vectors, preferred_element_type); }, py::arg("lhs"), // py::arg("rhs"), // @@ -3340,19 +3373,25 @@ void InitXlaModuleBindings(py::module m) { opt_device ? &opt_device.value() : nullptr); return check_materialization_helper(xtensors); }) - .def( - "_get_graph_hash", - [](const std::vector& tensors) { - std::vector xtensors; - xtensors.reserve(tensors.size()); - for (auto& tensor : tensors) { - xtensors.push_back(GetValueOrThrow(bridge::GetXlaTensor(tensor))); - } - torch::lazy::hash_t hash = - XLAGraphExecutor::Get()->GetGraphHash(xtensors); - std::string bin((const char*)&hash, sizeof(hash)); - return py::bytes(bin); - }) + .def("_get_graph_hash", + [](const std::vector& tensors) -> py::bytes { + absl::StatusOr> + xtensors_status = bridge::GetXlaTensors(tensors); + ABSL_CHECK(xtensors_status.ok()) + << "_get_graph_hash(): error retrieving the XLA tensors from " + << "the given tensor arguments. " + << "This is a bug! Please, open an issue in the PyTorch/XLA " + << "GitHub repository: https://github.com/pytorch/xla" + << std::endl + << "Status Error: " + << BuildStatusErrorMessage(xtensors_status.status()); + std::vector xtensors = + xtensors_status.value(); + torch::lazy::hash_t hash = + XLAGraphExecutor::Get()->GetGraphHash(xtensors); + std::string bin((const char*)&hash, sizeof(hash)); + return py::bytes(bin); + }) .def("_clear_pending_irs", [](const std::string& device) { // Use with caution. Those tensor whole ir was cleared diff --git a/torch_xla/csrc/status.cpp b/torch_xla/csrc/status.cpp index 270f34878675..a4e4eddc0cc2 100644 --- a/torch_xla/csrc/status.cpp +++ b/torch_xla/csrc/status.cpp @@ -119,9 +119,15 @@ static std::string MaybeGetMessageWithLineBreak(const absl::Status& status) { : std::string(status.message()); } +std::string BuildStatusErrorMessage(const absl::Status& status) { + return absl::StrCat(MaybeGetMessageWithLineBreak(status), + GetFormattedStatusPropagationTrace(status)); +} + void MaybeThrow(const absl::Status& status) { - TORCH_CHECK(status.ok(), MaybeGetMessageWithLineBreak(status), - GetFormattedStatusPropagationTrace(status)); + TORCH_CHECK(status.ok(), BuildStatusErrorMessage(status)); } +void GetValueOrThrow(const absl::Status& status) { MaybeThrow(status); } + } // namespace torch_xla diff --git a/torch_xla/csrc/status.h b/torch_xla/csrc/status.h index 2f53b37381fb..c922e2f511f3 100644 --- a/torch_xla/csrc/status.h +++ b/torch_xla/csrc/status.h @@ -174,6 +174,17 @@ absl::Status MaybeWithNewMessage(const absl::Status& status, const char* file, } // namespace status_internal +// Builds the complete error message for the given `status`. +// +// If `TORCH_SHOW_CPP_STACKTRACES` is enabled, returns the concatenation of +// `status.message()` with its inner status propagation trace. +// +// TODO(ysiraichi): this call does not append the C++ stacktrace, which, +// ideally, should. It can be done by not using `TORCH_CHECK()` macro directly +// in `MaybeThrow()`, but using PyTorch `c10::get_lazy_backtrace()` +// (at c10/util/Backtrace.h). +std::string BuildStatusErrorMessage(const absl::Status& status); + // Maybe throws an exception if `status` has a non-ok code. // // Ideally, this function should be used only used in the project's @@ -200,6 +211,9 @@ T GetValueOrThrow(absl::StatusOr&& status) { return std::move(status).value(); } +// `GetValueOrThrow` overload for `Status`. +void GetValueOrThrow(const absl::Status& status); + } // namespace torch_xla #endif // XLA_TORCH_XLA_CSRC_STATUS_H_ From 41bfd6259df8aceb242aa32f90c94617aeba2e57 Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Thu, 7 Aug 2025 15:34:58 -0700 Subject: [PATCH 041/133] Update artifacts_builds.tf for rc5 --- infra/tpu-pytorch-releases/artifacts_builds.tf | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/infra/tpu-pytorch-releases/artifacts_builds.tf b/infra/tpu-pytorch-releases/artifacts_builds.tf index 9ede0110ff3b..099a2402afe9 100644 --- a/infra/tpu-pytorch-releases/artifacts_builds.tf +++ b/infra/tpu-pytorch-releases/artifacts_builds.tf @@ -2,8 +2,8 @@ # Define common configuration parameters for 2.8 release and nightly locals { tpu_python_versions = ["3.9", "3.10", "3.11", "3.12", "3.13"] - release_git_tag = "v2.8.0-rc4" - release_package_version = "2.8.0-rc4" + release_git_tag = "v2.8.0-rc5" + release_package_version = "2.8.0-rc5" release_pytorch_git_rev = "v2.8.0-rc8" nightly_package_version = "2.9.0" cuda_versions = { From 57cd41c038e20ac14d6c28a8c5ee3f378b30fb27 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 8 Aug 2025 15:31:05 -0300 Subject: [PATCH 042/133] Refactor the status error message builder. (#9546) --- torch_xla/csrc/init_python_bindings.cpp | 13 ++++++++----- torch_xla/csrc/status.cpp | 26 ++++++++++--------------- torch_xla/csrc/status.h | 5 +---- 3 files changed, 19 insertions(+), 25 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 1d9a4b3c3ebd..a2b799a5f0e7 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -3378,13 +3378,16 @@ void InitXlaModuleBindings(py::module m) { absl::StatusOr> xtensors_status = bridge::GetXlaTensors(tensors); ABSL_CHECK(xtensors_status.ok()) - << "_get_graph_hash(): error retrieving the XLA tensors from " - << "the given tensor arguments. " + << "\n\n" + << "Internal Error:\n" + << " _get_graph_hash(): error retrieving the XLA tensors " + "from the given tensor arguments. " << "This is a bug! Please, open an issue in the PyTorch/XLA " << "GitHub repository: https://github.com/pytorch/xla" - << std::endl - << "Status Error: " - << BuildStatusErrorMessage(xtensors_status.status()); + << "\n\n" + << "Status Error:\n" + << " " << BuildStatusErrorMessage(xtensors_status.status()) + << "\n"; std::vector xtensors = xtensors_status.value(); torch::lazy::hash_t hash = diff --git a/torch_xla/csrc/status.cpp b/torch_xla/csrc/status.cpp index a4e4eddc0cc2..2e1c7002e897 100644 --- a/torch_xla/csrc/status.cpp +++ b/torch_xla/csrc/status.cpp @@ -103,29 +103,23 @@ static std::string GetFormattedStatusPropagationTrace( auto status_propagation_trace = GetStatusPropagationTraceOrEmpty(status); return status_propagation_trace.empty() ? "" - : absl::StrCat("\nStatus Propagation Trace:", - status_propagation_trace.Flatten(), "\n"); -} - -// Get the status message followed by a line break, if we are printing the -// C++ stacktraces. -// -// This is needed so we have a blank line in between the status message and -// the dumped C++ traces (either the status propagation one, or the C++ -// stacktrace). -static std::string MaybeGetMessageWithLineBreak(const absl::Status& status) { - return torch::get_cpp_stacktraces_enabled() - ? absl::StrCat(status.message(), "\n") - : std::string(status.message()); + : absl::StrCat("\n\nStatus Propagation Trace:", + status_propagation_trace.Flatten()); } std::string BuildStatusErrorMessage(const absl::Status& status) { - return absl::StrCat(MaybeGetMessageWithLineBreak(status), + return absl::StrCat(status.message(), GetFormattedStatusPropagationTrace(status)); } +// Return a line break if torch::get_cpp_stacktraces_enabled() is true. +static std::string LineBreakIfCppStacktracesEnabled() { + return torch::get_cpp_stacktraces_enabled() ? "\n" : ""; +} + void MaybeThrow(const absl::Status& status) { - TORCH_CHECK(status.ok(), BuildStatusErrorMessage(status)); + TORCH_CHECK(status.ok(), absl::StrCat(BuildStatusErrorMessage(status), + LineBreakIfCppStacktracesEnabled())); } void GetValueOrThrow(const absl::Status& status) { MaybeThrow(status); } diff --git a/torch_xla/csrc/status.h b/torch_xla/csrc/status.h index c922e2f511f3..b2d508076a3b 100644 --- a/torch_xla/csrc/status.h +++ b/torch_xla/csrc/status.h @@ -179,10 +179,7 @@ absl::Status MaybeWithNewMessage(const absl::Status& status, const char* file, // If `TORCH_SHOW_CPP_STACKTRACES` is enabled, returns the concatenation of // `status.message()` with its inner status propagation trace. // -// TODO(ysiraichi): this call does not append the C++ stacktrace, which, -// ideally, should. It can be done by not using `TORCH_CHECK()` macro directly -// in `MaybeThrow()`, but using PyTorch `c10::get_lazy_backtrace()` -// (at c10/util/Backtrace.h). +// It doesn't add a trailing line break. std::string BuildStatusErrorMessage(const absl::Status& status); // Maybe throws an exception if `status` has a non-ok code. From 8c1449f6d700fa4251616fa5d3f9f5526d4ce976 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 8 Aug 2025 15:31:55 -0300 Subject: [PATCH 043/133] Use `TORCH_CHECK()` instead of throwing `std::runtime_error` in `XLA_CHECK*()` macros' implementation. (#9542) --- test/cpp/BUILD | 2 + test/cpp/test_aten_xla_tensor_1.cpp | 2 +- test/cpp/test_aten_xla_tensor_4.cpp | 2 +- test/cpp/test_debug_macros.cpp | 49 ++++-- test/cpp/test_ir.cpp | 2 +- test/cpp/test_status_common.h | 165 ++++++++---------- .../test_status_dont_show_cpp_stacktraces.cpp | 6 + test/cpp/test_status_show_cpp_stacktraces.cpp | 6 + torch_xla/csrc/runtime/BUILD | 2 + torch_xla/csrc/runtime/debug_macros.h | 7 + torch_xla/csrc/runtime/runtime.h | 2 +- torch_xla/csrc/runtime/tf_logging.cpp | 5 +- 12 files changed, 134 insertions(+), 116 deletions(-) diff --git a/test/cpp/BUILD b/test/cpp/BUILD index e752eab4f670..483d5ef7c01e 100644 --- a/test/cpp/BUILD +++ b/test/cpp/BUILD @@ -167,6 +167,7 @@ cc_library( name = "test_status_common", hdrs = ["test_status_common.h"], deps = [ + ":cpp_test_util", "//torch_xla/csrc:status", "//torch_xla/csrc/runtime:env_vars", "@com_google_absl//absl/status:statusor", @@ -196,6 +197,7 @@ ptxla_cc_test( name = "test_debug_macros", srcs = ["test_debug_macros.cpp"], deps = [ + ":cpp_test_util", "//torch_xla/csrc:status", "//torch_xla/csrc/runtime:debug_macros", "//torch_xla/csrc/runtime:env_vars", diff --git a/test/cpp/test_aten_xla_tensor_1.cpp b/test/cpp/test_aten_xla_tensor_1.cpp index 694c45945639..e2813b88a944 100644 --- a/test/cpp/test_aten_xla_tensor_1.cpp +++ b/test/cpp/test_aten_xla_tensor_1.cpp @@ -2429,7 +2429,7 @@ TEST_F(AtenXlaTensorTest, TestCount_Nonzero_error_case) { torch::Tensor xla_a = CopyToDevice(a, device); std::vector dims = {0, 0}; - EXPECT_THROW(torch::count_nonzero(xla_a, dims), std::runtime_error); + EXPECT_THROW(torch::count_nonzero(xla_a, dims), c10::Error); dims = {10}; EXPECT_THROW(torch::count_nonzero(xla_a, dims), c10::Error); diff --git a/test/cpp/test_aten_xla_tensor_4.cpp b/test/cpp/test_aten_xla_tensor_4.cpp index 20b0789ca8cc..5b1d99524b8c 100644 --- a/test/cpp/test_aten_xla_tensor_4.cpp +++ b/test/cpp/test_aten_xla_tensor_4.cpp @@ -307,7 +307,7 @@ TEST_F(AtenXlaTensorTest, TestGettingSizeOnDynamicTensor) { torch::TensorOptions(torch::kFloat)); torch::Tensor xla_b = CopyToDevice(b, device); torch::Tensor xla_nonzero = torch::nonzero(xla_b); - EXPECT_THROW(xla_nonzero.sizes(), std::runtime_error); + EXPECT_THROW(xla_nonzero.sizes(), c10::Error); EXPECT_NO_THROW(xla_nonzero.sym_sizes()); }); } diff --git a/test/cpp/test_debug_macros.cpp b/test/cpp/test_debug_macros.cpp index 52c01668791b..e62e40484fe6 100644 --- a/test/cpp/test_debug_macros.cpp +++ b/test/cpp/test_debug_macros.cpp @@ -1,29 +1,42 @@ #include #include +#include "test/cpp/cpp_test_util.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/env_vars.h" -namespace torch_xla { +namespace torch_xla::cpp_test { namespace { -using absl::StrCat; +// Prefix of the C++ stacktrace PyTorch adds to the error message. +constexpr char kTorchCppStacktracePrefix[] = + "Exception raised from operator& at torch_xla/csrc/runtime/tf_logging.cpp:"; TEST(DebugMacrosTest, Check) { - auto line = __LINE__ + 1; - EXPECT_THAT([] { XLA_CHECK(false) << "Error message."; }, - testing::ThrowsMessage(testing::StartsWith( - StrCat("Check failed: false: Error message. (at ", __FILE__, - ":", line, ")\n*** Begin stack trace ***")))); + int32_t line; + try { + line = __LINE__ + 1; + XLA_CHECK(false) << "Error message."; + } catch (const c10::Error& error) { + EXPECT_THAT(error.what(), + testing::StartsWith(absl::StrCat( + "Check failed: false: Error message. (at ", __FILE__, ":", + line, ")\n\n", kTorchCppStacktracePrefix))); + } } -#define TEST_XLA_CHECK_OP_(opstr, lhs, rhs, compstr, valstr) \ - TEST(DebugMacrosTest, Check##opstr) { \ - EXPECT_THAT( \ - [] { XLA_CHECK_##opstr(lhs, rhs) << " Error message."; }, \ - testing::ThrowsMessage(testing::StartsWith(StrCat( \ - "Check failed: " compstr " (" valstr ") Error message. (at ", \ - __FILE__, ":", __LINE__, ")\n*** Begin stack trace ***")))); \ +#define TEST_XLA_CHECK_OP_(opstr, lhs, rhs, compstr, valstr) \ + TEST(DebugMacrosTest, Check##opstr) { \ + try { \ + XLA_CHECK_##opstr(lhs, rhs) << " Error message."; \ + } catch (const c10::Error& error) { \ + EXPECT_THAT( \ + error.what(), \ + ::testing::StartsWith(absl::StrCat( \ + "Check failed: " compstr " (" valstr ") Error message. (at ", \ + __FILE__, ":", __LINE__, ")\n\n", \ + ::torch_xla::cpp_test::kTorchCppStacktracePrefix))); \ + } \ } #define TEST_XLA_CHECK_OP(opstr, op, lhs, rhs) \ @@ -52,15 +65,15 @@ TEST_XLA_CHECK_OP(LT, <, 5, 1) TEST_XLA_CHECK_GREATER(GE, <=, 5, 8) TEST_XLA_CHECK_GREATER(GT, <, 5, 8) -} // namespace -} // namespace torch_xla - static void SetUp() { setenv("TORCH_SHOW_CPP_STACKTRACES", /* value= */ "1", /* replace= */ 1); } +} // namespace +} // namespace torch_xla::cpp_test + int main(int argc, char** argv) { - SetUp(); + ::torch_xla::cpp_test::SetUp(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/test/cpp/test_ir.cpp b/test/cpp/test_ir.cpp index ab75714dc314..39b155a564ff 100644 --- a/test/cpp/test_ir.cpp +++ b/test/cpp/test_ir.cpp @@ -358,7 +358,7 @@ TEST_F(IrTest, TestSizeDivNodeDynamicByZero) { std::shared_ptr dim_node_div = std::dynamic_pointer_cast(node_div); - EXPECT_THROW(dim_node_div->getDynamicValue(), std::runtime_error); + EXPECT_THROW(dim_node_div->getDynamicValue(), c10::Error); } } // namespace cpp_test diff --git a/test/cpp/test_status_common.h b/test/cpp/test_status_common.h index 4d4b173f6431..5cf8285f5ebf 100644 --- a/test/cpp/test_status_common.h +++ b/test/cpp/test_status_common.h @@ -23,11 +23,13 @@ #include #include +#include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "test/cpp/cpp_test_util.h" #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/status.h" @@ -75,7 +77,11 @@ class StatusTest : public testing::TestWithParam { [](const ::testing::TestParamInfo<::torch_xla::CppStacktracesMode>& \ info) { return ToString(info.param); }) -namespace testing { +namespace cpp_test { + +// Prefix of the C++ stacktrace PyTorch adds to the error message. +constexpr inline char kTorchCppStacktracePrefix[] = + "Exception raised from MaybeThrow at torch_xla/csrc/status.cpp:"; constexpr inline char kNewMessage[] = "New test error message"; constexpr inline char kMessage[] = "Test error message"; @@ -84,29 +90,6 @@ constexpr inline char kFunction[] = "foo"; constexpr inline char kEntryPrefix[] = "\n "; constexpr inline int32_t kLine = 42; -// The PyTorch C++ stacktrace is ALWAYS appended to the error message. -// More specifically, when `what()` function is called. -// -// However, it's only when the raised `c10::Error` gets translated to a -// Python exception that PyTorch checks the value of the -// `TORCH_SHOW_CPP_STACKTRACES` environment variable, which actually -// controls whether the stacktrace will get shown or not by calling -// `what_without_backtraces()`, instead. -// -// Therefore, we need to mimic this behavior. -#define THROW_RUNTIME_ERROR_FROM_C10_ERROR(block) \ - try { \ - block; \ - } catch (const c10::Error& error) { \ - throw std::runtime_error(IsShowCppStacktracesMode() \ - ? error.what() \ - : error.what_without_backtrace()); \ - } - -// Prefix of the C++ stacktrace PyTorch adds to the error message. -constexpr inline char kTorchCppStacktracePrefix[] = - "Exception raised from MaybeThrow at torch_xla/csrc/status.cpp:"; - inline std::string GetStatusPropagationTrace(const absl::Status& status) { if (status.ok()) { return ""; @@ -123,21 +106,18 @@ TEST_P(StatusTest, MaybeThrowWithOkStatus) { } TEST_P(StatusTest, MaybeThrowWithErrorStatus) { - auto throw_exception = [=]() { - THROW_RUNTIME_ERROR_FROM_C10_ERROR({ - absl::Status error_status = absl::InvalidArgumentError(kMessage); - MaybeThrow(error_status); - }); - }; - - if (IsShowCppStacktracesMode()) { - std::string expected_prefix = - absl::StrCat(kMessage, "\n\n", kTorchCppStacktracePrefix); - EXPECT_THAT(throw_exception, ::testing::ThrowsMessage( - ::testing::StartsWith(expected_prefix))); - } else { - EXPECT_THAT(throw_exception, ::testing::ThrowsMessage( - ::testing::Eq(kMessage))); + try { + absl::Status error_status = absl::InvalidArgumentError(kMessage); + MaybeThrow(error_status); + } catch (const c10::Error& error) { + if (IsShowCppStacktracesMode()) { + EXPECT_THAT(std::string_view(error.what()), + ::testing::StartsWith(absl::StrCat( + kMessage, "\n\n", kTorchCppStacktracePrefix))); + } else { + EXPECT_EQ(std::string_view(error.what_without_backtrace()), + std::string_view(kMessage)); + } } } @@ -149,20 +129,18 @@ TEST_P(StatusTest, GetValueOrThrowWithOkStatusOr) { } TEST_P(StatusTest, GetValueOrThrowWithErrorStatusOr) { - auto throw_exception = [=]() { - THROW_RUNTIME_ERROR_FROM_C10_ERROR({ - absl::StatusOr error_status = absl::InvalidArgumentError(kMessage); - int value = GetValueOrThrow(error_status); - }); - }; - if (IsShowCppStacktracesMode()) { - std::string expected_prefix = - absl::StrCat(kMessage, "\n\n", kTorchCppStacktracePrefix); - EXPECT_THAT(throw_exception, ::testing::ThrowsMessage( - ::testing::StartsWith(expected_prefix))); - } else { - EXPECT_THAT(throw_exception, ::testing::ThrowsMessage( - ::testing::Eq(kMessage))); + try { + absl::StatusOr error_status = absl::InvalidArgumentError(kMessage); + int value = GetValueOrThrow(error_status); + } catch (const c10::Error& error) { + if (IsShowCppStacktracesMode()) { + EXPECT_THAT(std::string_view(error.what()), + ::testing::StartsWith(absl::StrCat( + kMessage, "\n\n", kTorchCppStacktracePrefix))); + } else { + EXPECT_EQ(std::string_view(error.what_without_backtrace()), + std::string_view(kMessage)); + } } } @@ -272,14 +250,14 @@ TEST_P(StatusTest, MacroReturnIfErrorWithNestedError) { EXPECT_EQ(result.message(), std::string_view(kMessage)); if (IsShowCppStacktracesMode()) { - auto frame0 = absl::StrCat(kEntryPrefix, "From: operator() at ", __FILE__, - ":", errline0, " (error: ", kMessage, ")"); - auto frame1 = absl::StrCat(kEntryPrefix, "From: operator() at ", __FILE__, - ":", errline1); - auto frame2 = absl::StrCat(kEntryPrefix, "From: operator() at ", __FILE__, - ":", errline2); - EXPECT_EQ(GetStatusPropagationTrace(result), - absl::StrCat(frame0, frame1, frame2)); + std::ostringstream oss; + oss << kEntryPrefix << "From: operator() at " << __FILE__ << ":" << errline0 + << " (error: " << kMessage << ")"; + oss << kEntryPrefix << "From: operator() at " << __FILE__ << ":" + << errline1; + oss << kEntryPrefix << "From: operator() at " << __FILE__ << ":" + << errline2; + EXPECT_EQ(GetStatusPropagationTrace(result), oss.str()); } } @@ -383,39 +361,44 @@ TEST_P(StatusTest, MaybeThrowWithErrorPropagationWithNewMessage) { return absl::OkStatus(); }; - auto throw_exception = [&]() { - THROW_RUNTIME_ERROR_FROM_C10_ERROR(MaybeThrow(outerfn())); - }; - - if (IsShowCppStacktracesMode()) { - // Expected Error Message Prefix - // ============================= - // - // New test error kMessage - // - // Status Propagation Stacktrace: - // From: ./test/cpp/test_status_common.h:329 (error: Test error - // kMessage) From: ./test/cpp/test_status_common.h:335 (error: New test - // error kMessage) From: ./test/cpp/test_status_common.h:342 - // - // C++ Stacktrace: - // - std::string expected_prefix = absl::StrCat( - kNewMessage, "\n\nStatus Propagation Trace:", kEntryPrefix, - "From: operator() at ", __FILE__, ":", errline0, " (error: ", kMessage, - ")", kEntryPrefix, "From: operator() at ", __FILE__, ":", errline1, - " (error: ", kNewMessage, ")", kEntryPrefix, "From: operator() at ", - __FILE__, ":", errline2, "\n\n", kTorchCppStacktracePrefix); - - EXPECT_THAT(throw_exception, ::testing::ThrowsMessage( - ::testing::StartsWith(expected_prefix))); - } else { - EXPECT_THAT(throw_exception, ::testing::ThrowsMessage( - ::testing::Eq(kNewMessage))); + try { + MaybeThrow(outerfn()); + } catch (const c10::Error& error) { + if (IsShowCppStacktracesMode()) { + // Expected Error Message Prefix + // ============================= + // + // New test error kMessage + // + // Status Propagation Stacktrace: + // From: ./test/cpp/test_status_common.h:329 (error: Test error + // kMessage) From: ./test/cpp/test_status_common.h:335 (error: New + // test error kMessage) From: ./test/cpp/test_status_common.h:342 + // + // C++ Stacktrace: + // + std::ostringstream oss; + oss << kNewMessage; + oss << "\n\n"; + oss << "Status Propagation Trace:"; + oss << kEntryPrefix << "From: operator() at " << __FILE__ << ":" + << errline0 << " (error: " << kMessage << ")"; + oss << kEntryPrefix << "From: operator() at " << __FILE__ << ":" + << errline1 << " (error: " << kNewMessage << ")"; + oss << kEntryPrefix << "From: operator() at " << __FILE__ << ":" + << errline2; + oss << "\n\n"; + oss << kTorchCppStacktracePrefix; + EXPECT_THAT(std::string_view(error.what()), + ::testing::StartsWith(oss.str())); + } else { + EXPECT_EQ(std::string_view(error.what_without_backtrace()), + std::string_view(kNewMessage)); + } } } -} // namespace testing +} // namespace cpp_test } // namespace torch_xla #endif // XLA_TEST_CPP_TEST_STATUS_COMMON_H_ diff --git a/test/cpp/test_status_dont_show_cpp_stacktraces.cpp b/test/cpp/test_status_dont_show_cpp_stacktraces.cpp index a6555f293825..16d49a38e7dd 100644 --- a/test/cpp/test_status_dont_show_cpp_stacktraces.cpp +++ b/test/cpp/test_status_dont_show_cpp_stacktraces.cpp @@ -2,6 +2,9 @@ using torch_xla::StatusTest; +namespace torch_xla::cpp_test { +namespace { + // This file instantiates the parameterized tests defined in // `test_status_common.h`. It specifically configures the test environment by // explicitly setting the `TORCH_SHOW_CPP_STACKTRACES` environment variable to @@ -11,3 +14,6 @@ using torch_xla::StatusTest; // automatically be run in this mode (without C++ error context). // INSTANTIATE_WITH_CPP_STACKTRACES_MODE(StatusTest, StatusTest, kHide); + +} // namespace +} // namespace torch_xla::cpp_test diff --git a/test/cpp/test_status_show_cpp_stacktraces.cpp b/test/cpp/test_status_show_cpp_stacktraces.cpp index 61f881a86293..a06e540a7025 100644 --- a/test/cpp/test_status_show_cpp_stacktraces.cpp +++ b/test/cpp/test_status_show_cpp_stacktraces.cpp @@ -2,6 +2,9 @@ using torch_xla::StatusTest; +namespace torch_xla::cpp_test { +namespace { + // This file instantiates the parameterized tests defined in // `test_status_common.h`. It specifically configures the test environment by // explicitly setting the `TORCH_SHOW_CPP_STACKTRACES` environment variable to @@ -11,3 +14,6 @@ using torch_xla::StatusTest; // automatically be run in this mode (with C++ error context). INSTANTIATE_WITH_CPP_STACKTRACES_MODE(StatusWithCppErrorContextTest, StatusTest, kShow); + +} // namespace +} // namespace torch_xla::cpp_test diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index c4760783f4d0..b381d3feff7c 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -395,6 +395,8 @@ cc_library( hdrs = ["tf_logging.h"], deps = [ "//torch_xla/csrc:status", + "@torch//:headers", + "@torch//:runtime_headers", "@tsl//tsl/platform:stacktrace", "@tsl//tsl/platform:statusor", "@xla//xla/service:platform_util", diff --git a/torch_xla/csrc/runtime/debug_macros.h b/torch_xla/csrc/runtime/debug_macros.h index 5bfc90d81fd6..6ebdafea6fb2 100644 --- a/torch_xla/csrc/runtime/debug_macros.h +++ b/torch_xla/csrc/runtime/debug_macros.h @@ -6,6 +6,13 @@ #include "tsl/platform/stacktrace.h" #include "tsl/platform/statusor.h" +// DEPRECATED +// ========== +// These macros are deprecated in favor of error handling by propagating abseil +// status types (e.g. `absl::Status` and `absl::StatusOr`). +// +// Description +// =========== // TORCH_SHOW_CPP_STACKTRACES environment variable changes the behavior of the // macros below, such as XLA_CHECK(), XLA_CHECK_EQ(), etc. (except for // XLA_CHECK_OK) in the following way: diff --git a/torch_xla/csrc/runtime/runtime.h b/torch_xla/csrc/runtime/runtime.h index f6af26cb66f2..6a1588935e6f 100644 --- a/torch_xla/csrc/runtime/runtime.h +++ b/torch_xla/csrc/runtime/runtime.h @@ -11,7 +11,7 @@ namespace torch_xla::runtime { const absl::StatusOr& GetComputationClient(); ABSL_DEPRECATED( - "Use status::GetComputationClient(), instead. " + "Use GetComputationClient(), instead. " "This function throws an exception on error, instead of " "actually handling the StatusOr return value, which is " "safer.") diff --git a/torch_xla/csrc/runtime/tf_logging.cpp b/torch_xla/csrc/runtime/tf_logging.cpp index 3b268f608810..3f6419a5127b 100644 --- a/torch_xla/csrc/runtime/tf_logging.cpp +++ b/torch_xla/csrc/runtime/tf_logging.cpp @@ -1,5 +1,6 @@ #include "torch_xla/csrc/runtime/tf_logging.h" +#include #include #include @@ -19,12 +20,10 @@ void ErrorGenerator::operator&(const std::basic_ostream& oss) const { if (torch::get_cpp_stacktraces_enabled()) { ess << " (at " << file_ << ":" << line_ << ")\n"; - ess << tsl::CurrentStackTrace(); } TF_VLOG(1) << ess.str(); - // We cannot use AT_ERROR() here, due to layering issues. - throw std::runtime_error(ess.str()); + TORCH_CHECK(false, ess.str()); } } // namespace internal From 095faec1e7b6cc47220181e74ae9cde2605f9b00 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 8 Aug 2025 15:32:11 -0300 Subject: [PATCH 044/133] Error Handling: make `XLATensor::Create()` return status type. (#9544) --- test/cpp/test_tensor.cpp | 103 ++++++++++++++++++----------- torch_xla/csrc/aten_xla_bridge.cpp | 8 ++- torch_xla/csrc/tensor.cpp | 14 ++-- torch_xla/csrc/tensor.h | 5 +- torch_xla/csrc/tensor_methods.cpp | 5 +- 5 files changed, 85 insertions(+), 50 deletions(-) diff --git a/test/cpp/test_tensor.cpp b/test/cpp/test_tensor.cpp index eff668c8809a..6d962c900496 100644 --- a/test/cpp/test_tensor.cpp +++ b/test/cpp/test_tensor.cpp @@ -101,8 +101,8 @@ TEST_F(TensorTest, TestAdd) { at::Tensor c = a.add(b, 1.0); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLATensorPtr dev_a = XLATensor::Create(a, device); - XLATensorPtr dev_b = XLATensor::Create(b, device); + XLATensorPtr dev_a = GetValueOrThrow(XLATensor::Create(a, device)); + XLATensorPtr dev_b = GetValueOrThrow(XLATensor::Create(b, device)); XLATensorPtr dev_c = tensor_methods::add(dev_a, dev_b, 1.0); AllClose(c, dev_c); @@ -121,8 +121,8 @@ TEST_F(TensorTest, TestIntegerAdd) { at::isIntegralType(type) ? at::Scalar(int64_t(1)) : at::Scalar(1.0); at::Tensor c = a.add(b, one); - XLATensorPtr dev_a = XLATensor::Create(a, device); - XLATensorPtr dev_b = XLATensor::Create(b, device); + XLATensorPtr dev_a = GetValueOrThrow(XLATensor::Create(a, device)); + XLATensorPtr dev_b = GetValueOrThrow(XLATensor::Create(b, device)); XLATensorPtr dev_c = tensor_methods::add(dev_a, dev_b, one); EXPECT_TRUE(EqualValuesNoElementTypeCheck( @@ -135,7 +135,7 @@ TEST_F(TensorTest, TestSize) { at::Tensor input = at::rand({2, 1, 4, 6}, at::TensorOptions(at::kFloat)); int rank = input.dim(); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLATensorPtr dev_input = XLATensor::Create(input, device); + XLATensorPtr dev_input = GetValueOrThrow(XLATensor::Create(input, device)); for (int dim = -rank; dim < rank; ++dim) { EXPECT_EQ(input.size(dim), dev_input->size(dim)); } @@ -151,8 +151,10 @@ TEST_F(TensorTest, TestRrelu) { at::Tensor noise = at::zeros_like(input); at::Tensor output = at::rrelu_with_noise(input, noise, lower, upper, training); - XLATensorPtr dev_input = XLATensor::Create(input, device); - XLATensorPtr dev_noise = XLATensor::Create(noise, device); + XLATensorPtr dev_input = + GetValueOrThrow(XLATensor::Create(input, device)); + XLATensorPtr dev_noise = + GetValueOrThrow(XLATensor::Create(noise, device)); XLATensorPtr dev_outputs = tensor_methods::rrelu_with_noise( dev_input, dev_noise, lower, upper, training); AllClose(output, dev_outputs); @@ -167,7 +169,7 @@ TEST_F(TensorTest, TestThreshold) { float value = 20; at::Tensor output = at::threshold(input, threshold, value); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLATensorPtr dev_input = XLATensor::Create(input, device); + XLATensorPtr dev_input = GetValueOrThrow(XLATensor::Create(input, device)); XLATensorPtr dev_output = tensor_methods::threshold(dev_input, threshold, value); AllClose(output, dev_output); @@ -185,9 +187,10 @@ TEST_F(TensorTest, TestAddMatMul) { at::Tensor bias = at::rand({labels}, at::TensorOptions(at::kFloat)); at::Tensor output = at::addmm(bias, input, weight); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLATensorPtr dev_input = XLATensor::Create(input, device); - XLATensorPtr dev_weight = XLATensor::Create(weight, device); - XLATensorPtr dev_bias = XLATensor::Create(bias, device); + XLATensorPtr dev_input = GetValueOrThrow(XLATensor::Create(input, device)); + XLATensorPtr dev_weight = + GetValueOrThrow(XLATensor::Create(weight, device)); + XLATensorPtr dev_bias = GetValueOrThrow(XLATensor::Create(bias, device)); XLATensorPtr dev_output = tensor_methods::addmm(dev_input, dev_weight, dev_bias); AllClose(output, dev_output); @@ -198,7 +201,7 @@ TEST_F(TensorTest, TestTranspose) { at::Tensor input = at::rand({2, 3}, at::TensorOptions(at::kFloat)); at::Tensor output = at::transpose(input, 0, 1); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLATensorPtr dev_input = XLATensor::Create(input, device); + XLATensorPtr dev_input = GetValueOrThrow(XLATensor::Create(input, device)); XLATensorPtr dev_output = tensor_methods::transpose(dev_input, 0, 1); AllClose(output, dev_output); }); @@ -208,7 +211,7 @@ TEST_F(TensorTest, TestView) { at::Tensor input = at::rand({32, 20, 4, 4}, at::TensorOptions(at::kFloat)); at::Tensor output = input.view({-1, 320}); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLATensorPtr dev_input = XLATensor::Create(input, device); + XLATensorPtr dev_input = GetValueOrThrow(XLATensor::Create(input, device)); XLATensorPtr dev_output = tensor_methods::view(dev_input, {-1, 320}); AllClose(output, dev_output); }); @@ -289,7 +292,8 @@ TEST_F(TensorTest, TestMaxPool2D) { /*padding=*/{padding, padding}, /*dilation=*/{1, 1}, /*ceil_mode=*/false); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLATensorPtr dev_input = XLATensor::Create(input, device); + XLATensorPtr dev_input = + GetValueOrThrow(XLATensor::Create(input, device)); auto dev_output = tensor_methods::max_pool_nd( dev_input, /*spatial_dim_count=*/2, @@ -313,7 +317,8 @@ TEST_F(TensorTest, TestMaxPool2DNonSquare) { /*padding=*/{padding, padding + 1}, /*dilation=*/{1, 1}, /*ceil_mode=*/false); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLATensorPtr dev_input = XLATensor::Create(input, device); + XLATensorPtr dev_input = + GetValueOrThrow(XLATensor::Create(input, device)); auto dev_output = tensor_methods::max_pool_nd( dev_input, /*spatial_dim_count=*/2, @@ -341,7 +346,8 @@ TEST_F(TensorTest, TestAvgPool2D) { /*ceil_mode=*/false, count_include_pad, /*divisor_override=*/std::nullopt); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLATensorPtr dev_input = XLATensor::Create(input, device); + XLATensorPtr dev_input = + GetValueOrThrow(XLATensor::Create(input, device)); XLATensorPtr dev_output = tensor_methods::avg_pool_nd( dev_input, /*spatial_dim_count=*/2, @@ -371,7 +377,8 @@ TEST_F(TensorTest, TestAvgPool2DNonSquare) { /*count_include_pad=*/count_include_pad, /*divisor_override=*/std::nullopt); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLATensorPtr dev_input = XLATensor::Create(input, device); + XLATensorPtr dev_input = + GetValueOrThrow(XLATensor::Create(input, device)); XLATensorPtr dev_output = tensor_methods::avg_pool_nd( dev_input, /*spatial_dim_count=*/2, @@ -409,15 +416,20 @@ TEST_F(TensorTest, TestBatchNorm1D) { /*running_mean=*/running_mean, /*running_var=*/running_var, /*training=*/training, /*momentum=*/momentum, /*eps=*/eps); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLATensorPtr xla_input = XLATensor::Create(input, device); - XLATensorPtr xla_weight = undef_weight_bias - ? XLATensorPtr() - : XLATensor::Create(weight, device); - XLATensorPtr xla_bias = undef_weight_bias - ? XLATensorPtr() - : XLATensor::Create(bias, device); - XLATensorPtr xla_running_mean = XLATensor::Create(running_mean, device); - XLATensorPtr xla_running_var = XLATensor::Create(running_var, device); + XLATensorPtr xla_input = + GetValueOrThrow(XLATensor::Create(input, device)); + XLATensorPtr xla_weight = + undef_weight_bias + ? XLATensorPtr() + : GetValueOrThrow(XLATensor::Create(weight, device)); + XLATensorPtr xla_bias = + undef_weight_bias + ? XLATensorPtr() + : GetValueOrThrow(XLATensor::Create(bias, device)); + XLATensorPtr xla_running_mean = + GetValueOrThrow(XLATensor::Create(running_mean, device)); + XLATensorPtr xla_running_var = + GetValueOrThrow(XLATensor::Create(running_var, device)); auto xla_output = tensor_methods::native_batch_norm( /*input=*/xla_input, /*weight=*/xla_weight, /*bias=*/xla_bias, /*running_mean=*/xla_running_mean, /*running_var=*/xla_running_var, @@ -474,11 +486,14 @@ TEST_F(TensorTest, TestConv2D) { /*output_padding=*/{output_padding, output_padding}, /*groups=*/groups, false, false, false); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLATensorPtr dev_input = XLATensor::Create(input, device); - XLATensorPtr dev_weight = XLATensor::Create(weight, device); + XLATensorPtr dev_input = + GetValueOrThrow(XLATensor::Create(input, device)); + XLATensorPtr dev_weight = + GetValueOrThrow(XLATensor::Create(weight, device)); XLATensorPtr dev_output; if (with_bias) { - XLATensorPtr dev_bias = XLATensor::Create(bias, device); + XLATensorPtr dev_bias = + GetValueOrThrow(XLATensor::Create(bias, device)); dev_output = tensor_methods::convolution_overrideable( dev_input, dev_weight, dev_bias, /*stride=*/{stride, stride}, @@ -543,11 +558,14 @@ TEST_F(TensorTest, TestConv2DNonSquare) { /*groups=*/groups, false, false, false); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLATensorPtr dev_input = XLATensor::Create(input, device); - XLATensorPtr dev_weight = XLATensor::Create(weight, device); + XLATensorPtr dev_input = + GetValueOrThrow(XLATensor::Create(input, device)); + XLATensorPtr dev_weight = + GetValueOrThrow(XLATensor::Create(weight, device)); XLATensorPtr dev_output; if (with_bias) { - XLATensorPtr dev_bias = XLATensor::Create(bias, device); + XLATensorPtr dev_bias = + GetValueOrThrow(XLATensor::Create(bias, device)); dev_output = tensor_methods::convolution_overrideable( dev_input, dev_weight, dev_bias, /*stride=*/{stride, stride + 1}, @@ -616,11 +634,14 @@ TEST_F(TensorTest, TestConv3D) { {output_padding, output_padding, output_padding}, /*groups=*/groups, false, false, false); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLATensorPtr dev_input = XLATensor::Create(input, device); - XLATensorPtr dev_weight = XLATensor::Create(weight, device); + XLATensorPtr dev_input = + GetValueOrThrow(XLATensor::Create(input, device)); + XLATensorPtr dev_weight = + GetValueOrThrow(XLATensor::Create(weight, device)); XLATensorPtr dev_output; if (with_bias) { - XLATensorPtr dev_bias = XLATensor::Create(bias, device); + XLATensorPtr dev_bias = + GetValueOrThrow(XLATensor::Create(bias, device)); dev_output = tensor_methods::convolution_overrideable( dev_input, dev_weight, dev_bias, /*stride=*/{stride, stride, stride}, @@ -688,10 +709,14 @@ TEST_F(TensorTest, TestConv3D) { // {output_padding, output_padding + 1, output_padding}, // /*groups=*/groups, false, false, false); // ForEachDevice([&](const torch::lazy::BackendDevice& device) { -// XLATensorPtr dev_input = XLATensor::Create(input, device); -// XLATensorPtr dev_weight = XLATensor::Create(weight, -// device); XLATensorPtr dev_output; if (with_bias) { -// XLATensorPtr dev_bias = XLATensor::Create(bias, device); +// XLATensorPtr dev_input = +// GetValueOrThrow(XLATensor::Create(input, device)); +// XLATensorPtr dev_weight = +// GetValueOrThrow(XLATensor::Create(weight, device); +// XLATensorPtr dev_output; +// if (with_bias) { +// XLATensorPtr dev_bias = +// GetValueOrThrow(XLATensor::Create(bias, device)); // dev_output = tensor_methods::convolution_overrideable( // dev_input, dev_weight, dev_bias, // /*stride=*/{stride, stride + 1, stride + 1}, diff --git a/torch_xla/csrc/aten_xla_bridge.cpp b/torch_xla/csrc/aten_xla_bridge.cpp index 0f1969e64d5c..8bc0cc32a615 100644 --- a/torch_xla/csrc/aten_xla_bridge.cpp +++ b/torch_xla/csrc/aten_xla_bridge.cpp @@ -186,8 +186,9 @@ XLATensorPtr GetOrCreateXlaTensor(const at::Tensor& tensor, } auto xtensor = GetXlaTensor(tensor); - return xtensor.ok() ? xtensor.value() - : XLATensor::Create(inner_tensor, device); + return xtensor.ok() + ? xtensor.value() + : GetValueOrThrow(XLATensor::Create(inner_tensor, device)); } XLATensorPtr GetOrCreateXlaTensor(const std::optional& tensor, @@ -478,7 +479,8 @@ at::Tensor CreateXlaTensor( at::Tensor tensor, const std::optional& device) { if (tensor.defined() && device) { - XLATensorPtr xla_tensor = XLATensor::Create(std::move(tensor), *device); + XLATensorPtr xla_tensor = + GetValueOrThrow(XLATensor::Create(std::move(tensor), *device)); tensor = AtenFromXlaTensor(xla_tensor); } return tensor; diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 6459293a87ff..106b2603e843 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -61,9 +61,15 @@ bool CanApplySharding(const XLATensor::ShardingSpecPtr sharding) { XLATensor::Data::~Data() { XLAGraphExecutor::Get()->UnregisterTensor(this); } -XLATensorPtr XLATensor::Create(const at::Tensor& tensor, - const torch::lazy::BackendDevice& device) { - XLA_CHECK_EQ(tensor.device().type(), at::kCPU); +absl::StatusOr XLATensor::Create( + const at::Tensor& tensor, const torch::lazy::BackendDevice& device) { + if (!tensor.is_cpu()) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + "Could not create an XLATensor out of the provided tensor. Expected " + "tensor data to be on CPU. Got: ", + c10::DeviceTypeName(tensor.device().type()), + ". Consider moving the tensor to CPU."))); + } XLATensorPtr xtensor = c10::make_intrusive(XLATensor(tensor, device)); XLAGraphExecutor::Get()->RegisterTensor(xtensor->data()); @@ -621,7 +627,7 @@ std::vector XLATensor::MakeOutputTensors( XLATensorPtr XLATensor::CopyTensorToDevice( const torch::lazy::BackendDevice& device) { // TODO: This can be optimized via proper XRT/XLA computation. - return Create(ToTensor(/*detached=*/true), device); + return GetValueOrThrow(Create(ToTensor(/*detached=*/true), device)); } torch::lazy::Value XLATensor::MaybeCastIrValue( diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 0d49e98b67f7..69bd8aa3a562 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -9,6 +9,7 @@ #include #include +#include "absl/base/nullability.h" #include "torch_xla/csrc/runtime/util.h" #include "torch_xla/csrc/view.h" @@ -149,8 +150,8 @@ class XLATensor : public torch::lazy::LazyTensor { bool is_cloned = false; }; - static XLATensorPtr Create(const at::Tensor& tensor, - const torch::lazy::BackendDevice& device); + static absl::StatusOr Create( + const at::Tensor& tensor, const torch::lazy::BackendDevice& device); static XLATensorPtr Create( torch::lazy::BackendDataPtr handle, std::optional logical_element_type = std::nullopt); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 5916376c1061..b86e4bc23ad7 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -1336,8 +1336,9 @@ std::tuple cummax(const XLATensorPtr& input, at::Tensor val = at::empty(shape_, at::TensorOptions().dtype(input->dtype())); at::Tensor idx = at::empty(shape_, at::TensorOptions().dtype(at::kLong)); - return std::make_tuple(input->Create(val, input->GetDevice()), - input->Create(idx, input->GetDevice())); + return std::make_tuple( + GetValueOrThrow(XLATensor::Create(val, input->GetDevice())), + GetValueOrThrow(XLATensor::Create(idx, input->GetDevice()))); } torch::lazy::NodePtr node = torch_xla::MakeNode(input->GetIrValue(), canonical_dim); From 38e0f03796d3c251e9e9f83a5b7f6e9d67e67a60 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Mon, 11 Aug 2025 13:08:04 -0300 Subject: [PATCH 045/133] `cat`: improve error handling and error messages. (#9548) --- test/test_operations.py | 13 +++++++++++++ torch_xla/csrc/aten_xla_type.cpp | 7 ++++--- torch_xla/csrc/tensor_methods.cpp | 23 ++++++++++++++++------- torch_xla/csrc/tensor_methods.h | 5 +++-- 4 files changed, 36 insertions(+), 12 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index 68aa0b6c2c82..9db0364e7fe8 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2473,6 +2473,19 @@ def test_construct_large_tensor_raises_error(self): # OOM is raised when we try to bring data from the device. b.cpu() + def test_cat_raises_error_on_incompatible_shapes(self): + a = torch.rand(2, 2, device=torch_xla.device()) + b = torch.rand(5, 1, device=torch_xla.device()) + + try: + torch.cat([a, b]) + except RuntimeError as e: + expected_error = ( + "cat(): cannot concatenate tensors of shape f32[2,2] with f32[5,1] " + "at dimension 0. Expected shapes to be equal (except at dimension 0) " + "or that either of them was a 1D empty tensor of size (0,).") + self.assertEqual(str(e), expected_error) + class MNISTComparator(nn.Module): diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 64354c893a13..d4d9120fa514 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1314,9 +1314,10 @@ at::Tensor XLANativeFunctions::bmm(const at::Tensor& self, at::Tensor XLANativeFunctions::cat(const at::ITensorListRef& tensors, int64_t dim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::cat(GetValueOrThrow(bridge::GetXlaTensors(tensors)), dim, - at::native::result_type(tensors))); + auto xtensors = GetValueOrThrow(bridge::GetXlaTensors(tensors)); + auto output = GetValueOrThrow( + tensor_methods::cat(xtensors, dim, at::native::result_type(tensors))); + return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor XLANativeFunctions::celu(const at::Tensor& self, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index b86e4bc23ad7..de91e1d1cda4 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -9,6 +9,7 @@ #include #include +#include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "torch_xla/csrc/LazyIr.h" @@ -1160,8 +1161,8 @@ std::vector broadcast_tensors( return tensors.front()->MakeOutputTensors(node); } -XLATensorPtr cat(absl::Span tensors, int64_t dim, - at::ScalarType dtype) { +absl::StatusOr cat( + absl::Span tensors, int64_t dim, at::ScalarType dtype) { // Shape checks for cat: // - If not empty, every tensor shape must be the same. // - Empty tensor passes but is simply ignore in implementation, @@ -1169,9 +1170,10 @@ XLATensorPtr cat(absl::Span tensors, int64_t dim, // - If empty dimension, other dimensions must be the same. // e.g. ([4, 0, 32, 32], [4, 2, 32, 32], dim=1) passes. // ([4, 0, 32, 32], [4, 2, 31, 32], dim=1) throws. - XLA_CHECK_GT(tensors.size(), 0); + ABSL_CHECK(tensors.size() > 0); std::vector values; std::vector shapes; + size_t last_tensor_index; for (size_t i = 0; i < tensors.size(); ++i) { xla::Shape tensor_shape = tensors[i]->shape(); if (tensor_shape.dimensions_size() == 1 && @@ -1181,13 +1183,20 @@ XLATensorPtr cat(absl::Span tensors, int64_t dim, dim = torch::lazy::GetCanonicalDimensionIndex( dim, tensor_shape.dimensions_size()); tensor_shape.DeleteDimension(dim); - if (!shapes.empty()) { - XLA_CHECK(xla::ShapeUtil::CompatibleIgnoringElementType(shapes.back(), - tensor_shape)) - << shapes.back() << " vs. " << tensor_shape; + if (!shapes.empty() && !xla::ShapeUtil::CompatibleIgnoringElementType( + shapes.back(), tensor_shape)) { + auto last_tensor = tensors[last_tensor_index]; + auto tensor = tensors[i]; + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + "cat(): cannot concatenate tensors of shape ", + last_tensor->shape().get().ToString(), " with ", + tensor->shape().get().ToString(), " at dimension ", dim, + ". Expected shapes to be equal (except at dimension ", dim, + ") or that either of them was a 1D empty tensor of size (0,)."))); } shapes.push_back(tensor_shape); values.push_back(tensors[i]->GetIrValue()); + last_tensor_index = i; } if (values.empty()) { return tensors[0]; diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 79f6acd8049d..4f771acb77aa 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -1,6 +1,7 @@ #ifndef XLA_TORCH_XLA_CSRC_TENSOR_METHODS_H_ #define XLA_TORCH_XLA_CSRC_TENSOR_METHODS_H_ +#include "absl/base/nullability.h" #include "torch_xla/csrc/cross_replica_reduces.h" #include "torch_xla/csrc/ops/custom_sharding.h" #include "torch_xla/csrc/runtime/computation_client.h" @@ -307,8 +308,8 @@ XLATensorPtr bmm(const XLATensorPtr& batch1, const XLATensorPtr& batch2); std::vector broadcast_tensors( absl::Span tensors); -XLATensorPtr cat(absl::Span tensors, int64_t dim, - at::ScalarType dtype); +absl::StatusOr cat( + absl::Span tensors, int64_t dim, at::ScalarType dtype); XLATensorPtr cdist_forward(const XLATensorPtr& x1, const XLATensorPtr& x2, double p); From 23158fd559fb834977237bfaa8f2d270f86d28ee Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Mon, 11 Aug 2025 14:13:07 -0300 Subject: [PATCH 046/133] `div`: improve error handling and error messages. (#9549) --- test/test_operations.py | 11 +++++++++++ torch_xla/csrc/aten_xla_type.cpp | 3 ++- torch_xla/csrc/tensor_methods.cpp | 13 ++++++++----- torch_xla/csrc/tensor_methods.h | 2 +- torch_xla/csrc/tensor_ops.cpp | 20 ++++++++++---------- 5 files changed, 32 insertions(+), 17 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index 9db0364e7fe8..4c0395ff2867 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2486,6 +2486,17 @@ def test_cat_raises_error_on_incompatible_shapes(self): "or that either of them was a 1D empty tensor of size (0,).") self.assertEqual(str(e), expected_error) + def test_div_raises_error_on_invalid_rounding_mode(self): + a = torch.rand(2, 2, device=torch_xla.device()) + + try: + torch.div(a, 2, rounding_mode="bad") + except RuntimeError as e: + expected_error = ( + "div(): invalid rounding mode `bad`. Expected it to be either " + "'trunc', 'floor', or be left unspecified.") + self.assertEqual(str(e), expected_error) + class MNISTComparator(nn.Module): diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index d4d9120fa514..4d5286c2b042 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1536,8 +1536,9 @@ at::Tensor XLANativeFunctions::div( TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); at::ScalarType dtype = at::result_type(self, other); auto operands = GetBinaryOperands(self, UnwrapNumber(other, dtype)); - return bridge::AtenFromXlaTensor(tensor_methods::div( + auto output = GetValueOrThrow(tensor_methods::div( operands.first, operands.second, rounding_mode, dtype)); + return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor XLANativeFunctions::div(const at::Tensor& self, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index de91e1d1cda4..ffeff7bab88e 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -1421,9 +1421,10 @@ XLATensorPtr diagonal(const XLATensorPtr& input, int64_t offset, int64_t dim1, input->GetIrValue(), offset, canonical_dim1, canonical_dim2)); } -XLATensorPtr div(const XLATensorPtr& input, const XLATensorPtr& other, - const std::optional& rounding_mode, - std::optional logical_element_type) { +absl::StatusOr div( + const XLATensorPtr& input, const XLATensorPtr& other, + const std::optional& rounding_mode, + std::optional logical_element_type) { at::ScalarType scalar_type = at::typeMetaToScalarType(c10::get_default_dtype()); xla::PrimitiveType input_type = input->shape().get().element_type(); @@ -1446,8 +1447,10 @@ XLATensorPtr div(const XLATensorPtr& input, const XLATensorPtr& other, } else if (*rounding_mode == "floor") { res = torch_xla::MakeNode(res); } else { - XLA_CHECK(false) - << "rounding_mode must be one of None, 'trunc', or 'floor'"; + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError( + absl::StrCat("div(): invalid rounding mode `", *rounding_mode, + "`. Expected it to be either 'trunc', 'floor', or be " + "left unspecified."))); } } diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 4f771acb77aa..395768fc867e 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -389,7 +389,7 @@ XLATensorPtr diag(const XLATensorPtr& input, int64_t offset); XLATensorPtr diagonal(const XLATensorPtr& input, int64_t offset, int64_t dim1, int64_t dim2); -XLATensorPtr div( +absl::StatusOr div( const XLATensorPtr& input, const XLATensorPtr& other, const std::optional& rounding_mode = std::nullopt, std::optional logical_element_type = std::nullopt); diff --git a/torch_xla/csrc/tensor_ops.cpp b/torch_xla/csrc/tensor_ops.cpp index ad8d58e13727..2b925d7c381a 100644 --- a/torch_xla/csrc/tensor_ops.cpp +++ b/torch_xla/csrc/tensor_ops.cpp @@ -148,9 +148,9 @@ XLATensorPtr SmoothL1LossBackward(const XLATensorPtr& grad_output, XLATensorPtr grad_scale = tensor_methods::get_dimensions_size( broadcasted_input, XlaHelpers::GetAllDimensions(broadcasted_input->shape())); - return tensor_methods::mul( - tensor_methods::div(elementwise_loss_backward, grad_scale), - grad_output); + XLATensorPtr div_result = GetValueOrThrow( + tensor_methods::div(elementwise_loss_backward, grad_scale)); + return tensor_methods::mul(div_result, grad_output); } default: XLA_ERROR() << "Invalid reduction type: " @@ -174,12 +174,12 @@ XLATensorPtr SoftplusBackward(const XLATensorPtr& grad_output, XLATensorPtr z = tensor_methods::exp(scaled_input); XLATensorPtr one_vec = tensor_methods::full_like(z, 1, z->GetDevice(), z->dtype()); + XLATensorPtr div = GetValueOrThrow( + tensor_methods::div(z, tensor_methods::add(z, one_vec, 1))); - return tensor_methods::where( - tensor_methods::gt(scaled_input, threshold), grad_output, - tensor_methods::mul( - grad_output, - tensor_methods::div(z, tensor_methods::add(z, one_vec, 1)))); + return tensor_methods::where(tensor_methods::gt(scaled_input, threshold), + grad_output, + tensor_methods::mul(grad_output, div)); } XLATensorPtr Select(const XLATensorPtr& input, int64_t dim, int64_t index) { @@ -223,8 +223,8 @@ XLATensorPtr EmbeddingDenseBackward(const XLATensorPtr& grad_output, XLATensorPtr grad_weights_scale = tensor_methods::index(counts, {indices_rank1}, 0); // Scale the value of the gradient by the histogram. - grad = tensor_methods::div( - grad, tensor_methods::unsqueeze(grad_weights_scale, 1)); + grad = GetValueOrThrow(tensor_methods::div( + grad, tensor_methods::unsqueeze(grad_weights_scale, 1))); } // Don't accumulate gradients for indices which are equal with the given // padding_idx. From 1f787f1114a6e9247be1dde8d57372ce9dafd5d6 Mon Sep 17 00:00:00 2001 From: qihqi Date: Mon, 11 Aug 2025 16:25:50 -0700 Subject: [PATCH 047/133] Bug fixes (#9554) --- torchax/test/test_mutations.py | 34 +++++++++++++++++++++++++--------- torchax/torchax/ops/jaten.py | 3 ++- torchax/torchax/tensor.py | 5 ++--- 3 files changed, 29 insertions(+), 13 deletions(-) diff --git a/torchax/test/test_mutations.py b/torchax/test/test_mutations.py index ab23623a7cf3..ccbc359485c8 100644 --- a/torchax/test/test_mutations.py +++ b/torchax/test/test_mutations.py @@ -8,28 +8,44 @@ class TestMutations(TestCase): def setUp(self): self.env = torchax.tensor.Environment() + self.env.config.debug_print_each_op = True def test_add(self): with self.env: - x = torch.tensor([1, 2, 3], dtype=torch.int32) - y = torch.tensor([4, 5, 6], dtype=torch.int32) + x = torch.tensor([1, 2, 3], device='jax', dtype=torch.int32) + y = torch.tensor([4, 5, 6], device='jax', dtype=torch.int32) x.add_(y) - self.assertEqual(x, torch.tensor([5, 7, 9], dtype=torch.int32)) + torch.testing.assert_close(x.cpu(), + torch.tensor([5, 7, 9], dtype=torch.int32)) def test_sub(self): with self.env: - x = torch.tensor([1, 2, 3], dtype=torch.int32) - y = torch.tensor([4, 5, 6], dtype=torch.int32) + x = torch.tensor([1, 2, 3], device='jax', dtype=torch.int32) + y = torch.tensor([4, 5, 6], device='jax', dtype=torch.int32) x.sub_(y) - self.assertEqual(x, torch.tensor([-3, -3, -3], dtype=torch.int32)) + torch.testing.assert_close(x.cpu(), + torch.tensor([-3, -3, -3], dtype=torch.int32)) def test_mul(self): with self.env: - x = torch.tensor([1, 2, 3], dtype=torch.int32) - y = torch.tensor([4, 5, 6], dtype=torch.int32) + x = torch.tensor([1, 2, 3], device='jax', dtype=torch.int32) + y = torch.tensor([4, 5, 6], device='jax', dtype=torch.int32) x.mul_(y) - self.assertEqual(x, torch.tensor([4, 10, 18], dtype=torch.int32)) + torch.testing.assert_close(x.cpu(), + torch.tensor([4, 10, 18], dtype=torch.int32)) + + def test_index_copy(self): + with self.env: + x = torch.zeros(5, 3, device='jax') + t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], + device='jax', + dtype=torch.float) + index = torch.tensor([0, 4, 2], device='jax') + x.index_copy_(0, index, t) + expected = torch.tensor([[1., 2., 3.], [0., 0., 0.], [7., 8., 9.], + [0., 0., 0.], [4., 5., 6.]]) + torch.testing.assert_close(x.cpu(), expected) if __name__ == '__main__': diff --git a/torchax/torchax/ops/jaten.py b/torchax/torchax/ops/jaten.py index 711b4bbe8b06..851a2d6103ef 100644 --- a/torchax/torchax/ops/jaten.py +++ b/torchax/torchax/ops/jaten.py @@ -736,7 +736,6 @@ def _aten_empty_strided(sizes, stride, dtype=None, **kwargs): return jnp.empty(sizes, dtype=dtype) -@op(torch.ops.aten.index_put_) @op(torch.ops.aten.index_put) def _aten_index_put(self, indexes, values, accumulate=False): indexes = [slice(None, None, None) if i is None else i for i in indexes] @@ -5618,6 +5617,8 @@ def _aten__assert_tensor_metadata(*args, **kwargs): op_base.InplaceOp(torch.ops.aten.floor_divide), torch.ops.aten.remainder_: op_base.InplaceOp(torch.ops.aten.remainder), + torch.ops.aten.index_put_: + op_base.InplaceOp(torch.ops.aten.index_put), } # Note: tuple comparisons work intuitively, e.g. `_jax_version >= (0, 4, 32)`. diff --git a/torchax/torchax/tensor.py b/torchax/torchax/tensor.py index 3916fe6501b8..67bc074177ef 100644 --- a/torchax/torchax/tensor.py +++ b/torchax/torchax/tensor.py @@ -70,9 +70,6 @@ def __str__(self): __repr__ = __str__ - def __jax_array__(self): - return self._elem - @property def shape(self): return torch.Size(self._elem.shape) @@ -494,6 +491,8 @@ def _handle_tensor_constructor(self, func, args, kwargs): op = self._get_op_or_decomp(func) if op.needs_env: kwargs['env'] = self + if op.is_jax_function: + (args, kwargs) = self.t2j_iso((args, kwargs)) res = op.func(*args, **kwargs) if isinstance(res, jax.Array): res = Tensor(res, self, requires_grad) From f400690edc383ccc6e71e9a0c1aedbc85c9c1c7d Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Tue, 12 Aug 2025 11:46:52 -0700 Subject: [PATCH 048/133] Run torchprime CI only when the pull requests have torchprimeci label (#9551) --- .github/workflows/build_and_test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index f49b856b565e..b990f43f6971 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -83,7 +83,7 @@ jobs: with: timeout-minutes: 100 has_code_changes: ${{ needs.check_code_changes.outputs.has_code_changes }} - if: github.event_name == 'push' || github.event_name == 'pull_request' + if: contains(github.event.pull_request.labels.*.name, 'torchprimeci') secrets: inherit push-docs: From c8c9776b1855f7435428244ab0192f23b097ff18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hinrik=20Sn=C3=A6r=20Gu=C3=B0mundsson?= Date: Wed, 13 Aug 2025 18:04:20 -0400 Subject: [PATCH 049/133] [Documentation] Fixed typo in C++ debugging docs (#9559) --- docs/source/contribute/cpp_debugger.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/contribute/cpp_debugger.md b/docs/source/contribute/cpp_debugger.md index 855aeecda277..12d82a61d985 100644 --- a/docs/source/contribute/cpp_debugger.md +++ b/docs/source/contribute/cpp_debugger.md @@ -36,7 +36,7 @@ Debugging options are described in [setup.py on GitHub](https://github.com/pytor When defined, the `DEBUG` will cause the build process to generate debug symbols for all source files. It will also prevent the compiler from performing any -optimizations, which will cause the geenrated binary to run too slow to perform +optimizations, which will cause the generated binary to run too slow to perform meaningful work. We recommend using the `USE_CUSTOM_DEBINFO` environment variable to specify a semicolon separated list of source files for which you want to generate debug symbols. This lets you generate debug symbols for only the source From 40f58a6a78b220ffd769f27cf67255361a5fa483 Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Wed, 13 Aug 2025 16:48:23 -0700 Subject: [PATCH 050/133] Update README.md to mention 2.8 release (#9560) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 84bac8db9743..8e31f4800c2a 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ Note: Builds are available for Python 3.8 to 3.11; please use one of the support # - for conda # conda create -n py311 python=3.11 -pip install torch==2.7.0 'torch_xla[tpu]==2.7.0' +pip install torch==2.8.0 'torch_xla[tpu]==2.8.0' # Optional: if you're using custom kernels, install pallas dependencies pip install 'torch_xla[pallas]' From d5b9a6d9116580830e38ca42fbb4a1d21e1d5ebb Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 14 Aug 2025 10:19:01 -0300 Subject: [PATCH 051/133] `flip`: improve error handling and error messages. (#9550) --- test/test_operations.py | 14 ++++++ torch_xla/csrc/aten_xla_type.cpp | 6 ++- torch_xla/csrc/tensor_methods.cpp | 75 ++++++++++++++++++++++++++++--- torch_xla/csrc/tensor_methods.h | 3 +- 4 files changed, 89 insertions(+), 9 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index 4c0395ff2867..cb790a074148 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2497,6 +2497,20 @@ def test_div_raises_error_on_invalid_rounding_mode(self): "'trunc', 'floor', or be left unspecified.") self.assertEqual(str(e), expected_error) + def test_flip_raises_error_on_duplicated_dims(self): + a = torch.rand(2, 2, 2, 2, device=torch_xla.device()) + dims = [0, 0, 0, 1, 2, 3, -1] + dims_suggestion = [0, 1, 2, 3] + + try: + torch.flip(a, dims=dims) + except RuntimeError as e: + expected_error = ( + "flip(): expected each dimension to appear at most once. Found " + "dimensions: 0 (3 times), 3 (2 times). Consider changing dims " + f"from {dims} to {dims_suggestion}.") + self.assertEqual(str(e), expected_error) + class MNISTComparator(nn.Module): diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 4d5286c2b042..7d8ba3520592 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1804,8 +1804,10 @@ at::Tensor& XLANativeFunctions::fill_(at::Tensor& self, at::Tensor XLANativeFunctions::flip(const at::Tensor& self, at::IntArrayRef dims) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::flip( - GetValueOrThrow(bridge::GetXlaTensor(self)), XlaHelpers::I64List(dims))); + auto xself = GetValueOrThrow(bridge::GetXlaTensor(self)); + auto output = + GetValueOrThrow(tensor_methods::flip(xself, XlaHelpers::I64List(dims))); + return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor XLANativeFunctions::floor_divide(const at::Tensor& self, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index ffeff7bab88e..4a749d50ac79 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -8,6 +8,7 @@ #include #include +#include #include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" @@ -345,6 +346,69 @@ XLATensorPtr DispatchComparisonOp(c10::Symbol kind, const XLATensorPtr& input, return XLATensor::Create(node, input->GetDevice(), at::ScalarType::Bool); } +// Checks that the canonical dimensions out of the given dimensions are unique +// for the `flip` operation. +// +// This function fails if any canonical dimension appears more than once. +// Notice that its error message is specialized for the `flip` operation. +// +// @param rank Input rank +// @param dims (Error Message) `flip` operation original `dims` argument +// @param canonical_dims (Error Message) Canonical dimensions extracted from +// the `dims` argument +absl::Status CheckFlipDimensionsAreUnique( + int64_t rank, absl::Span dims, + absl::Span canonical_dims) { + // Counter that maps each given dimension to the number of times it has + // appeared. + std::vector count(rank, 0); + + // Count the number of times each dimension appears. + for (auto dim : canonical_dims) { + count[dim] += 1; + } + + bool any_dimension_appears_more_than_once = std::any_of( + count.begin(), count.end(), [](const auto n) { return n > 1; }); + + if (any_dimension_appears_more_than_once) { + // Suggestion for the value of dims that wouldn't raise an error. + std::vector dims_suggestion; + // Each "bad" dimension is represented as a string of the form: + // + // ( times) + // + // To be later joined with commas. + std::vector bad_count_str; + + // Iterates each dimension, populating both `dims_suggestion` and + // `bad_count_str`. + for (int64_t i : c10::irange(rank)) { + // Dimension does not appear. Do nothing. + if (count[i] == 0) { + continue; + } + + // Dimension appears in `dims`. Add it to the suggestion list. + dims_suggestion.push_back(i); + + // Dimension appears more than once. Add it to the "bad" list. + if (count[i] > 1) { + bad_count_str.push_back(absl::StrCat(i, " (", count[i], " times)")); + } + } + + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + "flip(): expected each dimension to appear at most once. Found " + "dimensions: ", + absl::StrJoin(bad_count_str, /* sep= */ ", "), + ". Consider changing dims from [", absl::StrJoin(dims, /* sep= */ ", "), + "] to [", absl::StrJoin(dims_suggestion, /* sep= */ ", "), "]."))); + } + + return absl::OkStatus(); +} + } // namespace ////////////////////////////////////////////////////////////////////////////// @@ -1680,12 +1744,11 @@ void fill_(XLATensorPtr& input, const at::Scalar& value) { input->SetInPlaceIrValue(std::move(constant)); } -XLATensorPtr flip(const XLATensorPtr& input, absl::Span dims) { - auto dimensions = torch::lazy::GetCanonicalDimensionIndices( - torch_xla::runtime::util::ToVector(dims), - input->shape().get().dimensions_size()); - std::set unique_dims(dimensions.begin(), dimensions.end()); - XLA_CHECK_EQ(unique_dims.size(), dimensions.size()); +absl::StatusOr flip(const XLATensorPtr& input, + absl::Span dims) { + auto rank = input->shape().get().dimensions_size(); + auto dimensions = torch::lazy::GetCanonicalDimensionIndices(dims, rank); + XLA_RETURN_IF_ERROR(CheckFlipDimensionsAreUnique(rank, dims, dimensions)); return input->CreateFrom( torch_xla::MakeNode(input->GetIrValue(), dimensions)); } diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 395768fc867e..fb7eae93f8db 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -450,7 +450,8 @@ void eye_out(XLATensorPtr& out, int64_t lines, int64_t cols); void fill_(XLATensorPtr& input, const at::Scalar& value); // Flips (reverses) the values in the dimensions of the input tensor. -XLATensorPtr flip(const XLATensorPtr& input, absl::Span dims); +absl::StatusOr flip(const XLATensorPtr& input, + absl::Span dims); XLATensorPtr fmod( const XLATensorPtr& input, const XLATensorPtr& other, From 2c343188f33c71939a3e20cd504746ef70244211 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 14 Aug 2025 15:47:48 -0300 Subject: [PATCH 052/133] Generalize crash message for non-ok status. (#9552) --- torch_xla/csrc/init_python_bindings.cpp | 14 +++---------- torch_xla/csrc/runtime/debug_macros.h | 1 - torch_xla/csrc/status.cpp | 26 +++++++++++++++++++++++++ torch_xla/csrc/status.h | 26 +++++++++++++++++++++++++ 4 files changed, 55 insertions(+), 12 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index a2b799a5f0e7..840bd555591d 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -3377,17 +3377,9 @@ void InitXlaModuleBindings(py::module m) { [](const std::vector& tensors) -> py::bytes { absl::StatusOr> xtensors_status = bridge::GetXlaTensors(tensors); - ABSL_CHECK(xtensors_status.ok()) - << "\n\n" - << "Internal Error:\n" - << " _get_graph_hash(): error retrieving the XLA tensors " - "from the given tensor arguments. " - << "This is a bug! Please, open an issue in the PyTorch/XLA " - << "GitHub repository: https://github.com/pytorch/xla" - << "\n\n" - << "Status Error:\n" - << " " << BuildStatusErrorMessage(xtensors_status.status()) - << "\n"; + XLA_CHECK_OK(xtensors_status, + "_get_graph_hash(): error retrieving the XLA tensors " + "from the given tensor arguments."); std::vector xtensors = xtensors_status.value(); torch::lazy::hash_t hash = diff --git a/torch_xla/csrc/runtime/debug_macros.h b/torch_xla/csrc/runtime/debug_macros.h index 6ebdafea6fb2..19e958a792ee 100644 --- a/torch_xla/csrc/runtime/debug_macros.h +++ b/torch_xla/csrc/runtime/debug_macros.h @@ -28,7 +28,6 @@ // unnecessary or undesirable. #define XLA_ERROR() TF_ERROR_STREAM() #define XLA_CHECK(c) TF_CHECK(c) -#define XLA_CHECK_OK(c) TF_CHECK_OK(c) #define XLA_CHECK_EQ(a, b) TF_CHECK_EQ(a, b) #define XLA_CHECK_NE(a, b) TF_CHECK_NE(a, b) #define XLA_CHECK_LE(a, b) TF_CHECK_LE(a, b) diff --git a/torch_xla/csrc/status.cpp b/torch_xla/csrc/status.cpp index 2e1c7002e897..6f70ecf20e15 100644 --- a/torch_xla/csrc/status.cpp +++ b/torch_xla/csrc/status.cpp @@ -124,4 +124,30 @@ void MaybeThrow(const absl::Status& status) { void GetValueOrThrow(const absl::Status& status) { MaybeThrow(status); } +void OkOrDie(const absl::Status& status, const char* file, const int32_t line, + const char* function, std::string_view message) { + if (status.ok()) { + return; + } + + std::ostringstream oss; + oss << "\n\n" + << "Internal Error:\n"; + + if (!message.empty()) { + oss << " " << message << "\n"; + } + + oss << " This is a bug! Please, open an issue in the PyTorch/XLA " + << "GitHub repository: https://github.com/pytorch/xla" + << "\n\n" + << "Status Error:\n" + << " " + << BuildStatusErrorMessage( + status_internal::MaybeWithNewMessage(status, file, line, function)) + << "\n"; + + ABSL_CHECK(status.ok()) << oss.str(); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/status.h b/torch_xla/csrc/status.h index b2d508076a3b..2dcde0ec7bd0 100644 --- a/torch_xla/csrc/status.h +++ b/torch_xla/csrc/status.h @@ -10,6 +10,8 @@ #ifndef XLA_TORCH_XLA_CSRC_STATUS_H_ #define XLA_TORCH_XLA_CSRC_STATUS_H_ +#include + #include "absl/status/statusor.h" namespace torch_xla { @@ -125,6 +127,22 @@ constexpr char kStatusPropagationTraceKey[] = lhs = std::move(XLA_STATUS_VAR_).value(), \ ##__VA_ARGS__) +// Crashes if `status` is not an ok status. +// +// Example: +// +// XLA_CHECK_OK( +// FnThatReturnStatus(), +// "New error message" +// ); +// +// If `FnThatReturnStatus()` returns a non-ok status, this macro will +// call `ABSL_CHECK()`, which will crash. +// +#define XLA_CHECK_OK(status, ...) \ + ::torch_xla::OkOrDie(::torch_xla::status_internal::GetStatus(status), \ + __FILE__, __LINE__, __FUNCTION__, ##__VA_ARGS__) + namespace status_internal { // Adds source location information to the status propagation trace if @@ -211,6 +229,14 @@ T GetValueOrThrow(absl::StatusOr&& status) { // `GetValueOrThrow` overload for `Status`. void GetValueOrThrow(const absl::Status& status); +// Checks that `status` is an ok status. +// +// Otherwise, it will create a new status instance with the given source +// location information, and incorporate its message (alongside the +// status propagation trace) to the crash report. +void OkOrDie(const absl::Status& status, const char* file, const int32_t line, + const char* function, std::string_view message = ""); + } // namespace torch_xla #endif // XLA_TORCH_XLA_CSRC_STATUS_H_ From 4199865482d4055ba6418e20b89dd1a25fc7f1d3 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Sat, 16 Aug 2025 10:09:53 -0300 Subject: [PATCH 053/133] Rename `MaybeThrow` to `OkOrThrow`. (#9561) --- test/cpp/test_status_common.h | 14 +++++------ torch_xla/csrc/aten_xla_type.cpp | 4 +-- torch_xla/csrc/dl_convertor.cpp | 2 +- torch_xla/csrc/init_python_bindings.cpp | 2 +- torch_xla/csrc/runtime/tensor_source.h | 2 +- torch_xla/csrc/status.cpp | 9 ++++--- torch_xla/csrc/status.h | 33 +++++++++++++------------ 7 files changed, 34 insertions(+), 32 deletions(-) diff --git a/test/cpp/test_status_common.h b/test/cpp/test_status_common.h index 5cf8285f5ebf..cb917942ffe3 100644 --- a/test/cpp/test_status_common.h +++ b/test/cpp/test_status_common.h @@ -81,7 +81,7 @@ namespace cpp_test { // Prefix of the C++ stacktrace PyTorch adds to the error message. constexpr inline char kTorchCppStacktracePrefix[] = - "Exception raised from MaybeThrow at torch_xla/csrc/status.cpp:"; + "Exception raised from OkOrThrow at torch_xla/csrc/status.cpp:"; constexpr inline char kNewMessage[] = "New test error message"; constexpr inline char kMessage[] = "Test error message"; @@ -100,15 +100,15 @@ inline std::string GetStatusPropagationTrace(const absl::Status& status) { : ""; } -TEST_P(StatusTest, MaybeThrowWithOkStatus) { +TEST_P(StatusTest, OkOrThrowWithOkStatus) { absl::Status ok_status = absl::OkStatus(); - EXPECT_NO_THROW(MaybeThrow(ok_status)); + EXPECT_NO_THROW(OkOrThrow(ok_status)); } -TEST_P(StatusTest, MaybeThrowWithErrorStatus) { +TEST_P(StatusTest, OkOrThrowWithErrorStatus) { try { absl::Status error_status = absl::InvalidArgumentError(kMessage); - MaybeThrow(error_status); + OkOrThrow(error_status); } catch (const c10::Error& error) { if (IsShowCppStacktracesMode()) { EXPECT_THAT(std::string_view(error.what()), @@ -343,7 +343,7 @@ TEST_P(StatusTest, MacroErrorWithLocation) { } } -TEST_P(StatusTest, MaybeThrowWithErrorPropagationWithNewMessage) { +TEST_P(StatusTest, OkOrThrowWithErrorPropagationWithNewMessage) { int32_t errline0 = __LINE__ + 2; auto innerfn = [&]() -> absl::Status { return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(kMessage)); @@ -362,7 +362,7 @@ TEST_P(StatusTest, MaybeThrowWithErrorPropagationWithNewMessage) { }; try { - MaybeThrow(outerfn()); + OkOrThrow(outerfn()); } catch (const c10::Error& error) { if (IsShowCppStacktracesMode()) { // Expected Error Message Prefix diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 7d8ba3520592..005e0e98dcc7 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -664,7 +664,7 @@ at::Tensor XLANativeFunctions::_copy_from(const at::Tensor& self, } else { auto dst_tensor = std::move(dst_tensor_status).value(); tensor_methods::copy_(dst_tensor, self_tensor_status.value()); - MaybeThrow(bridge::ReplaceXlaTensor(dst, dst_tensor)); + OkOrThrow(bridge::ReplaceXlaTensor(dst, dst_tensor)); } return dst; } @@ -3438,7 +3438,7 @@ at::Tensor& XLANativeFunctions::set_(at::Tensor& self, const at::Tensor& source) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr source_tensor = GetValueOrThrow(bridge::GetXlaTensor(source)); - MaybeThrow(bridge::ReplaceXlaTensor(self, source_tensor)); + OkOrThrow(bridge::ReplaceXlaTensor(self, source_tensor)); return self; } diff --git a/torch_xla/csrc/dl_convertor.cpp b/torch_xla/csrc/dl_convertor.cpp index 9adb63747dcb..638bcdbff67b 100644 --- a/torch_xla/csrc/dl_convertor.cpp +++ b/torch_xla/csrc/dl_convertor.cpp @@ -144,7 +144,7 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { pack->external_reference = GetValueOrThrow(pjrt_buffer->AcquireExternalReference()); xla::PjRtFuture<> future = pjrt_buffer->GetReadyFuture(); - MaybeThrow(future.Await()); + OkOrThrow(future.Await()); } pack->buffer_reference = pjrt_buffer; diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 840bd555591d..a2099f7d4ec1 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -429,7 +429,7 @@ void AllReduceInPlace(const std::string& reduce_type, replica_groups, pin_layout); std::vector new_xtensors = GetValueOrThrow(bridge::GetXlaTensors(tensors)); - MaybeThrow(bridge::ReplaceXlaTensor(tensors, new_xtensors)); + OkOrThrow(bridge::ReplaceXlaTensor(tensors, new_xtensors)); } at::Tensor AllReduce(const std::string& reduce_type, const at::Tensor& input, diff --git a/torch_xla/csrc/runtime/tensor_source.h b/torch_xla/csrc/runtime/tensor_source.h index 280bc4f83484..cc8e646eee75 100644 --- a/torch_xla/csrc/runtime/tensor_source.h +++ b/torch_xla/csrc/runtime/tensor_source.h @@ -31,7 +31,7 @@ class TensorSource { virtual std::vector byte_strides() const { std::vector byte_strides(shape().dimensions_size()); - MaybeThrow( + OkOrThrow( xla::ShapeUtil::ByteStrides(shape(), absl::MakeSpan(byte_strides))); return byte_strides; } diff --git a/torch_xla/csrc/status.cpp b/torch_xla/csrc/status.cpp index 6f70ecf20e15..dc9892d7d572 100644 --- a/torch_xla/csrc/status.cpp +++ b/torch_xla/csrc/status.cpp @@ -117,15 +117,16 @@ static std::string LineBreakIfCppStacktracesEnabled() { return torch::get_cpp_stacktraces_enabled() ? "\n" : ""; } -void MaybeThrow(const absl::Status& status) { +void OkOrThrow(const absl::Status& status) { TORCH_CHECK(status.ok(), absl::StrCat(BuildStatusErrorMessage(status), LineBreakIfCppStacktracesEnabled())); } -void GetValueOrThrow(const absl::Status& status) { MaybeThrow(status); } +void GetValueOrThrow(const absl::Status& status) { OkOrThrow(status); } -void OkOrDie(const absl::Status& status, const char* file, const int32_t line, - const char* function, std::string_view message) { +void status_internal::OkOrDie(const absl::Status& status, const char* file, + const int32_t line, const char* function, + std::string_view message) { if (status.ok()) { return; } diff --git a/torch_xla/csrc/status.h b/torch_xla/csrc/status.h index 2dcde0ec7bd0..28f16860479a 100644 --- a/torch_xla/csrc/status.h +++ b/torch_xla/csrc/status.h @@ -139,9 +139,10 @@ constexpr char kStatusPropagationTraceKey[] = // If `FnThatReturnStatus()` returns a non-ok status, this macro will // call `ABSL_CHECK()`, which will crash. // -#define XLA_CHECK_OK(status, ...) \ - ::torch_xla::OkOrDie(::torch_xla::status_internal::GetStatus(status), \ - __FILE__, __LINE__, __FUNCTION__, ##__VA_ARGS__) +#define XLA_CHECK_OK(status, ...) \ + ::torch_xla::status_internal::OkOrDie( \ + ::torch_xla::status_internal::GetStatus(status), __FILE__, __LINE__, \ + __FUNCTION__, ##__VA_ARGS__) namespace status_internal { @@ -190,6 +191,14 @@ absl::Status MaybeWithNewMessage(const absl::Status& status, const char* file, int32_t line, const char* function, std::string_view new_message = ""); +// Checks that `status` is an ok status. +// +// Otherwise, it will create a new status instance with the given source +// location information, and incorporate its message (alongside the +// status propagation trace) to the crash report. +void OkOrDie(const absl::Status& status, const char* file, const int32_t line, + const char* function, std::string_view message = ""); + } // namespace status_internal // Builds the complete error message for the given `status`. @@ -200,43 +209,35 @@ absl::Status MaybeWithNewMessage(const absl::Status& status, const char* file, // It doesn't add a trailing line break. std::string BuildStatusErrorMessage(const absl::Status& status); -// Maybe throws an exception if `status` has a non-ok code. +// Throws an exception if `status` has a non-ok code. // // Ideally, this function should be used only used in the project's // boundary, e.g. when we need to throw an exception for the user to see. -void MaybeThrow(const absl::Status& status); +void OkOrThrow(const absl::Status& status); // Either returns the value `status` holds, if it's an ok-status, or throw an // exception from its error status. template T& GetValueOrThrow(absl::StatusOr& status) { - MaybeThrow(status.status()); + OkOrThrow(status.status()); return status.value(); } template const T& GetValueOrThrow(const absl::StatusOr& status) { - MaybeThrow(status.status()); + OkOrThrow(status.status()); return status.value(); } template T GetValueOrThrow(absl::StatusOr&& status) { - MaybeThrow(status.status()); + OkOrThrow(status.status()); return std::move(status).value(); } // `GetValueOrThrow` overload for `Status`. void GetValueOrThrow(const absl::Status& status); -// Checks that `status` is an ok status. -// -// Otherwise, it will create a new status instance with the given source -// location information, and incorporate its message (alongside the -// status propagation trace) to the crash report. -void OkOrDie(const absl::Status& status, const char* file, const int32_t line, - const char* function, std::string_view message = ""); - } // namespace torch_xla #endif // XLA_TORCH_XLA_CSRC_STATUS_H_ From a1c6ee92c85e8b0955c20892ed68f032a6015c09 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Sat, 16 Aug 2025 16:59:35 -0700 Subject: [PATCH 054/133] Add xla random generator. (#9539) This is the very first PR for https://github.com/pytorch/xla/issues/9159. It purely add the generator without any utilization of it. https://github.com/pytorch/xla/issues/9159#issuecomment-2994942717 comment outlines the steps for entire change. --- .github/scripts/run_tests.sh | 1 + BUILD | 7 +- test/cpp/BUILD | 12 ++++ test/cpp/run_tests.sh | 1 + test/cpp/test_xla_generator.cpp | 106 +++++++++++++++++++++++++++++++ torch_xla/csrc/BUILD | 2 + torch_xla/csrc/xla_generator.cpp | 84 ++++++++++++++++++++++++ torch_xla/csrc/xla_generator.h | 56 ++++++++++++++++ 8 files changed, 266 insertions(+), 3 deletions(-) create mode 100644 test/cpp/test_xla_generator.cpp create mode 100644 torch_xla/csrc/xla_generator.cpp create mode 100644 torch_xla/csrc/xla_generator.h diff --git a/.github/scripts/run_tests.sh b/.github/scripts/run_tests.sh index d685cc40ee49..ccdc0b5e3d70 100755 --- a/.github/scripts/run_tests.sh +++ b/.github/scripts/run_tests.sh @@ -55,6 +55,7 @@ function run_torch_xla_cpp_tests() { "test_tensor" # disable test_xla_backend_intf since it is flaky on upstream #"test_xla_backend_intf" + "test_xla_generator" "test_xla_sharding" "test_runtime" "test_status_dont_show_cpp_stacktraces" diff --git a/BUILD b/BUILD index ee4fa07844ac..900dfa4bc3b2 100644 --- a/BUILD +++ b/BUILD @@ -72,15 +72,16 @@ test_suite( "//test/cpp:test_aten_xla_tensor_4", "//test/cpp:test_aten_xla_tensor_5", "//test/cpp:test_aten_xla_tensor_6", + "//test/cpp:test_debug_macros", "//test/cpp:test_ir", "//test/cpp:test_lazy", "//test/cpp:test_replication", - "//test/cpp:test_tensor", - "//test/cpp:test_xla_sharding", "//test/cpp:test_runtime", "//test/cpp:test_status_dont_show_cpp_stacktraces", "//test/cpp:test_status_show_cpp_stacktraces", - "//test/cpp:test_debug_macros", + "//test/cpp:test_tensor", + "//test/cpp:test_xla_generator", + "//test/cpp:test_xla_sharding", "//torch_xla/csrc/runtime:pjrt_computation_client_test", # "//torch_xla/csrc/runtime:ifrt_computation_client_test", ], diff --git a/test/cpp/BUILD b/test/cpp/BUILD index 483d5ef7c01e..00568e8573f5 100644 --- a/test/cpp/BUILD +++ b/test/cpp/BUILD @@ -204,3 +204,15 @@ ptxla_cc_test( "@com_google_googletest//:gtest_main", ], ) + +ptxla_cc_test( + name = "test_xla_generator", + srcs = ["test_xla_generator.cpp"], + deps = [ + ":cpp_test_util", + ":torch_xla_test", + "//torch_xla/csrc:tensor", + "//torch_xla/csrc:aten_cuda_functions", + "@com_google_googletest//:gtest_main", + ], +) \ No newline at end of file diff --git a/test/cpp/run_tests.sh b/test/cpp/run_tests.sh index 8c3fea6bcdc8..2da0ccb55699 100755 --- a/test/cpp/run_tests.sh +++ b/test/cpp/run_tests.sh @@ -100,6 +100,7 @@ if [[ "$RUN_CPP_TESTS" == "cpp_tests" ]]; then # disable test_xla_backend_intf since it is flaky on upstream #"test_xla_backend_intf" "test_xla_sharding" + "test_xla_generator" "test_runtime" "test_status_dont_show_cpp_stacktraces" "test_status_show_cpp_stacktraces" diff --git a/test/cpp/test_xla_generator.cpp b/test/cpp/test_xla_generator.cpp new file mode 100644 index 000000000000..d45991f72d39 --- /dev/null +++ b/test/cpp/test_xla_generator.cpp @@ -0,0 +1,106 @@ +#include +#include + +#include "test/cpp/torch_xla_test.h" +#include "torch_xla/csrc/xla_generator.h" + +namespace torch_xla { +namespace cpp_test { + +// Test fixture for XLAGenerator tests +class XLAGeneratorTest : public ::torch_xla::cpp_test::TorchXlaTest { + protected: + void SetUp() { + // Create a generator for XLA device 0 + gen_ = at::make_generator(0); + } + + at::Generator gen_; +}; + +TEST_F(XLAGeneratorTest, Constructor) { + // Check that the generator was created for the correct device + ASSERT_EQ(gen_.device().type(), at::DeviceType::XLA); + ASSERT_EQ(gen_.device().index(), 0); + + // Check that the initial seed is 0 + ASSERT_EQ(gen_.current_seed(), 0); +} + +TEST_F(XLAGeneratorTest, Seed) { + // Test setting and getting the current seed + uint64_t seed_val = 12345; + gen_.set_current_seed(seed_val); + ASSERT_EQ(gen_.current_seed(), seed_val); + + // Test the seed() method, which should set a non-deterministic seed + uint64_t old_seed = gen_.current_seed(); + uint64_t new_seed = gen_.seed(); + // The new seed should be different from the old one and set as the current + // seed + ASSERT_NE(new_seed, old_seed); + ASSERT_EQ(gen_.current_seed(), new_seed); +} + +TEST_F(XLAGeneratorTest, GetAndSetState) { + uint64_t seed_val = 98765; + uint64_t offset_val = 0; + + // Set seed and offset on the original generator + gen_.set_current_seed(seed_val); + gen_.set_offset(offset_val); + + // Get the state from the original generator + at::Tensor state_tensor = gen_.get_state(); + + // Create a new generator + auto new_gen = at::make_generator(1); + ASSERT_NE(new_gen.current_seed(), seed_val); + + // Set the state of the new generator + new_gen.set_state(state_tensor); + + // Verify the state of the new generator + ASSERT_EQ(new_gen.current_seed(), seed_val); + ASSERT_EQ(new_gen.get_offset(), offset_val); +} + +TEST_F(XLAGeneratorTest, SetStateValidation) { + // Test that set_state throws with incorrect tensor properties + auto new_gen = at::make_generator(0); + + // Incorrect size + auto wrong_size_tensor = at::empty({10}, at::kByte); + EXPECT_THROW(new_gen.set_state(wrong_size_tensor), c10::Error); + + // Incorrect dtype + auto wrong_dtype_tensor = at::empty({16}, at::kInt); + EXPECT_THROW(new_gen.set_state(wrong_dtype_tensor), c10::Error); +} + +TEST_F(XLAGeneratorTest, Clone) { + uint64_t seed_val = 1; + uint64_t offset_val = 0; + + // Set state on the original generator + gen_.set_current_seed(seed_val); + gen_.set_offset(offset_val); + + // Clone the generator + auto cloned_gen = gen_.clone(); + + // Verify that the cloned generator has the same state but is a different + // object + ASSERT_NE(std::addressof(cloned_gen), std::addressof(gen_)); + ASSERT_EQ(cloned_gen.device(), gen_.device()); + ASSERT_EQ(cloned_gen.current_seed(), gen_.current_seed()); + ASSERT_EQ(cloned_gen.get_offset(), offset_val); + + // Modify the original generator's seed and check that the clone is unaffected + gen_.set_current_seed(9999); + ASSERT_EQ(cloned_gen.current_seed(), seed_val); + ASSERT_NE(cloned_gen.current_seed(), gen_.current_seed()); +} + +} // namespace cpp_test +} // namespace torch_xla \ No newline at end of file diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 8f2a1bdc67cb..f99dca0a74ef 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -64,6 +64,7 @@ ptxla_cc_library( "torch_util.cpp", "view.cpp", "xla_backend_impl.cpp", + "xla_generator.cpp", "xla_graph_executor.cpp", "xla_lower_util.cpp", "xla_op_builder.cpp", @@ -107,6 +108,7 @@ ptxla_cc_library( "torch_util.h", "view.h", "xla_backend_impl.h", + "xla_generator.h", "xla_graph_executor.h", "xla_lower_util.h", "xla_op_builder.h", diff --git a/torch_xla/csrc/xla_generator.cpp b/torch_xla/csrc/xla_generator.cpp new file mode 100644 index 000000000000..5d0a7c15866b --- /dev/null +++ b/torch_xla/csrc/xla_generator.cpp @@ -0,0 +1,84 @@ +#include "xla_generator.h" + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at { + +XLAGeneratorImpl::XLAGeneratorImpl(DeviceIndex device_index) + : c10::GeneratorImpl{Device(DeviceType::XLA, device_index), + DispatchKeySet(c10::DispatchKey::XLA)} { + state_ = c10::make_intrusive(); +} + +XLAGeneratorImpl::XLAGeneratorImpl(DeviceIndex device_index, + c10::intrusive_ptr state) + : c10::GeneratorImpl{Device(DeviceType::XLA, device_index), + DispatchKeySet(c10::DispatchKey::XLA)}, + state_(std::move(state)) {} + +DeviceType XLAGeneratorImpl::device_type() { return DeviceType::XLA; } + +std::shared_ptr XLAGeneratorImpl::clone() const { + return std::shared_ptr(clone_impl()); +} + +XLAGeneratorImpl* XLAGeneratorImpl::clone_impl() const { + return new XLAGeneratorImpl(device_.index(), state_->clone()); +} + +void XLAGeneratorImpl::set_current_seed(uint64_t seed) { state_->seed_ = seed; } + +uint64_t XLAGeneratorImpl::current_seed() const { return state_->seed_; } + +uint64_t XLAGeneratorImpl::seed() { + uint64_t random = c10::detail::getNonDeterministicRandom(true); + set_current_seed(random); + return random; +} + +void XLAGeneratorImpl::set_offset(uint64_t offset) { state_->offset_ = offset; } + +uint64_t XLAGeneratorImpl::get_offset() const { return state_->offset_; } + +/* Serialize the generator state into a CPU tensor. */ +c10::intrusive_ptr XLAGeneratorImpl::get_state() const { + static const size_t seed_size = sizeof(uint64_t); + static const size_t offset_size = sizeof(uint64_t); + static const size_t total_size = seed_size + offset_size; + + auto state_tensor = + at::empty({(int64_t)total_size}, + at::TensorOptions().dtype(at::kByte).device(at::kCPU)); + uint8_t* data_ptr = state_tensor.data_ptr(); + memcpy(data_ptr, &state_->seed_, seed_size); + memcpy(data_ptr + seed_size, &state_->offset_, offset_size); + return state_tensor.getIntrusivePtr(); +} + +void XLAGeneratorImpl::set_state(const c10::TensorImpl& new_state) { + static const size_t seed_size = sizeof(uint64_t); + static const size_t offset_size = sizeof(uint64_t); + static const size_t total_size = seed_size + offset_size; + + TORCH_CHECK(new_state.numel() == total_size, + "The given state must be a byte tensor of size ", total_size, + ", but was size ", new_state.numel()); + TORCH_CHECK(new_state.dtype() == at::kByte, + "The given state must be a byte tensor, but was ", + new_state.dtype()); + TORCH_CHECK(new_state.is_cpu(), "The given state must be a CPU tensor"); + + auto new_rng_state = new_state.data_dtype_initialized(); + memcpy(&state_->seed_, new_rng_state, seed_size); + memcpy(&state_->offset_, new_rng_state + seed_size, offset_size); +} + +} // namespace at diff --git a/torch_xla/csrc/xla_generator.h b/torch_xla/csrc/xla_generator.h new file mode 100644 index 000000000000..330d32861200 --- /dev/null +++ b/torch_xla/csrc/xla_generator.h @@ -0,0 +1,56 @@ +#pragma once + +#include +#include +#include + +#include + +namespace at { + +// Holds the actual state variables for the XLA generator. +struct XLAGeneratorState : c10::intrusive_ptr_target { + uint64_t seed_ = 0; + uint64_t offset_ = 0; + + // Constructor + XLAGeneratorState(uint64_t seed = 0, uint64_t offset = 0) + : seed_(seed), offset_(offset) {} + + // Cloning method + c10::intrusive_ptr clone() { + return c10::make_intrusive(seed_, offset_); + } +}; + +struct TORCH_API XLAGeneratorImpl : public c10::GeneratorImpl { + // Constructors + XLAGeneratorImpl(DeviceIndex device_index = -1); + XLAGeneratorImpl(DeviceIndex device_index, + c10::intrusive_ptr state); + ~XLAGeneratorImpl() override = default; + + // Cloning support + std::shared_ptr clone() const; + + // --- Core Virtual Methods to Override --- + void set_current_seed(uint64_t seed) override; + uint64_t current_seed() const override; + uint64_t seed() override; + void set_offset(uint64_t offset) override; + uint64_t get_offset() const override; + c10::intrusive_ptr get_state() const override; + void set_state(const c10::TensorImpl& new_state) override; + + // --- Additional Methods --- + static c10::DeviceType device_type(); + + private: + // Private clone implementation + XLAGeneratorImpl* clone_impl() const override; + + // The actual state is held in a separate, cloneable object. + c10::intrusive_ptr state_; +}; + +} // namespace at \ No newline at end of file From 0f56dec9a33a993d4c14cb755bdd25490cabba21 Mon Sep 17 00:00:00 2001 From: Zain Rizvi Date: Tue, 19 Aug 2025 10:35:38 -0500 Subject: [PATCH 055/133] [EZ] Replace `pytorch-labs` with `meta-pytorch` (#9556) This PR replaces all instances of `pytorch-labs` with `meta-pytorch` in this repository now that the `pytorch-labs` org has been renamed to `meta-pytorch` ## Changes Made - Replaced all occurrences of `pytorch-labs` with `meta-pytorch` - Only modified files with extensions: .py, .md, .sh, .rst, .cpp, .h, .txt, .yml - Skipped binary files and files larger than 1MB due to GitHub api payload limits in the script to cover all repos in this org. Will do a more manual second pass later to cover any larger files ## Files Modified This PR updates files that contained the target text. Generated by automated script on 2025-08-12T20:59:10.495582+00:00Z --- torchax/test/llama/llama_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchax/test/llama/llama_model.py b/torchax/test/llama/llama_model.py index d6a323ae6b9c..2aa3566ae0b1 100644 --- a/torchax/test/llama/llama_model.py +++ b/torchax/test/llama/llama_model.py @@ -4,7 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -# This file is copied from https://github.com/pytorch-labs/gpt-fast +# This file is copied from https://github.com/meta-pytorch/gpt-fast # This is used for unit test purposes from dataclasses import dataclass import math From b84c83b46615f767e6d94cda959db8178ddd95b5 Mon Sep 17 00:00:00 2001 From: Tarun Paparaju Date: Thu, 21 Aug 2025 13:35:51 -0700 Subject: [PATCH 056/133] Added missing "#"s for the comments in triton.md (#9571) Hello! As the title suggests, I have added some missing "#"s to make sure the Triton demo code runs without errors. --- docs/source/features/triton.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/features/triton.md b/docs/source/features/triton.md index 991583aab221..33bf1a4d5861 100644 --- a/docs/source/features/triton.md +++ b/docs/source/features/triton.md @@ -49,8 +49,8 @@ block_size = 8 grid = (triton.cdiv(size, block_size),) # triton_call takes the same arguments as the triton.jit function, in addition -to the kernel itself and the grid that is used to execute the kernel. -All the tl.constexpr terms are passed as kwargs at the end. +# to the kernel itself and the grid that is used to execute the kernel. +# All the tl.constexpr terms are passed as kwargs at the end. payload = xla_triton.triton_call( x, y, output, size, kernel=add_kernel, grid=grid, BLOCK_SIZE=block_size) From 6b6ef5c7d757f955565b2083c48d936bfd758dcd Mon Sep 17 00:00:00 2001 From: qihqi Date: Fri, 22 Aug 2025 12:03:42 -0700 Subject: [PATCH 057/133] Remove tests that are defined outside of this repo. (#9577) Includes: * TPU info * tests in pytorch/pytorch With this PR, I'd like to stablish the convention of tests: we test stuff defined in this repo only. If we want to test an interation of torch_xla with a third_party library, we would need to define the test itself inside of this repo. Reason for disabling TPU info test: Historically tpu-info commandline tool's source is in this repo, so the test is in this repo. It's source was then moved to https://github.com/AI-Hypercomputer/cloud-accelerator-diagnostics because it's not a Pytorch/XLA specific tool. Now, given that it is released independently, we cannot control it. (see https://github.com/pytorch/xla/issues/9568) Reason for disabling pytorch/pytorch: The test themselves are defined in pytorch/pytorch instead of pytorch/xla repo. (the ../.. path goes to the parent folder which is where pytorch/pytorch is cloned). If we find a particular test helpful, we should copy that test into pytorch/xla Historically this list are a tiny subset of the full pytorch/pytorch tests, and is often commented out if failing. i.e. this is merely a list of tests that we found out that happen to pass, not a list that we want to enforce passing. --- test/run_tests.sh | 24 ------- test/tpu/run_training_tests.sh | 5 -- test/tpu/tpu_info/test_cli.py | 114 --------------------------------- 3 files changed, 143 deletions(-) delete mode 100644 test/tpu/tpu_info/test_cli.py diff --git a/test/run_tests.sh b/test/run_tests.sh index 54c893c7b405..85ae9d8691ce 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -135,25 +135,6 @@ function run_pt_xla_debug_level2 { PT_XLA_DEBUG_LEVEL=2 PT_XLA_DEBUG_FILE="/tmp/pt_xla_debug.txt" run_test "$@" } -function run_torch_op_tests { - run_dynamic "$_TEST_DIR/../../test/test_view_ops.py" "$@" -v TestViewOpsXLA - run_test_without_functionalization "$_TEST_DIR/../../test/test_view_ops.py" "$@" -v TestViewOpsXLA - run_test "$_TEST_DIR/../../test/test_torch.py" "$@" -v TestTorchDeviceTypeXLA - run_dynamic "$_TEST_DIR/../../test/test_torch.py" "$@" -v TestDevicePrecisionXLA - # TODO https://github.com/pytorch/xla/issues/9459: Investigate why this - # doesn't run any tests. - # run_test "$_TEST_DIR/../../test/test_torch.py" "$@" -v TestTensorDeviceOpsXLA - run_test "$_TEST_DIR/../../test/test_indexing.py" "$@" -v TestIndexingXLA - run_test "$_TEST_DIR/../../test/test_indexing.py" "$@" -v NumpyTestsXLA - # run_dynamic "$_TEST_DIR/../../test/test_nn.py" "$@" -v TestNNDeviceTypeXLA - run_dynamic "$_TEST_DIR/../../test/nn/test_dropout.py" "$@" -v TestDropoutNNDeviceTypeXLA - run_dynamic "$_TEST_DIR/../../test/nn/test_pooling.py" "$@" -v TestPoolingNNDeviceTypeXLA - run_dynamic "$_TEST_DIR/../../test/nn/test_embedding.py" "$@" -v TestEmbeddingNNDeviceTypeXLA - run_dynamic "$_TEST_DIR/../../test/nn/test_convolution.py" "$@" -v TestConvolutionNNDeviceTypeXLA - run_dynamic "$_TEST_DIR/../../test/nn/test_multihead_attention.py" "$@" -v TestMultiheadAttentionNNDeviceTypeXLA - run_dynamic "$_TEST_DIR/../../test/test_type_promotion.py" "$@" -v TestTypePromotionXLA -} - ####################################################################################### ################################# XLA OP TESTS SHARDS ################################# ####################################################################################### @@ -300,7 +281,6 @@ function run_xla_op_tests5 { ####################################################################################### function run_op_tests { - run_torch_op_tests run_xla_op_tests1 run_xla_op_tests2 run_xla_op_tests3 @@ -350,7 +330,6 @@ function run_tests { run_xla_op_tests5 elif [[ "$RUN_TORCH_MP_OP_TESTS" == "torch_mp_op" ]]; then echo "Running torch op tests..." - run_torch_op_tests PJRT_DEVICE=CPU XLA_CUDA=0 run_mp_op_tests else @@ -362,9 +341,6 @@ function run_tests { run_xla_op_tests4 run_xla_op_tests5 fi - if [[ "$XLA_SKIP_TORCH_OP_TESTS" != "1" ]]; then - run_torch_op_tests - fi if [[ "$XLA_SKIP_MP_OP_TESTS" != "1" ]]; then run_mp_op_tests fi diff --git a/test/tpu/run_training_tests.sh b/test/tpu/run_training_tests.sh index d87e03f34f6c..a938be996e52 100755 --- a/test/tpu/run_training_tests.sh +++ b/test/tpu/run_training_tests.sh @@ -32,8 +32,3 @@ if [[ -n "$TPU_VERSION" && "$TPU_VERSION" == "4" ]]; then python3 "$_TEST_DIR/../examples/eager/train_decoder_only_eager_multi_process.py" XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 "$_TEST_DIR/ds/test_dynamic_shapes.py" -v fi - -if [[ -n "$TPU_VERSION" && "$TPU_VERSION" != "6" ]]; then - # Test `tpu-info` CLI compatibility - python3 "$_TPU_DIR/tpu_info/test_cli.py" -fi diff --git a/test/tpu/tpu_info/test_cli.py b/test/tpu/tpu_info/test_cli.py deleted file mode 100644 index 3de3330d788f..000000000000 --- a/test/tpu/tpu_info/test_cli.py +++ /dev/null @@ -1,114 +0,0 @@ -import threading -from absl.testing import absltest, parameterized -import contextlib -import multiprocessing -import os -from typing import Dict, Optional -import torch_xla -import torch_xla.runtime as xr -import torch_xla.distributed.xla_multiprocessing as xmp -from tpu_info import cli, device, metrics - - -class TpuInfoCliTest(parameterized.TestCase): - - @classmethod - def setUpClass(cls): - xr.set_device_type("TPU") - chip_type, num_chips = device.get_local_chips() - assert chip_type is not None - cls.chip_type = chip_type - cls.num_chips = num_chips - - @staticmethod - def _init_tpu_and_wait( - # accept index as first arg to make xmp.spawn happy - index: int, - q: multiprocessing.Queue, - done: multiprocessing.Event, - env: Optional[Dict[str, str]] = None, - ): - if env: - os.environ.update(**env) - torch_xla.device() - q.put(os.getpid()) - done.wait() - - @contextlib.contextmanager - def _torch_xla_process(self, env: Dict[str, str]): - with multiprocessing.Manager() as m: - q = m.Queue() - done = m.Event() - p = multiprocessing.Process( - target=self._init_tpu_and_wait, args=(0, q, done, env)) - p.start() - pid = q.get(timeout=20.0) - with contextlib.ExitStack() as e: - e.callback(done.set) - yield pid - # Wait for process to exit before next test - p.join() - - @parameterized.named_parameters([ - ("all_chips", {}), - ("one_chip", { - "TPU_VISIBLE_CHIPS": "0", - "TPU_PROCESS_BOUNDS": "1,1,1", - "TPU_CHIPS_PER_PROCESS_BOUNDS": "1,1,1" - }), - ]) - def test_single_process_e2e(self, extra_env): - with self._torch_xla_process(extra_env) as subprocess_pid: - owners = device.get_chip_owners() - for _, pid in owners.items(): - self.assertEqual(pid, subprocess_pid) - usages = metrics.get_chip_usage(self.chip_type) - for u in usages: - self.assertGreater(u.total_memory, 0) - self.assertEqual(u.duty_cycle_pct, 0.0) - one_gb = 1 << 30 - self.assertLess(u.memory_usage, one_gb) - # TODO(https://github.com/pytorch/xla/issues/9462): Uncomment after - # libtpu is fixed for python 3.12 - # cli.print_chip_info() - - @contextlib.contextmanager - def _torch_xla_spawn(self): - with multiprocessing.Manager() as m: - q = m.Queue() - done = m.Event() - # HACK: run xmp.spawn in a thread because `join` arg is not implemented - t = threading.Thread( - target=xmp.spawn, - args=(self._init_tpu_and_wait,), - kwargs={'args': (q, done)}) - t.start() - - # v2 and v3 may have duplicates due to multithreading - child_pids = set() - for _ in range(self.chip_type.value.devices_per_chip * self.num_chips): - child_pids.add(q.get(timeout=20.0)) - with contextlib.ExitStack() as e: - e.callback(done.set) - yield child_pids - - t.join() - - def test_multiprocessing_e2e(self): - with self._torch_xla_spawn() as subprocess_pids: - owners = device.get_chip_owners() - self.assertSetEqual( - set(pid for _, pid in owners.items()), subprocess_pids) - usages = metrics.get_chip_usage(self.chip_type) - for u in usages: - self.assertGreater(u.total_memory, 0) - self.assertEqual(u.duty_cycle_pct, 0.0) - one_gb = 1 << 30 - self.assertLess(u.memory_usage, one_gb) - # TODO(https://github.com/pytorch/xla/issues/9462): Uncomment after - # libtpu is fixed for python 3.12 - # cli.print_chip_info() - - -if __name__ == "__main__": - absltest.main() From 748ac9b1032cea9499f8062a10607eceb4a84cb7 Mon Sep 17 00:00:00 2001 From: qihqi Date: Fri, 22 Aug 2025 13:53:36 -0700 Subject: [PATCH 058/133] Update XLA pin then fix up to make it compile (#9565) Updates: * Modify call to some OpenXLA C++ functions because their calling convention changed. * bazel is now more strict: explicit header and explicit deps are enforced * removed FORTIFY_SOURCE defines via a patch, it is making gcc segfault on compile * updated gcc version to 11 --- .circleci/common.sh | 3 +- .github/workflows/_test.yml | 2 +- .github/workflows/_tpu_ci.yml | 4 +- CONTRIBUTING.md | 2 +- README.md | 2 +- WORKSPACE | 25 ++++++----- bazel/rules_def.bzl | 8 +--- openxla_patches/count_down.diff | 14 ------- openxla_patches/no_fortify.diff | 40 ++++++++++++++++++ scripts/build_torch_wheels.sh | 8 ++-- scripts/update_deps.py | 4 +- setup.py | 42 +++++++------------ test/tpu/xla_test_job.yaml | 2 +- torch_xla/_internal/jax_workarounds.py | 2 + torch_xla/csrc/init_python_bindings.cpp | 2 +- .../csrc/runtime/ifrt_computation_client.cpp | 6 +-- torch_xla/csrc/runtime/profiler.cpp | 17 +++++++- torch_xla/csrc/runtime/xla_coordinator.cpp | 6 +-- torch_xla/csrc/status.cpp | 1 + 19 files changed, 110 insertions(+), 80 deletions(-) delete mode 100644 openxla_patches/count_down.diff create mode 100644 openxla_patches/no_fortify.diff diff --git a/.circleci/common.sh b/.circleci/common.sh index df145603db42..3093a8006942 100755 --- a/.circleci/common.sh +++ b/.circleci/common.sh @@ -112,7 +112,8 @@ function build_torch_xla() { # Need to uncomment the line below. # Currently it fails upstream XLA CI. # pip install plugins/cuda -v - pip install 'torch_xla[pallas]' + pip install --pre torch_xla[pallas] --index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html + popd } diff --git a/.github/workflows/_test.yml b/.github/workflows/_test.yml index 413a5aef8322..4ef00dcedaed 100644 --- a/.github/workflows/_test.yml +++ b/.github/workflows/_test.yml @@ -140,7 +140,7 @@ jobs: set -x pip install expecttest unittest-xml-reporting - pip install 'torch_xla[pallas]' + pip install --pre 'torch_xla[pallas]' --index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html if [[ ! -z "$RUN_BENCHMARK_TESTS" ]]; then pip install -r pytorch/xla/benchmarks/requirements.txt diff --git a/.github/workflows/_tpu_ci.yml b/.github/workflows/_tpu_ci.yml index 82dd7c748c1c..dc766c53a897 100644 --- a/.github/workflows/_tpu_ci.yml +++ b/.github/workflows/_tpu_ci.yml @@ -52,8 +52,8 @@ jobs: pip install fsspec pip install rich # jax and libtpu is needed for pallas tests. - pip install 'torch_xla[pallas]' - pip install 'torch_xla[tpu]' -f https://storage.googleapis.com/libtpu-wheels/index.html -f https://storage.googleapis.com/libtpu-releases/index.html + pip install --pre 'torch_xla[pallas]' --index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html' + pip install --pre 'torch_xla[tpu]' --index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html' pip install --upgrade protobuf - name: Run Tests (${{ matrix.test_script }}) if: inputs.has_code_changes == 'true' diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d01f657c7609..a6eb0af8a54d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -162,7 +162,7 @@ commands on your Linux machine directly, outside of the container. -f https://storage.googleapis.com/libtpu-releases/index.html # Optional: if you're using custom kernels, install pallas dependencies - pip install torch_xla[pallas] + pip install --pre torch_xla[pallas] --index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html ``` 1. If you are running on a TPU VM, ensure `torch` and `torch_xla` were built and diff --git a/README.md b/README.md index 8e31f4800c2a..11d5a5712560 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ Note: Builds are available for Python 3.8 to 3.11; please use one of the support pip install torch==2.8.0 'torch_xla[tpu]==2.8.0' # Optional: if you're using custom kernels, install pallas dependencies -pip install 'torch_xla[pallas]' +pip install --pre torch_xla[pallas] --index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html ``` **As of 07/16/2025 and starting from Pytorch/XLA 2.8 release, PyTorch/XLA will provide nightly and release wheels for Python 3.11 to 3.13** diff --git a/WORKSPACE b/WORKSPACE index f05f70023301..8222c5797bba 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -46,7 +46,7 @@ new_local_repository( # To build PyTorch/XLA with a new revison of OpenXLA, update the xla_hash to # the openxla git commit hash and note the date of the commit. -xla_hash = '3d5ece64321630dade7ff733ae1353fc3c83d9cc' # Committed on 2025-06-17. +xla_hash = '92f7b5952dd585c5be17c9a5caad27407005b513' # Committed on 2025-08-15. http_archive( name = "xla", @@ -58,7 +58,7 @@ http_archive( patches = [ "//openxla_patches:gpu_nvml.diff", "//openxla_patches:gpu_race_condition.diff", - "//openxla_patches:count_down.diff", + "//openxla_patches:no_fortify.diff", ], strip_prefix = "xla-" + xla_hash, urls = [ @@ -81,6 +81,19 @@ http_archive( # path = "/path/to/openxla", # ) +# Initialize OpenXLA's external dependencies. There is an specific order +# which those dependencies are initialized, because for bazel it's the +# first definition that takes precedence. +# We follow what openxla/xla does exactly: +# https://github.com/openxla/xla/blob/main/WORKSPACE#L37 +load("@xla//:workspace4.bzl", "xla_workspace4") + +xla_workspace4() + +load("@xla//:workspace3.bzl", "xla_workspace3") + +xla_workspace3() + # Initialize hermetic Python load("@xla//third_party/py:python_init_rules.bzl", "python_init_rules") @@ -115,14 +128,6 @@ install_deps() -# Initialize OpenXLA's external dependencies. -load("@xla//:workspace4.bzl", "xla_workspace4") - -xla_workspace4() - -load("@xla//:workspace3.bzl", "xla_workspace3") - -xla_workspace3() load("@xla//:workspace2.bzl", "xla_workspace2") diff --git a/bazel/rules_def.bzl b/bazel/rules_def.bzl index 3a089bb79405..a5053f06d395 100644 --- a/bazel/rules_def.bzl +++ b/bazel/rules_def.bzl @@ -1,10 +1,4 @@ """Rules that simplify deps and compiler configuration for PyTorch/XLA.""" - -load( - "@xla//xla:xla.default.bzl", - "xla_cc_test", -) - def ptxla_cc_library( deps = [], copts = [], @@ -22,7 +16,7 @@ def ptxla_cc_test( deps, copts = [], **kwargs): - xla_cc_test( + native.cc_test( linkstatic = True, copts = copts + [ "-isystemexternal/torch", # Required for system includes. diff --git a/openxla_patches/count_down.diff b/openxla_patches/count_down.diff deleted file mode 100644 index b46d3907752f..000000000000 --- a/openxla_patches/count_down.diff +++ /dev/null @@ -1,14 +0,0 @@ -diff --git a/xla/backends/cpu/runtime/convolution_thunk_internal.h b/xla/backends/cpu/runtime/convolution_thunk_internal.h -index 84fed6bb78..9835f12e4e 100644 ---- a/xla/backends/cpu/runtime/convolution_thunk_internal.h -+++ b/xla/backends/cpu/runtime/convolution_thunk_internal.h -@@ -342,7 +342,8 @@ void EigenGenericConv2D( - Eigen::Index start = task_index * task_size; - Eigen::Index end = std::min(start + task_size, feature_group_count); - for (Eigen::Index i = start; i < end; ++i) { -- auto on_done = [count_down]() mutable { count_down.CountDown(); }; -+ // auto on_done = [count_down]() mutable { count_down.CountDown(); }; -+ auto on_done = [count_down]() mutable { const_cast(&count_down)->CountDown(); }; - auto [output, convolved] = convolve_group(i); - output.device(device, std::move(on_done)) = convolved; - } diff --git a/openxla_patches/no_fortify.diff b/openxla_patches/no_fortify.diff new file mode 100644 index 000000000000..1fbd93af63f6 --- /dev/null +++ b/openxla_patches/no_fortify.diff @@ -0,0 +1,40 @@ +diff --git a/tools/toolchains/cross_compile/cc/BUILD b/tools/toolchains/cross_compile/cc/BUILD +index be6e8968d5..51d15fa3c9 100644 +--- a/tools/toolchains/cross_compile/cc/BUILD ++++ b/tools/toolchains/cross_compile/cc/BUILD +@@ -80,7 +80,6 @@ cc_toolchain_config( + opt_compile_flags = [ + "-g0", + "-O2", +- "-D_FORTIFY_SOURCE=1", + "-DNDEBUG", + "-ffunction-sections", + "-fdata-sections", +@@ -166,7 +165,6 @@ cc_toolchain_config( + opt_compile_flags = [ + "-g0", + "-O2", +- "-D_FORTIFY_SOURCE=1", + "-DNDEBUG", + "-ffunction-sections", + "-fdata-sections", +@@ -263,7 +261,6 @@ cc_toolchain_config( + opt_compile_flags = [ + "-g0", + "-O2", +- "-D_FORTIFY_SOURCE=1", + "-DNDEBUG", + "-ffunction-sections", + "-fdata-sections", +diff --git a/tools/toolchains/cross_compile/cc/cc_toolchain_config.bzl b/tools/toolchains/cross_compile/cc/cc_toolchain_config.bzl +index de638c0159..f25c0cc0fe 100644 +--- a/tools/toolchains/cross_compile/cc/cc_toolchain_config.bzl ++++ b/tools/toolchains/cross_compile/cc/cc_toolchain_config.bzl +@@ -218,7 +218,6 @@ def _impl(ctx): + flag_group( + # Security hardening requires optimization. + # We need to undef it as some distributions now have it enabled by default. +- flags = ["-U_FORTIFY_SOURCE"], + ), + ], + with_features = [ diff --git a/scripts/build_torch_wheels.sh b/scripts/build_torch_wheels.sh index 5e3ada94cd2d..25abe3bea010 100755 --- a/scripts/build_torch_wheels.sh +++ b/scripts/build_torch_wheels.sh @@ -139,9 +139,9 @@ function install_llvm_clang() { sudo update-alternatives --install /usr/bin/clang++ clang++ $(which clang++-8) 70 } -function install_gcc10() { - sudo apt-get -y install gcc-10 g++-10 - export CC=/usr/bin/gcc-10 export CXX=/usr/bin/g++-10 +function install_gcc() { + sudo apt-get -y install gcc-11 g++-11 + export CC=/usr/bin/gcc-10 export CXX=/usr/bin/g++-11 sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 100 sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-10 100 } @@ -332,7 +332,7 @@ function main() { if [[ $(uname -m) == "x86_64" ]]; then install_llvm_clang elif [[ $(uname -m) == "aarch64" ]]; then - install_gcc10 + install_gcc fi install_and_setup_conda build_and_install_torch diff --git a/scripts/update_deps.py b/scripts/update_deps.py index b737e6317ce1..6c369406d699 100755 --- a/scripts/update_deps.py +++ b/scripts/update_deps.py @@ -162,8 +162,8 @@ def get_latest_stable_jax_info() -> tuple[str, str, str] | None: published_at = data['published_at'] # e.g., "2024-04-26T22:58:34Z" release_date = published_at.split('T')[0] # e.g., "2024-04-26" - # The XLA commit is in third_party/xla/workspace.bzl in the JAX repo. - workspace_bzl_url = f'https://raw.githubusercontent.com/google/jax/{tag_name}/third_party/xla/workspace.bzl' + # The XLA commit is in third_party/xla/revision.bzl in the JAX repo. + workspace_bzl_url = f'https://raw.githubusercontent.com/google/jax/{tag_name}/third_party/xla/revision.bzl' try: with urllib.request.urlopen(workspace_bzl_url) as response: workspace_content = response.read().decode() diff --git a/setup.py b/setup.py index 11824cd08a47..76ce042c6fee 100644 --- a/setup.py +++ b/setup.py @@ -115,18 +115,18 @@ USE_NIGHTLY = True # Whether to use nightly or stable libtpu and JAX. -_libtpu_version = '0.0.18' -_libtpu_date = '20250617' +_libtpu_version = '0.0.21' +_libtpu_date = '20250813' -_jax_version = '0.6.2' -_jaxlib_version = '0.6.2' -_jax_date = '20250617' # Date for jax and jaxlib. +_jax_version = '0.7.1' +_jaxlib_version = '0.7.1' +_jax_date = '20250813' # Date for jax and jaxlib. if USE_NIGHTLY: - _libtpu_version += f".dev{_libtpu_date}" + _libtpu_version += f".dev{_libtpu_date}+nightly" _jax_version += f'.dev{_jax_date}' _jaxlib_version += f'.dev{_jax_date}' - _libtpu_wheel_name = f'libtpu-{_libtpu_version}.dev{_libtpu_date}+nightly-py3-none-manylinux_2_31_{platform_machine}' + _libtpu_wheel_name = f'libtpu-{_libtpu_version}-py3-none-manylinux_2_31_{platform_machine}' _libtpu_storage_directory = 'libtpu-nightly-releases' else: # The postfix can be changed when the version is updated. Check @@ -134,8 +134,8 @@ # versioning. _libtpu_wheel_name = f'libtpu-{_libtpu_version}-py3-none-manylinux_2_31_{platform_machine}' _libtpu_storage_directory = 'libtpu-lts-releases' - -_libtpu_storage_path = f'https://storage.googleapis.com/{_libtpu_storage_directory}/wheels/libtpu/{_libtpu_wheel_name}.whl' +#https://us-python.pkg.dev/ml-oss-artifacts-published/jax/libtpu/libtpu-0.0.19.1-py3-none-manylinux_2_31_x86_64.whl +_libtpu_storage_path = f'https://us-python.pkg.dev/ml-oss-artifacts-published/jax/libtpu/{_libtpu_wheel_name}.whl' def _get_build_mode(): @@ -423,22 +423,10 @@ def link_packages(self): def _get_jax_install_requirements(): - if not USE_NIGHTLY: - # Stable versions of JAX can be directly installed from PyPI. - return [ - f'jaxlib=={_jaxlib_version}', - f'jax=={_jax_version}', - ] - - # Install nightly JAX libraries from the JAX package registries. - jax = f'jax @ https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/jax/jax-{_jax_version}-py3-none-any.whl' - - jaxlib = [] - for python_minor_version in [9, 10, 11, 12]: - jaxlib.append( - f'jaxlib @ https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/jaxlib/jaxlib-{_jaxlib_version}-cp3{python_minor_version}-cp3{python_minor_version}-manylinux2014_x86_64.whl ; python_version == "3.{python_minor_version}"' - ) - return [jax] + jaxlib + return [ + f'jaxlib=={_jaxlib_version}', + f'jax=={_jax_version}', + ] setup( @@ -496,12 +484,12 @@ def _get_jax_install_requirements(): }, extras_require={ # On Cloud TPU VM install with: - # pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-wheels/index.html -f https://storage.googleapis.com/libtpu-releases/index.html + # pip install torch_xla[tpu] --index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html 'tpu': [ f'libtpu=={_libtpu_version}', 'tpu-info', ], - # pip install torch_xla[pallas] + # pip install torch_xla[pallas] --index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html 'pallas': [*_get_jax_install_requirements(),] }, cmdclass={ diff --git a/test/tpu/xla_test_job.yaml b/test/tpu/xla_test_job.yaml index 30c1945eb269..232e40cb71b7 100644 --- a/test/tpu/xla_test_job.yaml +++ b/test/tpu/xla_test_job.yaml @@ -43,7 +43,7 @@ spec: - | pip install expecttest==0.1.6 pip install rich - pip install 'torch_xla[pallas]' + pip install --pre torch_xla[pallas] --index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html cd /src/pytorch/xla volumeMounts: diff --git a/torch_xla/_internal/jax_workarounds.py b/torch_xla/_internal/jax_workarounds.py index 04f37f8c0a00..d2d665704184 100644 --- a/torch_xla/_internal/jax_workarounds.py +++ b/torch_xla/_internal/jax_workarounds.py @@ -63,6 +63,8 @@ def maybe_get_jax(): jax_import_guard() with jax_env_context(): import jax + # TorchXLA still expects SPMD style sharding + jax.config.update('jax_use_shardy_partitioner', False) return jax except (ModuleNotFoundError, ImportError): logging.warn('You are trying to use a feature that requires jax/pallas.' diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index a2099f7d4ec1..d605c07406ba 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1361,7 +1361,7 @@ class PyLoweringContext { std::string GetHloJsonText() { const xla::HloModuleProto& proto = computation.proto(); std::string result; - google::protobuf::util::MessageToJsonString(proto, &result); + XLA_CHECK_OK(google::protobuf::util::MessageToJsonString(proto, &result)); return result; } diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cpp b/torch_xla/csrc/runtime/ifrt_computation_client.cpp index d6337503508e..0325d1440fc8 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cpp +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cpp @@ -218,7 +218,8 @@ std::vector IfrtComputationClient::GetDataShards( std::vector> arrays = ifrt_data->buffer ->DisassembleIntoSingleDeviceArrays( - xla::ifrt::ArrayCopySemantics::kAlwaysCopy) + xla::ifrt::ArrayCopySemantics::kAlwaysCopy, + xla::ifrt::SingleDeviceShardSemantics::kAddressableShards) .value(); for (auto array : arrays) { @@ -308,8 +309,7 @@ std::vector IfrtComputationClient::TransferToDevice( ifrt_device, xla::ifrt::MemoryKind()), xla::ifrt::Client::HostBufferSemantics:: kImmutableUntilTransferCompletes, - [tensor, timed]() { /* frees tensor and timer */ }, - client_->CreateUserContext()) + [tensor, timed]() { /* frees tensor and timer */ }) .value(); ComputationClient::DataPtr data = diff --git a/torch_xla/csrc/runtime/profiler.cpp b/torch_xla/csrc/runtime/profiler.cpp index 8d48046e723d..4da3d9d628c6 100644 --- a/torch_xla/csrc/runtime/profiler.cpp +++ b/torch_xla/csrc/runtime/profiler.cpp @@ -1,5 +1,7 @@ #include "torch_xla/csrc/runtime/profiler.h" +#include + #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "torch_xla/csrc/runtime/tf_logging.h" @@ -76,10 +78,23 @@ absl::Status Trace( int num_tracing_attempts, const absl::flat_hash_map>& options) { + // by 20250815 Upstream CaptureRemoteTrace changed signature of options to + // include bool option. For backward compatibility we don't change signature + // of Trace, but instead we make an adaptor to adapt the new function + // signature. + absl::flat_hash_map> + updated_options; + for (const auto& item : options) { + if (std::holds_alternative(item.second)) { + updated_options[item.first] = std::get(item.second); + } else { + updated_options[item.first] = std::get(item.second); + } + } return tsl::profiler::CaptureRemoteTrace( service_addr, logdir, /*worker_list=*/"", /*include_dataset_ops=*/false, duration_ms, num_tracing_attempts, - options); + updated_options); } void RegisterProfilerForPlugin(const PJRT_Api* c_api) { diff --git a/torch_xla/csrc/runtime/xla_coordinator.cpp b/torch_xla/csrc/runtime/xla_coordinator.cpp index 7ed14e410f52..ff31bc4df729 100644 --- a/torch_xla/csrc/runtime/xla_coordinator.cpp +++ b/torch_xla/csrc/runtime/xla_coordinator.cpp @@ -19,10 +19,8 @@ absl::Status XlaCoordinator::Initialize(int global_rank, int world_size, // Default value can be found in // https://github.com/openxla/xla/blob/4b88636002bc5834d7fe3f862997c66a490987bc/xla/pjrt/distributed/client.h#L63-L72. int heartbeat_interval_sec = - sys_util::GetEnvInt(env::kEnvDistSvcHeartbeatIntervalInSec, 10); - service_options.heartbeat_interval = absl::Seconds(heartbeat_interval_sec); - service_options.max_missing_heartbeats = - sys_util::GetEnvInt(env::kEnvDistSvcMaxMissingHeartbeats, 10); + sys_util::GetEnvInt(env::kEnvDistSvcHeartbeatIntervalInSec, 100); + service_options.heartbeat_timeout = absl::Seconds(heartbeat_interval_sec); int shutdown_timeout = sys_util::GetEnvInt(env::kEnvDistSvcShutdownTimeoutInMin, 5); service_options.shutdown_timeout = absl::Minutes(shutdown_timeout); diff --git a/torch_xla/csrc/status.cpp b/torch_xla/csrc/status.cpp index dc9892d7d572..5b86927b72db 100644 --- a/torch_xla/csrc/status.cpp +++ b/torch_xla/csrc/status.cpp @@ -8,6 +8,7 @@ #include #include "absl/log/absl_check.h" +#include "absl/strings/str_cat.h" #include "tsl/platform/stacktrace.h" namespace torch_xla { From f8b44e2a1ddc3cf40beba8aafb3186c818480dff Mon Sep 17 00:00:00 2001 From: Kyuyeun Kim <62023335+kyuyeunk@users.noreply.github.com> Date: Fri, 22 Aug 2025 15:09:39 -0700 Subject: [PATCH 059/133] Create mapping for FP8 torch dtypes (#9573) Fix a bug when using `t2j` with fp8 dtypes. --- torchax/torchax/ops/mappings.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/torchax/torchax/ops/mappings.py b/torchax/torchax/ops/mappings.py index 409a6d8350be..4eb7c6996159 100644 --- a/torchax/torchax/ops/mappings.py +++ b/torchax/torchax/ops/mappings.py @@ -6,6 +6,14 @@ import torch.utils.dlpack as torchdl import torch.utils._mode_utils as mode_utils +NUMPY_UNSUPPORTED_DTYPES = { + torch.bfloat16: jnp.bfloat16, + torch.float8_e4m3fn: jnp.float8_e4m3fn, + torch.float8_e4m3fnuz: jnp.float8_e4m3fnuz, + torch.float8_e5m2: jnp.float8_e5m2, + torch.float8_e5m2fnuz: jnp.float8_e5m2fnuz, +} + def t2j(t, use_dlpack=True): is_bool = False @@ -28,14 +36,14 @@ def t2j(t, use_dlpack=True): if res is None: # https://github.com/google/jax/issues/7657 # https://github.com/google/jax/issues/17784 - if t.dtype == torch.bfloat16: + if t.dtype in NUMPY_UNSUPPORTED_DTYPES: nparray = (t.cpu().detach().to(torch.float32).numpy() - ) # numpy don't support bfloat16 + ) # handle dtypes not supported by numpy else: nparray = t.cpu().detach().numpy() res = jnp.asarray(nparray) - if t.dtype == torch.bfloat16: - res = res.astype(jnp.bfloat16) + if t.dtype in NUMPY_UNSUPPORTED_DTYPES: + res = res.astype(NUMPY_UNSUPPORTED_DTYPES[t.dtype]) if is_bool: res = res.astype(jnp.bool_) From b098be87dde58fe48e5effe72c0bb6b9b4ba5b6e Mon Sep 17 00:00:00 2001 From: aws-cph Date: Fri, 22 Aug 2025 21:18:22 -0700 Subject: [PATCH 060/133] refactor: DTensor inheritance for XLAShardedTensor (#9576) Changing XLAShardedTensor to inherit from DTensor and not torch.tensor in regards to https://github.com/pytorch/xla/issues/9418. --- test/neuron/run_tests.sh | 1 + test/run_tests.sh | 1 + test/spmd/test_xla_sharded_tensor.py | 38 +++++++++++++++++++ test/tpu/run_tests.sh | 1 + .../distributed/spmd/xla_sharded_tensor.py | 5 ++- 5 files changed, 44 insertions(+), 2 deletions(-) create mode 100644 test/spmd/test_xla_sharded_tensor.py diff --git a/test/neuron/run_tests.sh b/test/neuron/run_tests.sh index d8ee9a39b03e..a68e0671a3b6 100755 --- a/test/neuron/run_tests.sh +++ b/test/neuron/run_tests.sh @@ -257,6 +257,7 @@ function run_xla_op_tests3 { run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_spec_conv.py" run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_redistribute.py" + run_test_multi_device "$_TEST_DIR/spmd/test_xla_sharded_tensor.py" run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py" #run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py" run_test "$_TEST_DIR/spmd/test_train_spmd_linear_model.py" diff --git a/test/run_tests.sh b/test/run_tests.sh index 85ae9d8691ce..8cfe37e29f56 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -238,6 +238,7 @@ function run_xla_op_tests3 { run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py" run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_redistribute.py" + run_test_multi_devices "$_TEST_DIR/spmd/test_xla_sharded_tensor.py" run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py" run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py" run_test "$_TEST_DIR/spmd/test_mp_input_sharding.py" diff --git a/test/spmd/test_xla_sharded_tensor.py b/test/spmd/test_xla_sharded_tensor.py new file mode 100644 index 000000000000..a101fb9bcd7e --- /dev/null +++ b/test/spmd/test_xla_sharded_tensor.py @@ -0,0 +1,38 @@ +import sys +import unittest +import test_xla_sharding_base +from torch.distributed.tensor import DTensor +from torch_xla.distributed.spmd import XLAShardedTensor + +import torch + + +class XlaShardedTensorTest(test_xla_sharding_base.XlaShardingTest): + + @classmethod + def setUpClass(cls): + super().setUpClass() + + def test_xlashardedtensor_is_dtensor(self): + """Test that XLAShardedTensor is a subclass of DTensor.""" + xt = torch.randn(128, 128).to('xla') + xla_tensor = XLAShardedTensor(xt) + self.assertIsInstance(xla_tensor, DTensor) + + def test_xlashardedtensor_gradient(self): + """Test accessing gradients of an XLAShardedTensor (triggers __torch_function__).""" + xt = torch.randn(128, 128).to('xla') + xla_tensor = XLAShardedTensor(xt, requires_grad=True) + result = xla_tensor.sum() + result.backward() + + # this should trigger __torch_function__ + grad = xla_tensor.grad + + self.assertIsNotNone(grad) + self.assertEqual(grad.shape, xla_tensor.shape) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 440db8bd28ad..e1ad7c0023a4 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -63,6 +63,7 @@ run_test "$_TEST_DIR/spmd/test_fsdp_v2.py" run_test "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" run_test "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py" run_test "$_TEST_DIR/spmd/test_dtensor_redistribute.py" +run_test "$_TEST_DIR/spmd/test_xla_sharded_tensor.py" run_test "$_TEST_DIR/test_gradient_accumulation.py" XLA_EXPERIMENTAL=nonzero:masked_select:nms run_test "$_TEST_DIR/ds/test_dynamic_shape_models.py" -v run_test "$_TEST_DIR/test_autocast.py" diff --git a/torch_xla/distributed/spmd/xla_sharded_tensor.py b/torch_xla/distributed/spmd/xla_sharded_tensor.py index 5a049b5864e3..652a2011cbda 100644 --- a/torch_xla/distributed/spmd/xla_sharded_tensor.py +++ b/torch_xla/distributed/spmd/xla_sharded_tensor.py @@ -11,6 +11,7 @@ from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor.placement_types import Placement, Shard, Replicate, Partial from torch.utils._pytree import tree_map_only +from torch.distributed.tensor import DTensor @dataclass @@ -63,7 +64,7 @@ def no_dispatch() -> Iterator[None]: del guard -class XLAShardedTensor(torch.Tensor): +class XLAShardedTensor(DTensor): """ A wrapper around `torch.Tensor` with sharding annotation for XLA SPMD auto-sharding. The wrapped tensors are unwrapped @@ -300,4 +301,4 @@ def redistribute(self, device_mesh, placements, *, async_op: bool = False): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): - return super().__torch_function__(func, types, args, kwargs) + return super(DTensor, cls).__torch_function__(func, types, args, kwargs) \ No newline at end of file From 147d2c254724dcd1a44980ff9b9e94b4d96764af Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Sat, 23 Aug 2025 11:39:55 -0300 Subject: [PATCH 061/133] `full`: improve error handling and error messages. (#9564) This PR refactors the `tensor_methods::full` and `tensor_methods::full_symint` implementation by improving their error message, and returning a status type value. **Key Changes:** - Make `tensor_methods::full` and `tensor_methods::full_symint` return `StatusOr` - Improve error message on invalid arguments --- test/test_operations.py | 10 ++++++ torch_xla/csrc/aten_xla_type.cpp | 38 ++++++++++---------- torch_xla/csrc/ops/index_ops.cpp | 3 +- torch_xla/csrc/tensor_methods.cpp | 58 ++++++++++++++++++++++--------- torch_xla/csrc/tensor_methods.h | 13 ++++--- torch_xla/csrc/tensor_ops.cpp | 12 +++---- 6 files changed, 83 insertions(+), 51 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index cb790a074148..a544d9ba19a7 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2511,6 +2511,16 @@ def test_flip_raises_error_on_duplicated_dims(self): f"from {dims} to {dims_suggestion}.") self.assertEqual(str(e), expected_error) + def test_full_raises_error_on_negative_size(self): + shape = [2, -2, 2] + try: + torch.full(shape, 1.5, device="xla") + except RuntimeError as e: + expected_error = ( + "full(): expected concrete sizes (i.e. non-symbolic) to be " + f"positive values. However found negative ones: {shape}.") + self.assertEqual(str(e), expected_error) + class MNISTComparator(nn.Module): diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 005e0e98dcc7..ceacd59603e8 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1702,16 +1702,14 @@ at::Tensor XLANativeFunctions::empty_symint( // does not actually end up doing any memory initialization, we use that and // avoid going to CPU for it. A common PT pattern is indeed doing empty() plus // s_copy_(). - XLATensorPtr xla_tensor; - if (all_dims_static) { - xla_tensor = tensor_methods::full(XlaHelpers::I64List(int_sizes.value()), 0, - GetXlaDeviceOrCurrent(device), - at::dtype_or_default(dtype)); - } else { - xla_tensor = - tensor_methods::full_symint(sym_size, 0, GetXlaDeviceOrCurrent(device), - at::dtype_or_default(dtype)); - } + XLATensorPtr xla_tensor = GetValueOrThrow( + all_dims_static + ? tensor_methods::full(XlaHelpers::I64List(int_sizes.value()), 0, + GetXlaDeviceOrCurrent(device), + at::dtype_or_default(dtype)) + : tensor_methods::full_symint(sym_size, 0, + GetXlaDeviceOrCurrent(device), + at::dtype_or_default(dtype))); // `tensor.to` will trigger an `empty` + `_to_copy`. In the egaer mode, the // `full` will be evulated eagerly and got a replicated sharding. We should // leave the sharding to be empty. @@ -1858,9 +1856,9 @@ at::Tensor XLANativeFunctions::full(at::IntArrayRef size, } else { intend_dtype = fill_value.type(); } - return bridge::AtenFromXlaTensor( + return bridge::AtenFromXlaTensor(GetValueOrThrow( tensor_methods::full(absl::Span(size), fill_value, - GetXlaDeviceOrCurrent(device), intend_dtype)); + GetXlaDeviceOrCurrent(device), intend_dtype))); } at::Tensor XLANativeFunctions::gather(const at::Tensor& self, int64_t dim, @@ -2681,8 +2679,8 @@ std::tuple XLANativeFunctions::nll_loss2d_forward( int64_t ignore_index) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - XLATensorPtr total_weight = tensor_methods::full( - {}, 1, self_tensor->GetDevice(), self_tensor->dtype()); + XLATensorPtr total_weight = GetValueOrThrow(tensor_methods::full( + {}, 1, self_tensor->GetDevice(), self_tensor->dtype())); return std::make_tuple( bridge::AtenFromXlaTensor(tensor_methods::nll_loss2d( self_tensor, GetValueOrThrow(bridge::GetXlaTensor(target)), @@ -2716,8 +2714,8 @@ std::tuple XLANativeFunctions::nll_loss_forward( int64_t ignore_index) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - XLATensorPtr total_weight = tensor_methods::full( - {}, 1, self_tensor->GetDevice(), self_tensor->dtype()); + XLATensorPtr total_weight = GetValueOrThrow(tensor_methods::full( + {}, 1, self_tensor->GetDevice(), self_tensor->dtype())); return std::make_tuple( bridge::AtenFromXlaTensor(tensor_methods::nll_loss( self_tensor, GetValueOrThrow(bridge::GetXlaTensor(target)), @@ -4038,10 +4036,10 @@ std::tuple XLANativeFunctions::_linalg_svd( if (!compute_uv) { // When compute_uv is false, torch::_linalg_svd returns an empty tensor for // u and vh. - u = tensor_methods::full({0}, 0, self_tensor->GetDevice(), - self_tensor->dtype()); - vh = tensor_methods::full({0}, 0, self_tensor->GetDevice(), - self_tensor->dtype()); + u = GetValueOrThrow(tensor_methods::full({0}, 0, self_tensor->GetDevice(), + self_tensor->dtype())); + vh = GetValueOrThrow(tensor_methods::full({0}, 0, self_tensor->GetDevice(), + self_tensor->dtype())); } return std::make_tuple(bridge::AtenFromXlaTensor(u), bridge::AtenFromXlaTensor(s), diff --git a/torch_xla/csrc/ops/index_ops.cpp b/torch_xla/csrc/ops/index_ops.cpp index a189446bbefe..25b0b4f4b852 100644 --- a/torch_xla/csrc/ops/index_ops.cpp +++ b/torch_xla/csrc/ops/index_ops.cpp @@ -315,7 +315,8 @@ XLATensorPtr GetZeroElementTensor(const XLATensorPtr& base, base_dimensions.begin() + start_dim + indices.size(), base_dimensions.end()); - return tensor_methods::full(dimensions, 0, base->GetDevice(), base->dtype()); + return GetValueOrThrow( + tensor_methods::full(dimensions, 0, base->GetDevice(), base->dtype())); } XLATensorPtr IndexByTensors(const XLATensorPtr& base, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 4a749d50ac79..7534191f0422 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -409,6 +409,39 @@ absl::Status CheckFlipDimensionsAreUnique( return absl::OkStatus(); } +template +absl::Status CheckFullSizesArePositiveImpl(absl::Span sizes, + const F& original_sizes_as_str) { + const bool has_concrete_negative_size = std::any_of( + sizes.begin(), sizes.end(), [](const int64_t size) { return size < 0; }); + if (has_concrete_negative_size) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError( + absl::StrCat("full(): expected concrete sizes (i.e. non-symbolic) to " + "be positive values. However found negative ones: [", + original_sizes_as_str(), "]."))); + } + return absl::OkStatus(); +} + +absl::Status CheckFullSizesArePositive(absl::Span sizes) { + return CheckFullSizesArePositiveImpl( + sizes, [&]() { return absl::StrJoin(sizes, /* sep= */ ", "); }); +} + +absl::Status CheckFullConcreteSizesArePositive(at::SymIntArrayRef sym_sizes) { + std::vector concrete_sizes_or_zero; + std::transform(sym_sizes.begin(), sym_sizes.end(), + std::back_inserter(concrete_sizes_or_zero), + [](at::SymInt sym) { return sym.maybe_as_int().value_or(0); }); + return CheckFullSizesArePositiveImpl(concrete_sizes_or_zero, [&]() { + return absl::StrJoin(sym_sizes.begin(), sym_sizes.end(), /* sep= */ ", ", + [](std::string* out, at::SymInt sym) { + absl::StrAppendFormat(out, "%s", + absl::FormatStreamed(sym)); + }); + }); +} + } // namespace ////////////////////////////////////////////////////////////////////////////// @@ -1767,10 +1800,10 @@ XLATensorPtr fmod(const XLATensorPtr& input, const at::Scalar& other, logical_element_type); } -XLATensorPtr full(absl::Span size, const at::Scalar& fill_value, - const torch::lazy::BackendDevice& device, - at::ScalarType scalar_type) { - CheckShapeDimensions(size); +absl::StatusOr full( + absl::Span size, const at::Scalar& fill_value, + const torch::lazy::BackendDevice& device, at::ScalarType scalar_type) { + XLA_RETURN_IF_ERROR(CheckFullSizesArePositive(size)); xla::Shape shape = MakeArrayShapeFromDimensions(size, /*dynamic_dimensions=*/{}, MakeXlaPrimitiveType(scalar_type, &device), @@ -1794,19 +1827,10 @@ XLATensorPtr full_like(const XLATensorPtr& input, const at::Scalar& fill_value, device, *scalar_type); } -XLATensorPtr full_symint(at::SymIntArrayRef sym_size, - const at::Scalar& fill_value, - const torch::lazy::BackendDevice& device, - at::ScalarType scalar_type) { - XLA_CHECK(std::all_of(sym_size.begin(), sym_size.end(), [](at::SymInt dim) { - // TODO: It should be OK to perform this test on symbolic ints too, not - // sure why you conditionalized it. - if (auto c = dim.maybe_as_int()) { - return *c >= 0; - } - return true; - })) << "Dimensions cannot be negative numbers"; - +absl::StatusOr full_symint( + at::SymIntArrayRef sym_size, const at::Scalar& fill_value, + const torch::lazy::BackendDevice& device, at::ScalarType scalar_type) { + XLA_RETURN_IF_ERROR(CheckFullConcreteSizesArePositive(sym_size)); return XLATensor::Create( XLAGraphExecutor::Get()->GetIrValueForScalar( fill_value, MakeXlaPrimitiveType(scalar_type, &device), sym_size, diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index fb7eae93f8db..869dcaa8dffb 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -460,16 +460,15 @@ XLATensorPtr fmod( const XLATensorPtr& input, const at::Scalar& other, std::optional logical_element_type = std::nullopt); -XLATensorPtr full(absl::Span size, const at::Scalar& fill_value, - const torch::lazy::BackendDevice& device, - at::ScalarType scalar_type); +absl::StatusOr full( + absl::Span size, const at::Scalar& fill_value, + const torch::lazy::BackendDevice& device, at::ScalarType scalar_type); XLATensorPtr full_like(const XLATensorPtr& input, const at::Scalar& fill_value, const torch::lazy::BackendDevice& device, std::optional scalar_type); -XLATensorPtr full_symint(at::SymIntArrayRef sym_size, - const at::Scalar& fill_value, - const torch::lazy::BackendDevice& device, - at::ScalarType scalar_type); +absl::StatusOr full_symint( + at::SymIntArrayRef sym_size, const at::Scalar& fill_value, + const torch::lazy::BackendDevice& device, at::ScalarType scalar_type); XLATensorPtr gather(const XLATensorPtr& input, int64_t dim, const XLATensorPtr& index); diff --git a/torch_xla/csrc/tensor_ops.cpp b/torch_xla/csrc/tensor_ops.cpp index 2b925d7c381a..edb7d22297c4 100644 --- a/torch_xla/csrc/tensor_ops.cpp +++ b/torch_xla/csrc/tensor_ops.cpp @@ -207,16 +207,16 @@ XLATensorPtr EmbeddingDenseBackward(const XLATensorPtr& grad_output, int64_t numel = xla::ShapeUtil::ElementsIn(indices_shape_ref.get()); XLATensorPtr grad = tensor_methods::view(grad_output, {numel, grad_output->size(-1)}); - XLATensorPtr grad_weight = + XLATensorPtr grad_weight = GetValueOrThrow( tensor_methods::full({num_weights, grad_output->size(-1)}, 0, - grad_output->GetDevice(), grad_output->dtype()); + grad_output->GetDevice(), grad_output->dtype())); XLATensorPtr indices_rank1 = tensor_methods::view(indices, {numel}); if (scale_grad_by_freq) { // Compute the histogram of index values. - XLATensorPtr counts = tensor_methods::full( - {num_weights}, 0, indices->GetDevice(), indices->dtype()); - XLATensorPtr ones = tensor_methods::full({numel}, 1, indices->GetDevice(), - indices->dtype()); + XLATensorPtr counts = GetValueOrThrow(tensor_methods::full( + {num_weights}, 0, indices->GetDevice(), indices->dtype())); + XLATensorPtr ones = GetValueOrThrow(tensor_methods::full( + {numel}, 1, indices->GetDevice(), indices->dtype())); tensor_methods::index_put_(counts, counts, {indices_rank1}, /*start_dim=*/0, /*values=*/ones, /*accumulate=*/true, /*result_permutation=*/{0}); From 8243a25bdd5d0527a9fee8a45fa6a44c8e8fba46 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Sat, 23 Aug 2025 14:25:28 -0300 Subject: [PATCH 062/133] `gather`: improve error handling and error messages. (#9566) This PR refactors the `tensor_methods::gather` implementation by improving its error message, and returning a status type value. **Key Changes:** - Make `tensor_methods::gather` return `StatusOr` - Improve error message on incompatible tensor shapes --- test/test_operations.py | 33 +++++++++++++++++++ torch_xla/csrc/aten_xla_type.cpp | 4 +-- torch_xla/csrc/tensor_methods.cpp | 55 ++++++++++++++++++++++++------- torch_xla/csrc/tensor_methods.h | 5 +-- 4 files changed, 82 insertions(+), 15 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index a544d9ba19a7..a4e5b2e10449 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2521,6 +2521,39 @@ def test_full_raises_error_on_negative_size(self): f"positive values. However found negative ones: {shape}.") self.assertEqual(str(e), expected_error) + def test_gather_raises_error_on_rank_mismatch(self): + S = 2 + + input = torch.arange(4, device=torch_xla.device()).view(S, S) + index = torch.randint(0, S, (S, S, S), device=torch_xla.device()) + dim = 1 + + try: + torch.gather(input, dim, index) + except RuntimeError as e: + expected_error = ( + "gather(): expected rank of input (2) and index (3) tensors " + "to be the same.") + self.assertEqual(str(e), expected_error) + + def test_gather_raises_error_on_invalid_index_size(self): + S = 2 + X = S + 2 + + input = torch.arange(16, device=torch_xla.device()).view(S, S, S, S) + index = torch.randint(0, S, (X, S, X, S), device=torch_xla.device()) + dim = 1 + + try: + torch.gather(input, dim, index) + except RuntimeError as e: + expected_error = ( + f"gather(): expected sizes of index [{X}, {S}, {X}, {S}] to be " + f"smaller or equal those of input [{S}, {S}, {S}, {S}] on all " + f"dimensions, except on dimension {dim}. " + "However, that's not true on dimensions [0, 2].") + self.assertEqual(str(e), expected_error) + class MNISTComparator(nn.Module): diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index ceacd59603e8..5a75936b0c35 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1865,9 +1865,9 @@ at::Tensor XLANativeFunctions::gather(const at::Tensor& self, int64_t dim, const at::Tensor& index, bool /* sparse_grad */) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( + return bridge::AtenFromXlaTensor(GetValueOrThrow( tensor_methods::gather(GetValueOrThrow(bridge::GetXlaTensor(self)), dim, - GetValueOrThrow(bridge::GetXlaTensor(index)))); + GetValueOrThrow(bridge::GetXlaTensor(index))))); } at::Tensor XLANativeFunctions::gelu(const at::Tensor& self, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 7534191f0422..9c50db2f7bd2 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -442,6 +442,43 @@ absl::Status CheckFullConcreteSizesArePositive(at::SymIntArrayRef sym_sizes) { }); } +absl::Status CheckGatherRanksAreEqual(const XLATensorPtr& input, + const XLATensorPtr& index) { + int64_t input_rank = input->shape().get().dimensions_size(); + int64_t index_rank = index->shape().get().dimensions_size(); + if (input_rank != index_rank) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + "gather(): expected rank of input (", input_rank, ") and index (", + index_rank, ") tensors to be the same."))); + } + return absl::OkStatus(); +} + +// Checks that all index dimensions are smaller or equal to those of input, +// except on dimension canonical_dim. +absl::Status CheckGatherDimensionsAreCompatible(const XLATensorPtr& input, + const XLATensorPtr& index, + int64_t canonical_dim) { + // Dimensions that fail the "smaller or equal" condition. + std::vector bad_dims; + for (int64_t dim = 0; dim < input->shape().get().dimensions_size(); dim++) { + if (dim != canonical_dim && input->size(dim) < index->size(dim)) { + bad_dims.push_back(dim); + } + } + if (!bad_dims.empty()) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + "gather(): expected sizes of index [", + absl::StrJoin(index->shape().get().dimensions(), /* sep= */ ", "), + "] to be smaller or equal those of input [", + absl::StrJoin(input->shape().get().dimensions(), /* sep= */ ", "), + "] on all dimensions, except on dimension ", canonical_dim, + ". However, that's not true on dimensions [", + absl::StrJoin(bad_dims, /* sep= */ ", "), "]."))); + } + return absl::OkStatus(); +} + } // namespace ////////////////////////////////////////////////////////////////////////////// @@ -1838,18 +1875,14 @@ absl::StatusOr full_symint( device, scalar_type); } -XLATensorPtr gather(const XLATensorPtr& input, int64_t dim, - const XLATensorPtr& index) { - xla::Shape input_shape = input->shape(); - xla::Shape index_shape = index->shape(); - XLA_CHECK_EQ(input_shape.dimensions_size(), index_shape.dimensions_size()); +absl::StatusOr gather(const XLATensorPtr& input, + int64_t dim, + const XLATensorPtr& index) { int64_t canonical_dim = torch::lazy::GetCanonicalDimensionIndex( - dim, input_shape.dimensions_size()); - for (size_t dim = 0; dim < input_shape.dimensions_size(); dim++) { - if (dim != canonical_dim) { - XLA_CHECK_LE(index->size(dim), input->size(dim)); - } - } + dim, input->shape().get().dimensions_size()); + XLA_RETURN_IF_ERROR(CheckGatherRanksAreEqual(input, index)); + XLA_RETURN_IF_ERROR( + CheckGatherDimensionsAreCompatible(input, index, canonical_dim)); return input->CreateFrom(torch_xla::MakeNode( input->GetIrValue(), canonical_dim, index->GetIrValue())); } diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 869dcaa8dffb..3c9570833573 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -470,8 +470,9 @@ absl::StatusOr full_symint( at::SymIntArrayRef sym_size, const at::Scalar& fill_value, const torch::lazy::BackendDevice& device, at::ScalarType scalar_type); -XLATensorPtr gather(const XLATensorPtr& input, int64_t dim, - const XLATensorPtr& index); +absl::StatusOr gather(const XLATensorPtr& input, + int64_t dim, + const XLATensorPtr& index); XLATensorPtr ge(const XLATensorPtr& input, const at::Scalar& other); From 49ac22a612c4a848dcb976c9e1303b8da5622ca3 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Mon, 25 Aug 2025 08:20:21 -0300 Subject: [PATCH 063/133] `random_`: improve error handling and error messages. (#9567) This PR refactors the `random_` operation implementation by improving its error message, and returning a status type value. **Key Changes:** - Make `tensor_methods::random_` return `Status` - Replace `CheckRangeValues` by `CheckValueWithinTypeRange`, and make it return `Status` - Refactor `XLANativeFunctions::random_` overloads to handle the status values - Improve error messages --- test/test_operations.py | 27 ++++++++++++++++ torch_xla/csrc/aten_xla_type.cpp | 54 ++++++++++++++++++++----------- torch_xla/csrc/tensor_methods.cpp | 9 ++++-- torch_xla/csrc/tensor_methods.h | 2 +- 4 files changed, 70 insertions(+), 22 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index a4e5b2e10449..9d377083da54 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2554,6 +2554,33 @@ def test_gather_raises_error_on_invalid_index_size(self): "However, that's not true on dimensions [0, 2].") self.assertEqual(str(e), expected_error) + def test_random__raises_error_on_empty_interval(self): + a = torch.empty(10, device=torch_xla.device()) + from_ = 3 + to_ = 1 + + try: + a.random_(from_, to_) + except RuntimeError as e: + expected_error = ( + f"random_(): expected `from` ({from_}) to be smaller than " + f"`to` ({to_}).") + self.assertEqual(str(e), expected_error) + + def test_random__raises_error_on_value_out_of_type_value_range(self): + a = torch.empty(10, device=torch_xla.device(), dtype=torch.float16) + from_ = 3 + to_ = 65504 + 1 + + try: + a.random_(from_, to_) + except RuntimeError as e: + expected_error = ( + f"random_(): expected `to` to be within the range " + f"[-65504, 65504]. However got value {to_}, which is greater " + "than the upper bound.") + self.assertEqual(str(e), expected_error) + class MNISTComparator(nn.Module): diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 5a75936b0c35..9606a989f831 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -18,6 +18,7 @@ #include #include "absl/log/absl_check.h" +#include "status.h" #include "torch/csrc/lazy/core/helpers.h" #include "torch/csrc/lazy/core/shape_inference.h" #include "torch/csrc/lazy/core/tensor_util.h" @@ -317,18 +318,27 @@ int64_t GetIntegerUpperLimitForType(torch::ScalarType dtype) { } } -void CheckRangeValues(torch::ScalarType dtype, int64_t from, int64_t to) { - XlaHelpers::MinMax min_max; - // Bound the min_max by int64_t since types of "from" and "to" are int64. - if (IsTypeWithLargerRangeThanLong(dtype)) { - min_max = XlaHelpers::MinMaxValues(xla::PrimitiveType::S64); - } else { - min_max = XlaHelpers::MinMaxValues(XlaTypeFromTorchType(dtype)); +absl::Status CheckValueWithinTypeRange(const std::string_view op, + const std::string_view arg, + torch::ScalarType dtype, int64_t value) { + xla::PrimitiveType type = IsTypeWithLargerRangeThanLong(dtype) + ? xla::PrimitiveType::S64 + : XlaTypeFromTorchType(dtype); + + XlaHelpers::MinMax mm = XlaHelpers::MinMaxValues(type); + int64_t min = mm.min.toLong(); + int64_t max = mm.max.toLong(); + + if (value < min || value > max) { + const std::string_view comparison = value < min ? "lower" : "greater"; + const std::string_view bound = value < min ? "lower bound" : "upper bound"; + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError( + absl::StrCat(op, "(): expected `", arg, "` to be within the range [", + min, ", ", max, "]. However got value ", value, + ", which is ", comparison, " than the ", bound, "."))); } - XLA_CHECK_GE(from, min_max.min.toLong()); - XLA_CHECK_LE(from, min_max.max.toLong()); - XLA_CHECK_GE(to, min_max.min.toLong()); - XLA_CHECK_LE(to, min_max.max.toLong()); + + return absl::OkStatus(); } std::pair GetBinaryOperands( @@ -3025,12 +3035,14 @@ at::Tensor& XLANativeFunctions::random_( } XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); at::ScalarType dtype = self_tensor->dtype(); + // Prevent "to_val" from overflowing with at::ScalarType::Long. int64_t inc = (dtype == at::ScalarType::Long) ? 0 : 1; int64_t to_val = (to) ? *to : GetIntegerUpperLimitForType(dtype) + inc; - XLA_CHECK_LE(from, to_val); - CheckRangeValues(self_tensor->dtype(), from, to_val - 1); - tensor_methods::random_(self_tensor, from, to_val); + + OkOrThrow(CheckValueWithinTypeRange("random_", "from", dtype, from)); + OkOrThrow(CheckValueWithinTypeRange("random_", "to", dtype, to_val - 1)); + OkOrThrow(tensor_methods::random_(self_tensor, from, to_val)); return self; } @@ -3043,10 +3055,12 @@ at::Tensor& XLANativeFunctions::random_( ATEN_OP2(random_, to)>::call(self, to, generator); } + XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - XLA_CHECK_GT(to, 0); - CheckRangeValues(self_tensor->dtype(), 0, to - 1); - tensor_methods::random_(self_tensor, 0, to); + at::ScalarType dtype = self_tensor->dtype(); + + OkOrThrow(CheckValueWithinTypeRange("random_", "to", dtype, to - 1)); + OkOrThrow(tensor_methods::random_(self_tensor, 0, to)); return self; } @@ -3060,10 +3074,12 @@ at::Tensor& XLANativeFunctions::random_( } XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); at::ScalarType dtype = self_tensor->dtype(); + // Prevent "to_val" from overflowing with at::ScalarType::Long. int64_t inc = (dtype == at::ScalarType::Long) ? 0 : 1; - tensor_methods::random_(self_tensor, 0, - GetIntegerUpperLimitForType(dtype) + inc); + int64_t to_val = GetIntegerUpperLimitForType(dtype) + inc; + + OkOrThrow(tensor_methods::random_(self_tensor, 0, to_val)); return self; } diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 9c50db2f7bd2..2786ca1718bd 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -2922,8 +2922,12 @@ XLATensorPtr dynamic_view(const XLATensorPtr& input, ////////////////////////////////////////////////////////////////////////////// -void random_(XLATensorPtr& input, int64_t from, int64_t to) { - XLA_CHECK_LE(from, to); +absl::Status random_(XLATensorPtr& input, int64_t from, int64_t to) { + if (from >= to) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError( + absl::StrCat("random_(): expected `from` (", from, + ") to be smaller than `to` (", to, ")."))); + } auto input_shape = input->shape(); input->SetInPlaceIrValue(torch_xla::MakeNode( XLAGraphExecutor::Get()->GetIrValueForScalar( @@ -2931,6 +2935,7 @@ void random_(XLATensorPtr& input, int64_t from, int64_t to) { XLAGraphExecutor::Get()->GetIrValueForScalar(to, xla::PrimitiveType::S64, input->GetDevice()), XLAGraphExecutor::Get()->GetRngSeed(input->GetDevice()), input_shape)); + return absl::OkStatus(); } XLATensorPtr randperm(int64_t n, const torch::lazy::BackendDevice& device, diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 3c9570833573..c28d7f2165e6 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -776,7 +776,7 @@ void put_(XLATensorPtr& input, const XLATensorPtr& index, std::tuple qr(const XLATensorPtr& input, bool some); -void random_(XLATensorPtr& input, int64_t from, int64_t to); +absl::Status random_(XLATensorPtr& input, int64_t from, int64_t to); XLATensorPtr randperm(int64_t n, const torch::lazy::BackendDevice& device, at::ScalarType scalar_type); From aada9fcdea413feaffe473ce8fc332416a528fda Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Mon, 25 Aug 2025 12:13:25 -0300 Subject: [PATCH 064/133] Remove `XLA_CUDA` and other CUDA build flags. (#9582) This PR removes the uses and mentions to `XLA_CUDA` and `TF_CUDA_COMPUTE_CAPABILITIES` flags. They are related to the now deprecated CUDA build. This PR also removes **Key Changes:** - (_.bazelrc_) Removed CUDA bazel configuration - (_build_util.py_) Removed the translation of `XLA_CUDA` environment variable to `--config=cuda` bazel argument - Removed uses of `XLA_CUDA` and `TF_CUDA_COMPUTE_CAPABILITIES` throughout the codebase - Removed some logic for compiling PyTorch/XLA with CUDA support --- .bazelrc | 13 ------------- .circleci/build.sh | 1 - .github/upstream/Dockerfile | 5 ----- benchmarks/nightly.sh | 2 +- build_util.py | 2 -- configuration.yaml | 7 +------ docker/Dockerfile | 4 ---- infra/ansible/config/env.yaml | 8 -------- scripts/build_torch_wheels.sh | 23 ----------------------- setup.py | 3 --- test/cpp/run_tests.sh | 3 --- test/run_tests.sh | 8 ++++---- 12 files changed, 6 insertions(+), 73 deletions(-) diff --git a/.bazelrc b/.bazelrc index 27e84d37729b..3dec0dc40643 100644 --- a/.bazelrc +++ b/.bazelrc @@ -79,18 +79,6 @@ build:native_arch_posix --host_copt=-march=native build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1 -build:cuda --repo_env TF_NEED_CUDA=1 -# "sm" means we emit only cubin, which is forward compatible within a GPU generation. -# "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations. -build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain -build:cuda --@local_config_cuda//:enable_cuda -build:cuda --define=xla_python_enable_gpu=true -build:cuda --cxxopt=-DXLA_CUDA=1 - -# Coverage with cuda/gcc/nvcc requires manually setting coverage flags. -coverage:cuda --per_file_copt=third_party/.*,torch_xla/.*@--coverage -coverage:cuda --linkopt=-lgcov - build:acl --define==build_with_acl=true build:nonccl --define=no_nccl_support=true @@ -105,7 +93,6 @@ build:tpu --define=with_tpu_support=true # Run tests serially with TPU and GPU (only 1 device is available). test:tpu --local_test_jobs=1 -test:cuda --local_test_jobs=1 ######################################################################### # RBE config options below. diff --git a/.circleci/build.sh b/.circleci/build.sh index 79178fcb4183..cfb7625f4d36 100755 --- a/.circleci/build.sh +++ b/.circleci/build.sh @@ -50,7 +50,6 @@ source $XLA_DIR/xla_env export GCLOUD_SERVICE_KEY_FILE="$XLA_DIR/default_credentials.json" export SILO_NAME='cache-silo-ci-dev-3.8_cuda_12.1' # cache bucket for CI export BUILD_CPP_TESTS='1' -export TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_70,sm_75,compute_80,$TF_CUDA_COMPUTE_CAPABILITIES" build_torch_xla $XLA_DIR popd diff --git a/.github/upstream/Dockerfile b/.github/upstream/Dockerfile index 9f617d257cb0..bdfd012cbfd6 100644 --- a/.github/upstream/Dockerfile +++ b/.github/upstream/Dockerfile @@ -15,11 +15,6 @@ ARG tpuvm="" # Disable CUDA for PyTorch ENV USE_CUDA "0" -# Enable CUDA for XLA -ENV XLA_CUDA "${cuda}" -ENV TF_CUDA_COMPUTE_CAPABILITIES "${cuda_compute}" -ENV TF_CUDA_PATHS "/usr/local/cuda,/usr/include,/usr" - # CUDA build guidance ENV NVIDIA_VISIBLE_DEVICES all ENV NVIDIA_DRIVER_CAPABILITIES compute,utility diff --git a/benchmarks/nightly.sh b/benchmarks/nightly.sh index 7817d02496e8..64b34055cbf9 100755 --- a/benchmarks/nightly.sh +++ b/benchmarks/nightly.sh @@ -99,7 +99,7 @@ if [[ ${IS_FRESH_RUN?} ]]; then # Query local compute capability. If that fails, assign a sane default. LOCAL_CAP=compute_$(nvidia-smi --query-gpu=compute_cap --format=csv | \ tail -1 | sed 's/\.//g' | grep -E '^[0-9]{2}$' || echo '80') - XLA_CUDA=1 TF_CUDA_COMPUTE_CAPABILITIES=${LOCAL_CAP:?} python setup.py develop + python setup.py develop cd ../.. # Set up torchbench deps. diff --git a/build_util.py b/build_util.py index 487f5116323e..ebc6a96a9215 100644 --- a/build_util.py +++ b/build_util.py @@ -43,8 +43,6 @@ def bazel_options_from_env() -> Iterable[str]: # Build configuration. if check_env_flag('BAZEL_VERBOSE'): bazel_flags.append('-s') - if check_env_flag('XLA_CUDA'): - bazel_flags.append('--config=cuda') if check_env_flag('XLA_CPU_USE_ACL'): bazel_flags.append('--config=acl') diff --git a/configuration.yaml b/configuration.yaml index a66a0399aa77..8231171a2c23 100644 --- a/configuration.yaml +++ b/configuration.yaml @@ -4,7 +4,7 @@ variables: PJRT_DEVICE: description: - Indicates which device is being used with PJRT. It can be either CPU, - TPU, or CUDA + or TPU type: string PJRT_SELECT_DEFAULT_DEVICE: description: @@ -36,11 +36,6 @@ variables: - Verbosity level for GRPC, e.g. INFO, ERROR, etc. type: string default_value: "ERROR" - XLA_CUDA: - description: - - Build the xla client with CUDA enabled. - type: bool - default_value: false GIT_VERSIONED_XLA_BUILD: description: - Creates a versioned build. In particular, appends a git sha to the diff --git a/docker/Dockerfile b/docker/Dockerfile index 7945ef232638..a01b778a78c5 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -29,10 +29,6 @@ RUN git clone https://github.com/pytorch/pytorch ENV USE_CUDA "0" ENV USE_MPI "0" -# Enable CUDA for XLA -ENV XLA_CUDA "${cuda}" -ENV TF_CUDA_COMPUTE_CAPABILITIES "${cuda_compute}" - # Whether to build for TPUVM mode ENV TPUVM_MODE "${tpuvm}" ENV BUNDLE_LIBTPU "${tpuvm}" diff --git a/infra/ansible/config/env.yaml b/infra/ansible/config/env.yaml index a5047ff1a0db..4fd733b494f7 100644 --- a/infra/ansible/config/env.yaml +++ b/infra/ansible/config/env.yaml @@ -13,10 +13,6 @@ release_env: ACCELERATOR: tpu TPUVM_MODE: 1 - cuda: - TF_CUDA_COMPUTE_CAPABILITIES: "{{ cuda_compute_capabilities }}" - XLA_CUDA: 1 - # Variables that will be passed to shell environment only for building PyTorch and XLA libs. build_env: common: @@ -41,10 +37,6 @@ build_env: aarch64: - cuda: - TF_CUDA_COMPUTE_CAPABILITIES: "{{ cuda_compute_capabilities }}" - XLA_CUDA: 1 - tpu: ACCELERATOR: tpu TPUVM_MODE: 1 diff --git a/scripts/build_torch_wheels.sh b/scripts/build_torch_wheels.sh index 25abe3bea010..0f30a1e4e623 100755 --- a/scripts/build_torch_wheels.sh +++ b/scripts/build_torch_wheels.sh @@ -56,28 +56,6 @@ function install_cudnn { rm -f "$CUDNN_FILE" } -function maybe_install_cuda { - if [ "$XLA_CUDA" == "1" ]; then - if [ ! -d "/usr/local/cuda" ]; then - local CUDA_VER="10.2" - local CUDA_SUBVER="89_440.33.01" - local CUDA_FILE="cuda_${CUDA_VER}.${CUDA_SUBVER}_linux.run" - wget "http://developer.download.nvidia.com/compute/cuda/${CUDA_VER}/Prod/local_installers/${CUDA_FILE}" - sudo sh "${CUDA_FILE}" --silent --toolkit - rm -f "${CUDA_FILE}" - fi - if [ ! -f "/usr/local/cuda/include/cudnn.h" ] && [ ! -f "/usr/include/cudnn.h" ]; then - install_cudnn - fi - export TF_CUDA_PATHS="/usr/local/cuda,/usr/include,/usr" - maybe_append 'export TF_CUDA_PATHS="/usr/local/cuda,/usr/include,/usr"' ~/.bashrc - if [ "$TF_CUDA_COMPUTE_CAPABILITIES" == "" ]; then - export TF_CUDA_COMPUTE_CAPABILITIES="7.0" - fi - maybe_append "export TF_CUDA_COMPUTE_CAPABILITIES=\"$TF_CUDA_COMPUTE_CAPABILITIES\"" ~/.bashrc - fi -} - function maybe_install_sources { if [[ $(uname -m) == "aarch64" && ! -d "$HOME/ComputeLibrary" ]]; then # install arm compute library @@ -148,7 +126,6 @@ function install_gcc() { function install_req_packages() { sudo apt-get -y install python3-pip git curl libopenblas-dev vim apt-transport-https ca-certificates wget procps - maybe_install_cuda install_bazel install_ninja } diff --git a/setup.py b/setup.py index 76ce042c6fee..72a7ac9ca5bd 100644 --- a/setup.py +++ b/setup.py @@ -19,9 +19,6 @@ # BAZEL_VERBOSE=0 # turn on verbose messages during the bazel build of the xla/xrt client # -# XLA_CUDA=0 -# build the xla/xrt client with CUDA enabled -# # XLA_CPU_USE_ACL=0 # whether to use ACL # diff --git a/test/cpp/run_tests.sh b/test/cpp/run_tests.sh index 2da0ccb55699..371f2d83084f 100755 --- a/test/cpp/run_tests.sh +++ b/test/cpp/run_tests.sh @@ -78,9 +78,6 @@ if [[ "$BAZEL_REMOTE_CACHE" == "1" ]]; then EXTRA_FLAGS="$EXTRA_FLAGS --remote_default_exec_properties=cache-silo-key=$SILO_NAME" fi fi -if [[ "$XLA_CUDA" == "1" ]]; then - EXTRA_FLAGS="$EXTRA_FLAGS --config=cuda" -fi if [[ "$BAZEL_VERB" == "coverage" ]]; then EXTRA_FLAGS="$EXTRA_FLAGS --remote_download_outputs=all" # for lcov symlink fi diff --git a/test/run_tests.sh b/test/run_tests.sh index 8cfe37e29f56..033089d651f5 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -164,8 +164,8 @@ function run_xla_op_tests1 { run_test "$_TEST_DIR/pjrt/test_runtime_multi_cpu.py" run_test "$_TEST_DIR/pjrt/test_internal_tpu.py" - PJRT_DEVICE=CPU XLA_CUDA=0 run_test "$_TEST_DIR/pjrt/test_ddp.py" - PJRT_DEVICE=CPU XLA_CUDA=0 run_test "$_TEST_DIR/pjrt/test_mesh_service.py" + PJRT_DEVICE=CPU run_test "$_TEST_DIR/pjrt/test_ddp.py" + PJRT_DEVICE=CPU run_test "$_TEST_DIR/pjrt/test_mesh_service.py" run_test "$_TEST_DIR/test_python_ops.py" run_test "$_TEST_DIR/test_ops.py" @@ -199,7 +199,7 @@ function run_xla_op_tests2 { run_test "$_TEST_DIR/eager/test_eager_with_xla_compile.py" run_test "$_TEST_DIR/eager/test_eager_with_torch_compile.py" - PJRT_DEVICE=CPU XLA_CUDA=0 run_test "$_TEST_DIR/eager/test_eager_all_reduce_in_place.py" + PJRT_DEVICE=CPU run_test "$_TEST_DIR/eager/test_eager_all_reduce_in_place.py" run_test "$_TEST_DIR/eager/test_eager_spmd.py" run_test "$_TEST_DIR/test_callback.py" @@ -332,7 +332,7 @@ function run_tests { elif [[ "$RUN_TORCH_MP_OP_TESTS" == "torch_mp_op" ]]; then echo "Running torch op tests..." - PJRT_DEVICE=CPU XLA_CUDA=0 run_mp_op_tests + PJRT_DEVICE=CPU run_mp_op_tests else # Run full tests without sharding, respects XLA_SKIP_* if [[ "$XLA_SKIP_XLA_OP_TESTS" != "1" ]]; then From e9a1c5f90b04df971e6ad6fb7eae0fa41541391c Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Mon, 25 Aug 2025 12:16:17 -0300 Subject: [PATCH 065/133] Remove OpenXLA CUDA fallback and `_XLAC_cuda_functions.so` extension. (#9581) This PR removes the OpenXLA CUDA fallback implementation, and also the `_XLA_cuda_functions.so` extension, completely. Starting on this PR, the fallback shall be run only on CPU. **Key Changes:** - Remove _aten_cuda_functions.cpp_ and _aten_cuda_functions.h_ - Remove the OpenXLA CUDA fallback functions from _aten_fallback.cpp_ - Remove the `_XLAC_cuda_functions.so` library from _BUILD_ - Remove the Python `_XLAC_cuda_functions.so` extension from _setup.py_ - Remove the conditional loading of `_XLAC_cuda_functions.so` Python extension from _torch_xla/__init__.py_ --- BUILD | 16 - configuration.yaml | 5 - setup.py | 1 - test/cpp/BUILD | 9 +- torch_xla/__init__.py | 10 - torch_xla/csrc/BUILD | 11 - torch_xla/csrc/aten_cuda_functions.cpp | 51 --- torch_xla/csrc/aten_cuda_functions.h | 34 -- torch_xla/csrc/aten_fallback.cpp | 402 +-------------------- torch_xla/csrc/aten_fallback.h | 8 +- torch_xla/csrc/xla_manual_registration.cpp | 1 + 11 files changed, 11 insertions(+), 537 deletions(-) delete mode 100644 torch_xla/csrc/aten_cuda_functions.cpp delete mode 100644 torch_xla/csrc/aten_cuda_functions.h diff --git a/BUILD b/BUILD index 900dfa4bc3b2..1b82e9d4b975 100644 --- a/BUILD +++ b/BUILD @@ -46,22 +46,6 @@ cc_binary( ]), ) -cc_binary( - name = "_XLAC_cuda_functions.so", - copts = [ - "-fopenmp", - "-fPIC", - ], - linkopts = [ - "-Wl,-soname,_XLAC_cuda_functions.so", - ], - linkshared = 1, - visibility = ["//visibility:public"], - deps = [ - "//torch_xla/csrc:aten_cuda_functions", - ], -) - test_suite( name = "cpp_tests", # testonly = True, diff --git a/configuration.yaml b/configuration.yaml index 8231171a2c23..c1760d608ae9 100644 --- a/configuration.yaml +++ b/configuration.yaml @@ -397,8 +397,3 @@ variables: your code. type: bool default_value: false - XLA_FALLBACK_CPU: - description: - - Forces CPU OpenXLA fallback. By default, PyTorch/XLA will run any operation - that doesn't have a lowering using PyTorch CUDA as fallback. Setting this - flag will force PyTorch/XLA to use PyTorch CPU as fallback. diff --git a/setup.py b/setup.py index 72a7ac9ca5bd..39bc9129e6a3 100644 --- a/setup.py +++ b/setup.py @@ -455,7 +455,6 @@ def _get_jax_install_requirements(): package_dir=package_dir_mapping, ext_modules=[ BazelExtension('//:_XLAC.so'), - BazelExtension('//:_XLAC_cuda_functions.so'), ], install_requires=[ 'absl-py>=1.0.0', diff --git a/test/cpp/BUILD b/test/cpp/BUILD index 00568e8573f5..77d40a10d549 100644 --- a/test/cpp/BUILD +++ b/test/cpp/BUILD @@ -54,7 +54,6 @@ ptxla_cc_test( ":cpp_test_util", ":torch_xla_test", "//torch_xla/csrc:tensor", - "//torch_xla/csrc:aten_cuda_functions", "@com_google_googletest//:gtest_main", ], ) @@ -65,7 +64,6 @@ ptxla_cc_test( deps = [ ":torch_xla_test", "//torch_xla/csrc:tensor", - "//torch_xla/csrc:aten_cuda_functions", "@com_google_googletest//:gtest_main", "@xla//xla:shape_util", ], @@ -81,7 +79,6 @@ ptxla_cc_test( "//torch_xla/csrc/runtime:debug_macros", "//torch_xla/csrc:status", "//torch_xla/csrc:tensor", - "//torch_xla/csrc:aten_cuda_functions", "//torch_xla/csrc:thread_pool", "@com_google_absl//absl/synchronization", "@com_google_googletest//:gtest_main", @@ -100,7 +97,6 @@ ptxla_cc_test( ":cpp_test_util", ":torch_xla_test", "//torch_xla/csrc:tensor", - "//torch_xla/csrc:aten_cuda_functions", "@com_google_googletest//:gtest_main", ], ) @@ -127,7 +123,6 @@ ptxla_cc_test( "//torch_xla/csrc/runtime:sys_util", "//torch_xla/csrc:status", "//torch_xla/csrc:tensor", - "//torch_xla/csrc:aten_cuda_functions", "@com_google_googletest//:gtest_main", "@xla//xla:xla_data_proto_cc", "@xla//xla/tsl/profiler/utils:session_manager", @@ -146,7 +141,6 @@ ptxla_cc_test( ":torch_xla_test", "//torch_xla/csrc/runtime:metrics", "//torch_xla/csrc:tensor", - "//torch_xla/csrc:aten_cuda_functions", "@com_google_googletest//:gtest_main", "@xla//xla:permutation_util", ], @@ -212,7 +206,6 @@ ptxla_cc_test( ":cpp_test_util", ":torch_xla_test", "//torch_xla/csrc:tensor", - "//torch_xla/csrc:aten_cuda_functions", "@com_google_googletest//:gtest_main", ], -) \ No newline at end of file +) diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index abafafb01cad..3f5f71ba6e5c 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -7,16 +7,6 @@ import torch -if not torch.cuda.is_available(): - # Load _XLAC_cuda_functions to RTLD_GLOBAL, so that it can be used by _XLAC. - flags = sys.getdlopenflags() - sys.setdlopenflags(flags | os.RTLD_NOW | os.RTLD_GLOBAL) - - import _XLAC_cuda_functions - - # Then, restore the original flags. - sys.setdlopenflags(flags) - import _XLAC from ._internal import tpu from .version import __version__ diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index f99dca0a74ef..8132ae733160 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -77,7 +77,6 @@ ptxla_cc_library( hdrs = [ "aten_autograd_ops.h", "aten_fallback.h", - "aten_cuda_functions.h", "aten_xla_bridge.h", "batch_norm.h", "convert_ops.h", @@ -364,16 +363,6 @@ ptxla_cc_library( ], ) -ptxla_cc_library( - name = "aten_cuda_functions", - srcs = ["aten_cuda_functions.cpp"], - hdrs = ["aten_cuda_functions.h"], - deps = [ - "@local_config_python//:python_headers", - "@pybind11//:pybind11_embed", - ], -) - cc_library( name = "status", srcs = ["status.cpp"], diff --git a/torch_xla/csrc/aten_cuda_functions.cpp b/torch_xla/csrc/aten_cuda_functions.cpp deleted file mode 100644 index 45a5a4a356e0..000000000000 --- a/torch_xla/csrc/aten_cuda_functions.cpp +++ /dev/null @@ -1,51 +0,0 @@ -#include "torch_xla/csrc/aten_cuda_functions.h" - -#include - -#include -#include - -// Context -// ======= -// aten_fallback.cpp (compiled into _XLAC.so library) uses these functions -// for providing OpenXLA fallback on CUDA. Therefore, they must be defined at -// some point, somewhere. -// -// Problem -// ======= -// These functions' definition should be provided by PyTorch. However, whenever -// it's not compiled with CUDA support, it doesn't. -// -// Solution -// ======== -// Load these backup definitions, whenever we detect that the PyTorch being -// used doesn't support CUDA. -// -// More Context -// ============ -// Our CI currently compiles PyTorch/XLA only once. However, the CI uses it -// with two versions of PyTorch: compiled with and without CUDA support. -// Thus, we must find a way of, at runtime, conditionally load these functions' -// definition, given the current used PyTorch. -// -// This file is compiled into a additional Python library that is conditionally -// loaded given the result of torch.cuda.is_available(). - -static void fail(const char* name) { - throw std::runtime_error("PyTorch was compiled without CUDA support."); -} - -namespace c10::cuda { - -// Returning 0 in this function forces OpenXLA fallback execution on CPU. -DeviceIndex device_count() noexcept { return 0; } - -c10::DeviceIndex current_device() { fail("c10::cuda::current_device()"); } - -void set_device(c10::DeviceIndex) { fail("c10::cuda::set_device()"); } - -void device_synchronize() { fail("c10::cuda::device_synchronize()"); } - -} // namespace c10::cuda - -PYBIND11_MODULE(_XLAC_cuda_functions, m) {} diff --git a/torch_xla/csrc/aten_cuda_functions.h b/torch_xla/csrc/aten_cuda_functions.h deleted file mode 100644 index 50cb3de1aa26..000000000000 --- a/torch_xla/csrc/aten_cuda_functions.h +++ /dev/null @@ -1,34 +0,0 @@ -#ifndef XLA_TORCH_XLA_CSRC_ATEN_CUDA_FUNCTIONS_H_ -#define XLA_TORCH_XLA_CSRC_ATEN_CUDA_FUNCTIONS_H_ - -#include - -// Forward declaration of PyTorch CUDA functions. -// Source: c10/cuda/CUDAFunctions.h -// -// These are needed in order to synchronize the CUDA device after running -// the operation in PyTorch eager mode. -// -// It would be better to include the actual header. However, if we build -// PyTorch/XLA in an environment where PyTorch wasn't compiled with CUDA -// (i.e. our CI), the build would fail. - -namespace c10 { - -// Type alias used inside PyTorch. -using DeviceIndex = int8_t; - -namespace cuda { - -DeviceIndex device_count() noexcept; - -c10::DeviceIndex current_device(); - -void set_device(c10::DeviceIndex); - -void device_synchronize(); - -} // namespace cuda -} // namespace c10 - -#endif // XLA_TORCH_XLA_CSRC_ATEN_CUDA_FUNCTIONS_H_ diff --git a/torch_xla/csrc/aten_fallback.cpp b/torch_xla/csrc/aten_fallback.cpp index 45f1c64980a9..cb4436b43784 100644 --- a/torch_xla/csrc/aten_fallback.cpp +++ b/torch_xla/csrc/aten_fallback.cpp @@ -1,38 +1,20 @@ #include "torch_xla/csrc/aten_fallback.h" #include +#include #include #include #include -#include #include #include -#include "torch_xla/csrc/aten_cuda_functions.h" -#include "torch_xla/csrc/aten_xla_bridge.h" -#include "torch_xla/csrc/dl_convertor.h" #include "torch_xla/csrc/function_call_tracker.h" -#include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/metrics.h" -#include "torch_xla/csrc/runtime/runtime.h" #include "torch_xla/csrc/runtime/tf_logging.h" -#include "torch_xla/csrc/xla_graph_executor.h" namespace torch_xla { -// List of operations that should be fallbacked to CPU instead of GPU. -static std::unordered_set _force_fallback_on_cpu{ - // This operation is a simple memory access that transforms the given - // 1-element tensor into a Scalar. - // - // Although it makes sense to run this operation on CPU (since the - // output will get copied back to CPU anyway), this also fixes a - // particular issue with moco benchmark. - // More details: https://github.com/pytorch/xla/issues/7647 - "aten::_local_scalar_dense", -}; - // TODO(jwtan): Replace this with torch::lazy::Counter. We need // _fallback_counters to remain as torch_xla::runtime::metrics::Counter to // support torch_xla::runtime::metrics::CreatePerformanceReport(). For more @@ -52,376 +34,6 @@ std::vector GetFallbackOperations() { return fallback; } -// Most of the functions for the CUDA fallback are a modified version of -// PyTorch's at::native::cpu_fallback function. -// -// Source: aten/src/ATen/native/CPUFallback.cpp -// -// While a better solution would be to adapt PyTorch's function to be device -// agnostic, the changes are not small enough that would make sense for adding -// just one more device. Therefore, we copied the needed functions in this file. -// -// Before each modified function below, we shall specify what has changed, -// if there was any. - -// Decide whether to run OpenXLA fallback operations on CUDA. -bool UseOpenXLAFallbackOnCUDA(const c10::OperatorHandle& op) { - // In order to run OpenXLA fallback operations on CUDA, the conditions below - // must be true: - - // 1. XLA_FALLBACK_CPU environment variable is NOT set - bool dont_fallback_on_cpu = - !runtime::sys_util::GetEnvBool("XLA_FALLBACK_CPU", false); - - // 2. The current ComputationClient DeviceType is CUDA. Basically, we don't - // support running OpenXLA fallback operations on CUDA if the current - // PyTorch/XLA DeviceType is not CUDA. - bool device_is_cuda = - runtime::GetComputationClientOrDie()->GetDeviceType().getType() == - XlaDeviceType::CUDA; - - // 3. PyTorch must have been compiled with CUDA support. Otherwise, our - // phony implementation in aten_cuda_functions.cpp will return 0 for the - // call below. - bool pytorch_device_is_not_zero = c10::cuda::device_count() > 0; - - // 4. There is a kernel registered for the CUDA dispatch key, for this - // operation. - bool has_cuda_kernel = op.hasKernelForDispatchKey(c10::DispatchKey::CUDA); - - // 5. The operation is not in the set of operations that should be forcefuly - // fallbacked on CPU. - bool dont_force_fallback_on_cpu = - _force_fallback_on_cpu.find(c10::toString(op.operator_name())) == - _force_fallback_on_cpu.end(); - - return dont_fallback_on_cpu && device_is_cuda && pytorch_device_is_not_zero && - has_cuda_kernel && dont_force_fallback_on_cpu; -} - -struct DeviceInfo { - DeviceInfo(c10::Device device, c10::DeviceIndex i = -1) - : common_device(device), index(i) {} - - // Synchronizes the CUDA device being used by PyTorch. - void synchronize() { - TORCH_CHECK(index != -1, "No defined XLA tensors found for CUDA fallback."); - // Save the current PyTorch device, in case it's not the same as the - // recorded tensor device. - c10::DeviceIndex current = c10::cuda::current_device(); - c10::cuda::set_device(index); - c10::cuda::device_synchronize(); - c10::cuda::set_device(current); - } - - // Common device for all XLA tensors. - // - // CUDA OpenXLA fallback is supported only when all XLA tensors live in - // the same XLA device. This field should be updated and checked every - // time we convert an XLA tensor argument into a CUDA tensor. - c10::Device common_device; - - // CUDA device index where the tensors live in. - // - // This is used for synchronizing the device where the fallback operation - // was called. This should ensure completion of the CUDA computation, in - // order to be used by another XLA computation. - c10::DeviceIndex index; -}; - -// Change: use of std::any_of instead of iterating with a for-loop. -static bool validate_tensor_list(const c10::List& tensorlist) { - return std::any_of(tensorlist.begin(), tensorlist.end(), - [](const at::Tensor& tensor) { return tensor.defined(); }); -} - -// Retrieve the inner XLATensorPtr, and check it lives inside CUDA. -static XLATensorPtr get_xla_cuda_tensor(const at::Tensor& tensor) { - XLATensorPtr xla_tensor = GetValueOrThrow(bridge::GetXlaTensor(tensor)); - const torch::lazy::BackendDevice& device = xla_tensor->GetDevice(); - TORCH_CHECK(device.type() == static_cast(XlaDeviceType::CUDA), - "OpenXLA CUDA fallback only supports XLA:CUDA tensors. Found a " - "tensor of another device: ", - device.toString()); - return xla_tensor; -} - -static bool is_valid_xla_tensor(const at::Tensor& tensor) { - return tensor.defined() && tensor.is_xla(); -} - -static at::Tensor to_cuda_tensor(const at::Tensor& tensor, - std::optional& info) { - // Skip undefined or non-XLA tensors. - if (!is_valid_xla_tensor(tensor)) { - return tensor; - } - - // Grab the DLManagedTensor. - DLManagedTensor* managed = torch_xla::toDLPack(tensor); - c10::DeviceIndex index = managed->dl_tensor.device.device_id; - - if (info.has_value()) { - TORCH_CHECK(info->common_device == tensor.device() && info->index == index, - "fallback supports only single XLA device."); - } else { - info = std::make_optional(DeviceInfo(tensor.device(), index)); - } - - // Create the CUDA tensor. - return at::fromDLPack(managed, [=](void*) { managed->deleter(managed); }); -} - -// Former 'to_cpu'. -// In order to move tensors from XLA to CUDA, we make use of the DLPack API. -// -// 1. Synchronize the XLA tensors, so that we can access their data pointer -// 2. Use DLPack in order to create a CUDA tensor -static std::vector to_cuda(const at::TensorList& tensors, - std::optional& info) { - // Synchronize tensors, so that we are able to grab their data pointer. - std::vector xla_tensors; - for (auto& tensor : tensors) { - if (is_valid_xla_tensor(tensor)) { - xla_tensors.push_back(get_xla_cuda_tensor(tensor)); - } - } - XLAGraphExecutor::Get()->SyncTensorsGraph( - &xla_tensors, /*devices=*/{}, /*wait=*/true, /*sync_ltc_data=*/true, - /*warm_up_cache_only=*/false); - - // Use DLPack for sharing the XLA storage with a newly created CUDA tensor. - std::vector cuda_tensors(tensors.size()); - std::transform( - tensors.begin(), tensors.end(), cuda_tensors.begin(), - [&](const at::Tensor& tensor) { return to_cuda_tensor(tensor, info); }); - return cuda_tensors; -} - -// Copy back the results from CUDA to XLA. -// Assumes that we have already synchronized CUDA. -static at::Tensor to_xla(const at::Tensor& tensor) { - return torch_xla::fromDLPack(at::toDLPack(tensor)); -} - -// Former 'cpu_fallback'. -// Changes: -// -// 1. Track the device index being used. Rationale: we synchronize the device -// before crossing device borders for correctness. -// -void cuda_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, - bool error_on_views) { - auto& schema_args = op.schema().arguments(); - const auto num_arguments = schema_args.size(); - auto arguments = torch::jit::last(stack, num_arguments); - const auto arguments_begin = stack->size() - num_arguments; - - std::vector tensor_args; - std::vector tensor_args_indices; - - std::vector> tensorlist_args; - std::vector tensorlist_args_indices; - - std::vector tensorlist_cuda_args; - - // This fallback only works if all XLA:CUDA tensor arguments are - // on the same CUDA device. - // - // We keep track of said device, so that after actually running - // the operation on PyTorch CUDA eager-mode, we synchronize the - // device. - // - // This variable is updated over the course of 'to_cuda' calls. - std::optional info; - - // Initialize CUDA device. - torch::utils::device_lazy_init(at::kCUDA); - - // Step 1: Convert all non-CUDA tensor inputs into CUDA tensors - // and put them on the stack at the correct indices. - for (const auto idx : c10::irange(arguments.size())) { - const auto& ivalue = arguments[idx]; - if (ivalue.isTensor()) { - tensor_args.push_back(ivalue.toTensor()); - tensor_args_indices.push_back(idx); - } else if (ivalue.isTensorList()) { - // Note: we copy each TensorList argument to CUDA individually out of - // convenience, but XLA would benefit from materializing all tensor and - // TensorList args onto the CUDA at the same time. We can improve this if - // we need better perf for XLA's CUDA fallbacks. - tensorlist_args.push_back(ivalue.toTensorList()); - tensorlist_args_indices.push_back(idx); - auto cuda_ivalue = c10::IValue( - c10::List(to_cuda(ivalue.toTensorList().vec(), info))); - tensorlist_cuda_args.push_back(cuda_ivalue); - (*stack)[arguments_begin + idx] = std::move(cuda_ivalue); - tensorlist_args.push_back(ivalue.toTensorList()); - } else if (ivalue.isOptionalTensorList()) { - auto opt_tensors = ivalue.toOptionalTensorList().vec(); - std::vector need_convert_tensors; - std::vector need_convert_tensors_index; - for (auto i : c10::irange(opt_tensors.size())) { - if (!opt_tensors[i].has_value() || !opt_tensors[i]->defined()) continue; - need_convert_tensors.push_back(opt_tensors[i].value()); - need_convert_tensors_index.push_back(i); - } - auto cuda_tensors = to_cuda(need_convert_tensors, info); - for (const auto i : c10::irange(need_convert_tensors_index.size())) { - auto idx = need_convert_tensors_index[i]; - opt_tensors[idx] = cuda_tensors[i]; - } - (*stack)[arguments_begin + idx] = c10::IValue(opt_tensors); - } else if (ivalue.isDevice()) { - c10::Device device = ivalue.toDevice(); - if (info.has_value()) { - TORCH_CHECK(info->common_device == device, "XLA tensors live in ", - info->common_device, " but found target device: ", device); - } else { - info->common_device = device; - } - (*stack)[arguments_begin + idx] = c10::IValue(c10::Device(at::kCUDA)); - } - } - // XLA requires all of the tensor arguments to be gathered up and converted to - // CUDA together. - auto cuda_tensors = to_cuda(tensor_args, info); - - for (const auto i : c10::irange(tensor_args_indices.size())) { - auto idx = tensor_args_indices[i]; - (*stack)[arguments_begin + idx] = c10::IValue(cuda_tensors[i]); - } - - // Step 2: Call the underlying CUDA implementation of the operator - op.redispatchBoxed(c10::DispatchKeySet(c10::DispatchKey::CUDA), stack); - - // Synchronize the device before actually converting back to XLA. - TORCH_CHECK(info.has_value()); - info->synchronize(); - - // Step 3: We need to take special care to handle mutable aliases properly: - // If any input tensors are mutable aliases, we need to - // directly copy the updated data on the CUDA tensors back to the original - // inputs. - for (const auto i : c10::irange(tensor_args_indices.size())) { - auto tensor_idx = tensor_args_indices[i]; - const c10::AliasInfo* alias_info = schema_args[tensor_idx].alias_info(); - if (alias_info != nullptr && alias_info->isWrite()) { - at::_copy_from_and_resize(cuda_tensors[i], tensor_args[i]); - } - } - - // We also need to explicit reapply input mutations to inputs that are lists - // of tensors - for (const auto i : c10::irange(tensorlist_args_indices.size())) { - auto tensorlist_idx = tensorlist_args_indices[i]; - const c10::AliasInfo* alias_info = schema_args[tensorlist_idx].alias_info(); - if (alias_info != nullptr && alias_info->isWrite()) { - const auto& cuda_tensors = tensorlist_cuda_args[i].toTensorList().vec(); - for (const auto idx : c10::irange(tensorlist_args[i].size())) { - at::_copy_from_and_resize(cuda_tensors[idx], tensorlist_args[i][idx]); - } - } - } - - // Step 4: Convert any CUDA output tensors back to the original input device. - // For mutable alias'd outputs, we also need to take special care - // to move the ORIGINAL input tensor back onto the stack, in place of - // the temporary CUDA output tensor that we created. - // - // See [CPU Fallback Does Not Handle View Operators] - const auto& schema_returns = op.schema().returns(); - const auto& num_returns = schema_returns.size(); - auto returns = torch::jit::last(stack, num_returns); - const auto returns_begin = stack->size() - num_returns; - - for (const auto idx : c10::irange(returns.size())) { - const c10::AliasInfo* alias_info = schema_returns[idx].alias_info(); - if (alias_info != nullptr && alias_info->isWrite()) { - // Case (1): mutable alias case. - // Move the input ivalue directly onto the stack in place of - // the existing cuda output tensor. - bool found_alias = false; - if (returns[idx].isTensor() && returns[idx].toTensor().defined()) { - // We could store some extra metadata on the function schema to avoid - // the loop here if we need to improve perf. - for (const auto i : c10::irange(tensor_args_indices.size())) { - auto input_tensor_idx = tensor_args_indices[i]; - const auto& input_tensor = cuda_tensors[i]; - const c10::AliasInfo* input_alias_info = - schema_args[input_tensor_idx].alias_info(); - // Checked above; adding assert to guard against breakage of the below - // condition due to changing the above if test. - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(alias_info != nullptr); - if (input_tensor.defined() && (alias_info == input_alias_info || - (input_alias_info != nullptr && - *alias_info == *input_alias_info))) { - // We've found the original input tensor that aliases with the - // current output. Wrap it in an IValue and put it directly on the - // stack. - (*stack)[returns_begin + idx] = c10::IValue(tensor_args[i]); - found_alias = true; - break; - } - } - } else if (returns[idx].isTensorList() && - validate_tensor_list(returns[idx].toTensorList())) { - for (const auto i : c10::irange(tensorlist_args_indices.size())) { - auto input_tensor_idx = tensorlist_args_indices[i]; - const c10::AliasInfo* input_alias_info = - schema_args[input_tensor_idx].alias_info(); - // Checked above; adding assert to guard against breakage of the below - // condition due to changing the above if test. - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(alias_info != nullptr); - if (validate_tensor_list(tensorlist_args[i]) && - (alias_info == input_alias_info || - (input_alias_info != nullptr && - *alias_info == *input_alias_info))) { - // We've found the original input tensor that aliases with the - // current output. Wrap it in an IValue and put it directly on the - // stack. - (*stack)[returns_begin + idx] = c10::IValue(tensorlist_args[i]); - found_alias = true; - break; - } - } - } - TORCH_CHECK( - found_alias, "The operator ", op.schema().operator_name(), - " appears to have invalid alias information. ", - "Found a return tensor argument with a mismatched mutable alias: ", - schema_returns[idx]); - } else { - if (alias_info != nullptr && !alias_info->isWrite()) { - // Case (3): immutable alias (view) case. - TORCH_CHECK( - false, "The operator ", op.schema().operator_name(), - " appears to be a view operator, ", - "but it has no implementation for the backend \"xla\". ", - "View operators don't support ", - "since the tensor's storage cannot be shared across devices."); - } - // Case (2): copy case. - // Copy the CUDA output tensor to the original device. - if (returns[idx].isTensor() && returns[idx].toTensor().defined()) { - (*stack)[returns_begin + idx] = - c10::IValue(to_xla(returns[idx].toTensor())); - } else if (returns[idx].isTensorList() && - validate_tensor_list(returns[idx].toTensorList())) { - const auto& cuda_tensors = returns[idx].toTensorList().vec(); - std::vector tensors; - tensors.reserve(cuda_tensors.size()); - - for (const auto& tensor : cuda_tensors) { - tensors.push_back(to_xla(tensor)); - } - (*stack)[returns_begin + idx] = - c10::IValue(c10::List(tensors)); - } - } - } -} - void xla_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { XLA_FN_TRACK(3); const auto name = c10::toString(op.operator_name()); @@ -447,14 +59,10 @@ void xla_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { } } - if (UseOpenXLAFallbackOnCUDA(op)) { - cuda_fallback(op, stack, true); - } else { - // Call the actual boxed CPU fallback. - // Set error_on_views as XLA should take care - // of all view ops after functionalization. - at::native::cpu_fallback(op, stack, true); - } + // Call the actual boxed CPU fallback. + // Set error_on_views as XLA should take care + // of all view ops after functionalization. + at::native::cpu_fallback(op, stack, true); } TORCH_LIBRARY_IMPL(_, XLA, m) { diff --git a/torch_xla/csrc/aten_fallback.h b/torch_xla/csrc/aten_fallback.h index 58b6c1ee74ff..cb050a56e1c4 100644 --- a/torch_xla/csrc/aten_fallback.h +++ b/torch_xla/csrc/aten_fallback.h @@ -1,7 +1,7 @@ -#ifndef XLA_TORCH_XLA_CSRC_ATEN_CPU_FALLBACK_H_ -#define XLA_TORCH_XLA_CSRC_ATEN_CPU_FALLBACK_H_ +#ifndef XLA_TORCH_XLA_CSRC_ATEN_FALLBACK_H_ +#define XLA_TORCH_XLA_CSRC_ATEN_FALLBACK_H_ -#include +#include namespace torch_xla { @@ -11,4 +11,4 @@ std::vector GetFallbackOperations(); } // namespace torch_xla -#endif // XLA_TORCH_XLA_CSRC_ATEN_CPU_FALLBACK_H_ +#endif // XLA_TORCH_XLA_CSRC_ATEN_FALLBACK_H_ diff --git a/torch_xla/csrc/xla_manual_registration.cpp b/torch_xla/csrc/xla_manual_registration.cpp index 4a5beea20988..f439d4634bdd 100644 --- a/torch_xla/csrc/xla_manual_registration.cpp +++ b/torch_xla/csrc/xla_manual_registration.cpp @@ -1,4 +1,5 @@ #include +#include #include #include "torch_xla/csrc/XLANativeFunctions.h" From abf18e4b717d552836ef3c084560c95f3254d5d2 Mon Sep 17 00:00:00 2001 From: qihqi Date: Mon, 25 Aug 2025 12:54:21 -0700 Subject: [PATCH 066/133] Fix case when both device & dtype are given in .to (#9583) --- torchax/test/test_misc.py | 47 +++++++++++++++++++++++++++++++++++++++ torchax/torchax/tensor.py | 6 ++--- 2 files changed, 50 insertions(+), 3 deletions(-) create mode 100644 torchax/test/test_misc.py diff --git a/torchax/test/test_misc.py b/torchax/test/test_misc.py new file mode 100644 index 000000000000..b93877a7fd64 --- /dev/null +++ b/torchax/test/test_misc.py @@ -0,0 +1,47 @@ +"""If you don't know which file a test should go, and don't want to make a new file +for a small test. PUt it here +""" +import torch +import unittest +import torchax +import jax +import jax.numpy as jnp + + +class MiscTest(unittest.TestCase): + + def test_extract_jax_kwargs(self): + + class M(torch.nn.Module): + + def forward(self, a, b): + return torch.sin(a) + torch.cos(b) + + weights, func = torchax.extract_jax(M()) + res = func( + weights, + args=(), + kwargs={ + 'a': jnp.array([1, 2, 3]), + 'b': jnp.array([3, 4, 5]) + }) + self.assertTrue( + jnp.allclose( + res, + jnp.sin(jnp.array([1, 2, 3])) + jnp.cos(jnp.array([3, 4, 5])))) + + def test_to_device(self): + env = torchax.default_env() + env.config.debug_print_each_op = True + with env: + step1 = torch.ones( + 100, + 100, + ) + step2 = torch.triu(step1, diagonal=1) + step3 = step2.to(dtype=torch.bool, device='jax') + self.assertEqual(step3.device.type, 'jax') + + +if __name__ == '__main__': + unittest.main() diff --git a/torchax/torchax/tensor.py b/torchax/torchax/tensor.py index 67bc074177ef..a325c51dfc10 100644 --- a/torchax/torchax/tensor.py +++ b/torchax/torchax/tensor.py @@ -469,12 +469,12 @@ def _to_copy(self, the_tensor, new_dtype, new_device): arr = self.t2j_copy(the_tensor) res = Tensor(arr, self, the_tensor.requires_grad) - if new_dtype is not None and new_dtype != the_tensor.dtype: - if isinstance(the_tensor, Tensor): + if new_dtype is not None and new_dtype != res.dtype: + if isinstance(res, Tensor): res = res.apply_jax(jnp.astype, mappings.t2j_dtype(new_dtype)) else: with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): - return the_tensor.to(device=new_device, dtype=new_dtype) + return res.to(device=new_device, dtype=new_dtype) return res def get_and_rotate_prng_key(self, From 5522c69a58ed1f89121b9e2ca24d14569af815f8 Mon Sep 17 00:00:00 2001 From: Brendan Folie Date: Mon, 25 Aug 2025 14:12:35 -0700 Subject: [PATCH 067/133] implement send and recv using collective_permute (#9373) #9315 --- test/pjrt/test_collective_ops_tpu.py | 50 ++++++++++++++++++++++ test/test_torch_distributed_xla_backend.py | 41 ------------------ torch_xla/core/xla_model.py | 3 -- torch_xla/distributed/xla_backend.py | 39 +++++++---------- 4 files changed, 66 insertions(+), 67 deletions(-) diff --git a/test/pjrt/test_collective_ops_tpu.py b/test/pjrt/test_collective_ops_tpu.py index 1afe859a2040..b60f47d365bd 100644 --- a/test/pjrt/test_collective_ops_tpu.py +++ b/test/pjrt/test_collective_ops_tpu.py @@ -336,6 +336,56 @@ def test_all_to_all_single(self, use_dynamo): expected.sort().values), f"Got {val}, expected {expected}") + @staticmethod + def _send_recv_pipeline(): + dist.init_process_group("xla", init_method='xla://') + device = torch_xla.device() + world_size = xr.world_size() + cutoff = world_size // 2 + index = xr.global_ordinal() + tensor = torch.tensor([index], dtype=torch.float, device=device) + if index < cutoff: + dist.send(tensor, index + cutoff) + else: + dist.recv(tensor, index - cutoff) + return tensor.cpu() + + @staticmethod + def _send_recv_permute(): + dist.init_process_group("xla", init_method='xla://') + device = torch_xla.device() + world_size = xr.world_size() + index = xr.global_ordinal() + sending_tensor = torch.tensor([index], dtype=torch.float, device=device) + receiving_tensor = torch.tensor([-1.0], dtype=torch.float, device=device) + if index % 2 == 0: + dist.send(sending_tensor, (index + 1) % world_size) + dist.recv(receiving_tensor, (index - 1) % world_size) + else: + dist.recv(receiving_tensor, (index - 1) % world_size) + dist.send(sending_tensor, (index + 1) % world_size) + return receiving_tensor.cpu() + + @absltest.skipUnless(tpu.num_available_devices() % 2 == 0, + "Send/Recv test requires even number of devices") + def test_send_recv_pipeline(self): + """Send tensors on first N/2 devices to second N/2 devices.""" + results = pjrt.run_multiprocess(self._send_recv_pipeline) + world_size = tpu.num_expected_global_devices() + for ordinal, value in results.items(): + expected = ordinal if ordinal < world_size // 2 else ordinal - world_size // 2 + np.testing.assert_array_equal(value, [expected]) + + @absltest.skipUnless(tpu.num_available_devices() % 2 == 0, + "Send/Recv test requires even number of devices") + def test_send_recv_permute(self): + """Send tensor on device i to i + 1 (module world size).""" + results = pjrt.run_multiprocess(self._send_recv_permute) + world_size = tpu.num_expected_global_devices() + for ordinal, value in results.items(): + expected = (ordinal - 1) % world_size + np.testing.assert_array_equal(value, [expected]) + @staticmethod def _all_to_all(): dist.init_process_group("xla", init_method='xla://') diff --git a/test/test_torch_distributed_xla_backend.py b/test/test_torch_distributed_xla_backend.py index 3af6aaa8a080..bb0dfd3efd7f 100644 --- a/test/test_torch_distributed_xla_backend.py +++ b/test/test_torch_distributed_xla_backend.py @@ -166,47 +166,6 @@ def test_reduce_scatter_coalesced(self): # purge all computations attached the device. torch_xla.sync() - @patch_world(0, 6) - def test_send(self): - device = torch_xla.device() - tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() - input_list = [tensor] - - with mock.patch.object( - torch_xla.distributed.xla_backend.ProcessGroupXla, - 'make_send_channel_id', - new=lambda self, dst_rank, tag: dst_rank * 2): - dist.send(tensor, 1) - - send_pattern = r'%send\.\d+ = .+ send\(.+\), channel_id=2' - senddone_pattern = r'%send\-done\.\d+ = .+ send\-done\(.+\), channel_id=2' - hlo = torch_xla._XLAC._get_xla_tensors_hlo([tensor]) - hlo_matches(hlo, send_pattern) - hlo_matches(hlo, senddone_pattern) - - # Don't try to run Send on CPU because it's not implemented - torch_xla._XLAC._clear_pending_irs(str(torch_xla.device())) - - @patch_world(0, 6) - def test_recv(self): - device = torch_xla.device() - tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() - - with mock.patch.object( - torch_xla.distributed.xla_backend.ProcessGroupXla, - 'make_recv_channel_id', - new=lambda self, src_rank, tag: src_rank * 3): - dist.recv(tensor, 1) - - recv_pattern = r'%recv\.\d+ = .+ recv\(.+\), channel_id=3' - recvdone_pattern = r'%recv\-done\.\d+ = .+ recv\-done\(.+\), channel_id=3' - hlo = torch_xla._XLAC._get_xla_tensors_hlo([tensor]) - hlo_matches(hlo, recv_pattern) - hlo_matches(hlo, recvdone_pattern) - - # Don't try to run Recv on CPU because it's not implemented - torch_xla._XLAC._clear_pending_irs(str(torch_xla.device())) - @patch_world(rank=0, size=12) def test_new_group_no_ranks(self): with new_group_barrier_disabled(): diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 6b68e656d333..3dbad1a963eb 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -748,9 +748,6 @@ def collective_permute(value: torch.Tensor, pairs: List[List[int]]) -> torch.Tensor: """Performs a XLA `CollectivePermute()` operation on the input tensor. - WARNING: This function is not very reliable, may produce wrong results under - certain inputs. Use it at your own risk. - See: https://www.tensorflow.org/xla/operation_semantics#collectivepermute Args: diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index 87ed4bbd7a59..3c7848d6e327 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -337,41 +337,34 @@ def scatter(self, output_tensor_list: list[torch.Tensor], rs_opts.reduceOp = dist.ReduceOp.SUM return self.reduce_scatter(output_tensor_list, inputs, rs_opts) - # Dummy channel id maker. Different backend (TPU, GPU, etc) should replace - # the maker with their specific one. See unit test in - # test/test_torch_distributed_xla_backend.py for an example. - def make_send_channel_id(self, dst_rank, tag): - raise NotImplementedError - # Call site e.g. # https://github.com/pytorch/pytorch/blob/release/1.10/torch/distributed/distributed_c10d.py#L877 def send(self, tensors, dst_rank, tag=0): + logging.warning( + "Individual send/recv ops are inefficient on an XLA device. Consider using xla_model.collective_permute() and specifying all source-target pairs." + ) results = [] for t in tensors: - channel_id = self.make_send_channel_id(dst_rank, tag) - # The input will be returned as result. - input_as_result = xm.send(t, channel_id) - # Make the sent tensor depend on the token, such that the `send` - # op can actually be built into the computation graph. - with torch.no_grad(): - t.copy_(input_as_result) - results.append(input_as_result) + result_t = xm.collective_permute( + t, pairs=[[xr.global_ordinal(), dst_rank]]) + torch_xla.sync() + results.append(result_t) return _ret_work(results) - # Dummy channel id maker. Different backend (TPU, GPU, etc) should replace - # the maker with their specific one. See unit test in - # test/test_torch_distributed_xla_backend.py for an example. - def make_recv_channel_id(self, src_rank, tag): - raise NotImplementedError - # Call site e.g. # https://github.com/pytorch/pytorch/blob/release/1.10/torch/distributed/distributed_c10d.py#L913 def recv(self, out_tensors, src_rank, tag=0): + logging.warning( + "Individual send/recv ops are inefficient on an XLA device. Consider using xla_model.collective_permute() and specifying all source-target pairs." + ) results = [] for ot in out_tensors: - channel_id = self.make_recv_channel_id(src_rank, tag) - result = xm.recv(ot, channel_id) - results.append(result) + result_t = xm.collective_permute( + ot, pairs=[[src_rank, xr.global_ordinal()]]) + torch_xla.sync() + with torch.no_grad(): + ot.copy_(result_t) + results.append(result_t) return _ret_work(results) def recv_anysource(self, *args): From 163193ebe9715d353c10b9fb6cb629ec88e2520e Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Tue, 26 Aug 2025 02:07:47 -0700 Subject: [PATCH 068/133] Set environment variables for tpu7x (#9586) --- torch_xla/_internal/tpu.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/torch_xla/_internal/tpu.py b/torch_xla/_internal/tpu.py index 182c8675fbdd..2993a033c6eb 100644 --- a/torch_xla/_internal/tpu.py +++ b/torch_xla/_internal/tpu.py @@ -52,6 +52,8 @@ # Testing only '0x0056', '0x0062', + # TPU 7x + '0x0076' ] @@ -188,7 +190,10 @@ def version() -> int: except requests.HTTPError as e: raise EnvironmentError('Failed to get TPU metadata') from e - match = re.match(r'^v(\d)([A-Za-z]?){7}-(\d+)$', env[xenv.ACCELERATOR_TYPE]) + match = re.match(r'^(?:v|tpu)(\d)([A-Za-z]?){7}-(\d+)$', + env[xenv.ACCELERATOR_TYPE]) + if not match: + raise EnvironmentError('Failed to parse TPU version from metadata') return int(match.groups()[0]) @@ -254,7 +259,8 @@ def configure_topology(local_rank: int, tpu_env = get_tpu_env() accelerator_type = tpu_env[xenv.ACCELERATOR_TYPE] - if version() >= 4: + tpu_version = version() + if tpu_version >= 4: # Process bounds with 4 chips per process default_process_bounds = MeshShape.from_string( tpu_env[xenv.TPU_PROCESS_BOUNDS]) @@ -270,8 +276,11 @@ def configure_topology(local_rank: int, process_bounds = default_process_bounds * chips_per_process os.environ.setdefault(xenv.TPU_CHIPS_PER_PROCESS_BOUNDS, '1,1,1') - os.environ.setdefault(xenv.TPU_PROCESS_BOUNDS, - ','.join(str(dim) for dim in process_bounds)) + process_bounds_str = ','.join(str(dim) for dim in process_bounds) + if tpu_version == 7: + process_bounds_str += ',2' + + os.environ.setdefault(xenv.TPU_PROCESS_BOUNDS, process_bounds_str) # Assume each TPU has the same number of local processes with the same ports worker_id = int(tpu_env[xenv.WORKER_ID]) From 4c586bd00e8984cdeac44d1f3ded3ab28d2f826f Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 27 Aug 2025 12:20:12 -0300 Subject: [PATCH 069/133] Create new macros for throwing status errors. (#9588) Follow-up: #9580 This PR introduces 2 new macros for throwing status errors: `XLA_THROW_IF_ERROR()`, and `XLA_ASSIGN_OR_THROW()`. These macros are analogous to the already existing `XLA_RETURN_IF_ERROR()` and `XLA_ASSIGN_OR_RETURN()`, where instead of propagating (i.e. returning) error status, they throw an exception with the given error status. **Key Changes:** - (_status.h_ and _status.cpp_) New function: `ThrowStatusError(...)` - (_status.h_) Refactors the implementation of the existing macros, so as to use its definition for all of the aforementioned macros - `XLA_PROCESS_STATUS_IMPL_(...)` : core implementation of those macros. - `XLA_PROPAGATE_STATUS_IMPL_(var, ...)`: propagates the given status `var`. - `XLA_THROW_STATUS_IMPL_(...)`: calls the newly added `ThrowStatusError()` function, which throws an exception - `XLA_DO_IF_ERROR_IMPL_(...)`: core implementation of `XLA_*_IF_ERROR()` macros - `XLA_RETURN_IF_ERROR(...)`: combines `XLA_DO_IF_ERROR_IMPL_` with `XLA_PROPAGATE_STATUS_IMPL_` - `XLA_THROW_IF_ERROR(...)`: combines `XLA_DO_IF_ERROR_IMPL_` with `XLA_THROW_STATUS_IMPL_` - `XLA_ASSIGN_OR_DO_IMPL_(...)`: core implementation of `XLA_ASSIGN_OR_*()` macros - `XLA_ASSIGN_OR_RETURN(...)`: combines `XLA_ASSIGN_OR_DO_IMPL_` with `XLA_PROPAGATE_STATUS_IMPL_` - `XLA_ASSGIN_OR_THROW(...)`: combines `XLA_ASSIGN_OR_DO_IMPL_` with `XLA_THROW_STATUS_IMPL_` - (_test_status_common.h_) Add one test for each of the 2 new public macros --- test/cpp/test_status_common.h | 134 +++++++++++++++++++++++++++- torch_xla/csrc/status.cpp | 11 +++ torch_xla/csrc/status.h | 160 ++++++++++++++++++++++++++-------- 3 files changed, 266 insertions(+), 39 deletions(-) diff --git a/test/cpp/test_status_common.h b/test/cpp/test_status_common.h index cb917942ffe3..17b0ef29f5ff 100644 --- a/test/cpp/test_status_common.h +++ b/test/cpp/test_status_common.h @@ -80,8 +80,10 @@ class StatusTest : public testing::TestWithParam { namespace cpp_test { // Prefix of the C++ stacktrace PyTorch adds to the error message. -constexpr inline char kTorchCppStacktracePrefix[] = +constexpr inline char kTorchCppStacktracePrefixDeprecated[] = "Exception raised from OkOrThrow at torch_xla/csrc/status.cpp:"; +constexpr inline char kTorchCppStacktracePrefix[] = + "Exception raised from ThrowStatusError at torch_xla/csrc/status.cpp:"; constexpr inline char kNewMessage[] = "New test error message"; constexpr inline char kMessage[] = "Test error message"; @@ -113,7 +115,7 @@ TEST_P(StatusTest, OkOrThrowWithErrorStatus) { if (IsShowCppStacktracesMode()) { EXPECT_THAT(std::string_view(error.what()), ::testing::StartsWith(absl::StrCat( - kMessage, "\n\n", kTorchCppStacktracePrefix))); + kMessage, "\n\n", kTorchCppStacktracePrefixDeprecated))); } else { EXPECT_EQ(std::string_view(error.what_without_backtrace()), std::string_view(kMessage)); @@ -136,7 +138,7 @@ TEST_P(StatusTest, GetValueOrThrowWithErrorStatusOr) { if (IsShowCppStacktracesMode()) { EXPECT_THAT(std::string_view(error.what()), ::testing::StartsWith(absl::StrCat( - kMessage, "\n\n", kTorchCppStacktracePrefix))); + kMessage, "\n\n", kTorchCppStacktracePrefixDeprecated))); } else { EXPECT_EQ(std::string_view(error.what_without_backtrace()), std::string_view(kMessage)); @@ -388,6 +390,132 @@ TEST_P(StatusTest, OkOrThrowWithErrorPropagationWithNewMessage) { oss << kEntryPrefix << "From: operator() at " << __FILE__ << ":" << errline2; oss << "\n\n"; + oss << kTorchCppStacktracePrefixDeprecated; + EXPECT_THAT(std::string_view(error.what()), + ::testing::StartsWith(oss.str())); + } else { + EXPECT_EQ(std::string_view(error.what_without_backtrace()), + std::string_view(kNewMessage)); + } + } +} + +TEST_P(StatusTest, MacroThrowIfErrorWithErrorPropagationWithNewMessage) { + int32_t errline0 = __LINE__ + 2; + auto innerfn = [&]() -> absl::Status { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(kMessage)); + }; + + int32_t errline1 = __LINE__ + 2; + auto midfn = [&]() -> absl::Status { + XLA_RETURN_IF_ERROR(innerfn(), kNewMessage); + return absl::OkStatus(); + }; + + int32_t errline2 = __LINE__ + 2; + auto outerfn = [&]() -> absl::Status { + XLA_RETURN_IF_ERROR(midfn()); + return absl::OkStatus(); + }; + + int32_t errline3 = __LINE__ + 2; + try { + XLA_THROW_IF_ERROR(outerfn()); + FAIL() << "Expected `XLA_THROW_IF_ERROR(outerfn())` to throw."; + } catch (const c10::Error& error) { + if (IsShowCppStacktracesMode()) { + // clang-format off + // + // Expected Error Message Prefix + // ============================= + // + // New test error kMessage + // + // Status Propagation Stacktrace: + // From: operator() at ./test/cpp/test_status_common.h:334 (error: Test error kMessage) + // From: operator() at ./test/cpp/test_status_common.h:339 (error: New test error kMessage) + // From: operator() at ./test/cpp/test_status_common.h:345 + // From: TestBody at ./test/cpp/test_status_common.h:350 + // + // C++ Stacktrace: + // + // clang-format on + std::ostringstream oss; + oss << kNewMessage; + oss << "\n\n"; + oss << "Status Propagation Trace:"; + oss << kEntryPrefix << "From: operator() at " << __FILE__ << ":" + << errline0 << " (error: " << kMessage << ")"; + oss << kEntryPrefix << "From: operator() at " << __FILE__ << ":" + << errline1 << " (error: " << kNewMessage << ")"; + oss << kEntryPrefix << "From: operator() at " << __FILE__ << ":" + << errline2; + oss << kEntryPrefix << "From: TestBody at " << __FILE__ << ":" + << errline3; + oss << "\n\n"; + oss << kTorchCppStacktracePrefix; + EXPECT_THAT(std::string_view(error.what()), + ::testing::StartsWith(oss.str())); + } else { + EXPECT_EQ(std::string_view(error.what_without_backtrace()), + std::string_view(kNewMessage)); + } + } +} + +TEST_P(StatusTest, MacroAssignOrThrowWithErrorPropagationWithNewMessage) { + int32_t errline0 = __LINE__ + 2; + auto innerfn = [&]() -> absl::Status { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(kMessage)); + }; + + int32_t errline1 = __LINE__ + 2; + auto midfn = [&]() -> absl::Status { + XLA_RETURN_IF_ERROR(innerfn(), kNewMessage); + return absl::OkStatus(); + }; + + int32_t errline2 = __LINE__ + 2; + auto outerfn = [&]() -> absl::StatusOr { + XLA_RETURN_IF_ERROR(midfn()); + return 42; + }; + + int32_t errline3 = __LINE__ + 2; + try { + XLA_ASSIGN_OR_THROW(int ret, outerfn()); + FAIL() << "Expected `XLA_ASSIGN_OR_THROW(int ret, outerfn())` to throw."; + } catch (const c10::Error& error) { + if (IsShowCppStacktracesMode()) { + // clang-format off + // + // Expected Error Message Prefix + // ============================= + // + // New test error kMessage + // + // Status Propagation Stacktrace: + // From: operator() at ./test/cpp/test_status_common.h:393 (error: Test error kMessage) + // From: operator() at ./test/cpp/test_status_common.h:398 (error: New test error kMessage) + // From: operator() at ./test/cpp/test_status_common.h:404 + // From: TestBody at ./test/cpp/test_status_common.h:410 + // + // C++ Stacktrace: + // + // clang-format on + std::ostringstream oss; + oss << kNewMessage; + oss << "\n\n"; + oss << "Status Propagation Trace:"; + oss << kEntryPrefix << "From: operator() at " << __FILE__ << ":" + << errline0 << " (error: " << kMessage << ")"; + oss << kEntryPrefix << "From: operator() at " << __FILE__ << ":" + << errline1 << " (error: " << kNewMessage << ")"; + oss << kEntryPrefix << "From: operator() at " << __FILE__ << ":" + << errline2; + oss << kEntryPrefix << "From: TestBody at " << __FILE__ << ":" + << errline3; + oss << "\n\n"; oss << kTorchCppStacktracePrefix; EXPECT_THAT(std::string_view(error.what()), ::testing::StartsWith(oss.str())); diff --git a/torch_xla/csrc/status.cpp b/torch_xla/csrc/status.cpp index 5b86927b72db..56636874f0b9 100644 --- a/torch_xla/csrc/status.cpp +++ b/torch_xla/csrc/status.cpp @@ -118,6 +118,17 @@ static std::string LineBreakIfCppStacktracesEnabled() { return torch::get_cpp_stacktraces_enabled() ? "\n" : ""; } +void status_internal::ThrowStatusError(const absl::Status& status, + const char* file, const int32_t line, + const char* function, + std::string_view message) { + ABSL_CHECK(!status.ok()); + absl::Status new_status = status_internal::MaybeWithNewMessage( + status, file, line, function, message); + TORCH_CHECK(false, absl::StrCat(BuildStatusErrorMessage(new_status), + LineBreakIfCppStacktracesEnabled())); +} + void OkOrThrow(const absl::Status& status) { TORCH_CHECK(status.ok(), absl::StrCat(BuildStatusErrorMessage(status), LineBreakIfCppStacktracesEnabled())); diff --git a/torch_xla/csrc/status.h b/torch_xla/csrc/status.h index 28f16860479a..87a9227672bb 100644 --- a/torch_xla/csrc/status.h +++ b/torch_xla/csrc/status.h @@ -62,30 +62,56 @@ constexpr char kStatusPropagationTraceKey[] = #define XLA_STATUS_VAR_ XLA_CONCAT_(status_, __LINE__) // Provides a flexible way to handle error checking with optional message -// modification. It evaluates `expr`, checks if it's OK, and either: -// 1. Returns early with an error status -// 2. Proceeds with the given `then` block if successful -#define XLA_RETURN_IF_ERROR_IMPL_(expr, var, then, ...) \ - auto var = (expr); \ - if (!var.ok()) { \ - return ::torch_xla::status_internal::MaybeWithNewMessage( \ - ::torch_xla::status_internal::GetStatus(var), __FILE__, __LINE__, \ - __FUNCTION__, ##__VA_ARGS__); \ - } \ - then - -// Propagates `rexpr`, in case it's a non-ok status. +// modification. It evaluates `expr`, and: // -// Example: +// 1. Runs the `on_error` block, if the returned status is an error +// 2. Runs the `on_success` block, otherwise // -// XLA_RETURN_IF_ERROR( -// FnThatReturnsStatus(), -// "New error message." -// ); +#define XLA_PROCESS_STATUS_IMPL_(on_error, on_success, expr, var, ...) \ + auto var = (expr); \ + if (!var.ok()) { \ + on_error(var, ##__VA_ARGS__); \ + } \ + on_success + +// `on_error` implementation for propagating the status `var`. +// +// This macro wraps `var` (error status returned) into a new status, adding +// source location information to the status propagation trace if +// `TORCH_SHOW_CPP_STACKTRACES` is set. And then, returns the newly created +// status. +// +// It should be only used as parameter to `XLA_PROCESS_STATUS_IMPL_` macro +// defined above. +// +#define XLA_PROPAGATE_STATUS_IMPL_(var, ...) \ + return ::torch_xla::status_internal::MaybeWithNewMessage( \ + ::torch_xla::status_internal::GetStatus(var), __FILE__, __LINE__, \ + __FUNCTION__, ##__VA_ARGS__) + +// `on_error` implementation for throwing an exception with the status `var`. // -// If the function call results in an ok status, execution continues. Otherwise, -// we early return a non-ok status. Then, if `TORCH_SHOW_CPP_STACKTRACES` is -// set, the error shown will be: +// This macro wraps `var` (error status returned) into a new status, adding +// source location information to the status propagation trace if +// `TORCH_SHOW_CPP_STACKTRACES` is set. And then, throws an exception using the +// `ThrowStatusError()` function. +// +// It should be only used as parameter to `XLA_PROCESS_STATUS_IMPL_` macro +// defined above. +// +#define XLA_THROW_STATUS_IMPL_(var, ...) \ + ::torch_xla::status_internal::ThrowStatusError( \ + ::torch_xla::status_internal::GetStatus(var), __FILE__, __LINE__, \ + __FUNCTION__, ##__VA_ARGS__) + +// Macro implementation for processing an `absl::Status` value. This is the core +// definition of `XLA_*_IF_ERROR()` macros that, given that `rexpr` is an error +// status, either throws or returns (i.e. propagates) a newly created status +// with source location information. +// +// If `rexpr` results in an ok status, execution continues. Otherwise, we run +// `on_error`. Then, if `TORCH_SHOW_CPP_STACKTRACES` is set, the error shown +// will be: // // RuntimeError: New error message. // @@ -95,18 +121,61 @@ constexpr char kStatusPropagationTraceKey[] = // ... // From: : (error: New error message.) // -#define XLA_RETURN_IF_ERROR(rexpr, ...) \ - do { \ - XLA_RETURN_IF_ERROR_IMPL_(rexpr, XLA_STATUS_VAR_, {}, ##__VA_ARGS__) \ +#define XLA_DO_IF_ERROR_IMPL_(on_error, rexpr, ...) \ + do { \ + XLA_PROCESS_STATUS_IMPL_(on_error, /* on_success= */ {}, rexpr, \ + XLA_STATUS_VAR_, ##__VA_ARGS__) \ } while (false) -// Propagates `rexpr`, in case it's a non-ok status. Otherwise, assign -// its result to `lhs`. +// If `rexpr` returns a non-ok status, this macro propagates the returned status +// by early-returning a, possibly, new status with source location information. +// Otherwise, continues execution. +// +// Example: +// +// XLA_RETURN_IF_ERROR( +// FnThatReturnsStatus(), +// "New error message." +// ); +// +#define XLA_RETURN_IF_ERROR(rexpr, ...) \ + XLA_DO_IF_ERROR_IMPL_(XLA_PROPAGATE_STATUS_IMPL_, rexpr, ##__VA_ARGS__) + +// If `rexpr` returns a non-ok status, this macro throws an exception with the +// returned status, possibly, wrapped by a new status with source location +// information. Otherwise, continues execution. +// +// Example: +// +// XLA_THROW_IF_ERROR( +// FnThatReturnsStatus(), +// "New error message." +// ); +// +#define XLA_THROW_IF_ERROR(rexpr, ...) \ + XLA_DO_IF_ERROR_IMPL_(XLA_THROW_STATUS_IMPL_, rexpr, ##__VA_ARGS__) + +// Macro implementation for processing an `absl::Status` value. This is the core +// definition of `XLA_ASSIGN_OR_*()` macros that, given that `rexpr` is an error +// status, either throws or returns (i.e. propagates) a newly created status +// with source location information. +// +// If `rexpr` results in an ok status, we assign the value held by the status +// returned by `rexpr` to `lhs`. Otherwise, we run `on_error`. // // Note 1: `lhs` might be a variable declarate, e.g: // // Note 2: this macro will be replaced by multiple statements that live on -// the scope it was called (see XLA_RETURN_IF_ERROR_IMPL). +// the scope it was called (see `XLA_PROCESS_STATUS_IMPL_`). +// +#define XLA_ASSIGN_OR_DO_IMPL_(on_error, lhs, rexpr, ...) \ + XLA_PROCESS_STATUS_IMPL_( \ + on_error, /* on_success= */ lhs = std::move(XLA_STATUS_VAR_).value(), \ + rexpr, XLA_STATUS_VAR_, ##__VA_ARGS__) + +// If `rexpr` returns a non-ok status, this macro propagates the returned status +// by early-returning a, possibly, new status with source location information. +// Otherwise, assigns `rexpr` to `lhs`. // // Example: // @@ -116,16 +185,23 @@ constexpr char kStatusPropagationTraceKey[] = // "New error message." // ); // -// If the function call results in an ok status, execution continues with -// `result` set to `ret.value()`, where `ret` is the returned value of the -// function. Otherwise, we early return a non-ok status. Then, if -// `TORCH_SHOW_CPP_STACKTRACES` is set, the error shown will be similar to -// the one above. +#define XLA_ASSIGN_OR_RETURN(lhs, rexpr, ...) \ + XLA_ASSIGN_OR_DO_IMPL_(XLA_PROPAGATE_STATUS_IMPL_, lhs, rexpr, ##__VA_ARGS__) + +// If `rexpr` returns a non-ok status, this macro throws an exception with the +// returned status, possibly, wrapped by a new status with source location +// information. Otherwise, assigns `rexpr` to `lhs`. +// +// Example: +// +// XLA_ASSIGN_OR_THROW( +// int result, +// FnThatReturnsStatus(), +// "New error message." +// ); // -#define XLA_ASSIGN_OR_RETURN(lhs, rexpr, ...) \ - XLA_RETURN_IF_ERROR_IMPL_(rexpr, XLA_STATUS_VAR_, \ - lhs = std::move(XLA_STATUS_VAR_).value(), \ - ##__VA_ARGS__) +#define XLA_ASSIGN_OR_THROW(lhs, rexpr, ...) \ + XLA_ASSIGN_OR_DO_IMPL_(XLA_THROW_STATUS_IMPL_, lhs, rexpr, ##__VA_ARGS__) // Crashes if `status` is not an ok status. // @@ -191,6 +267,18 @@ absl::Status MaybeWithNewMessage(const absl::Status& status, const char* file, int32_t line, const char* function, std::string_view new_message = ""); +// Throws an exception from the given `status` +// +// This function wraps `status` within a new status, with the current source +// location information added to its status propagation trace payload. +// +// Then, it throws an exception by using the `TORCH_CHECK(false)` macro, which +// also displays the C++ stacktrace at the end, if `TORCH_SHOW_CPP_STACKTRACES` +// is set. +void ThrowStatusError(const absl::Status& status, const char* file, + const int32_t line, const char* function, + std::string_view message = ""); + // Checks that `status` is an ok status. // // Otherwise, it will create a new status instance with the given source From d214faffcf482a2e276da1008d922a558fa07d88 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 28 Aug 2025 12:34:59 -0300 Subject: [PATCH 070/133] `test`: Use new macros for throwing exceptions. (#9590) Follow-up: #9588 and #9580 Target: `test` directory In summary, this PR: - Replaces all calls to `OkOrThrow()` and `GetValueOrThrow()` (that throws an exception without source location information of the *"throw-site"*) with the macros `XLA_THROW_IF_ERROR()` and `XLA_ASSIGN_OR_THROW()`. - Corresponds to the fine-grained set of PRs that came from breaking down PR #9580 - Focuses on the `test` directory, replacing every use (except for the ones in _test_status_common.h_) of those, now deprecated, functions by the newly introduced macros. --- test/cpp/cpp_test_util.cpp | 19 ++-- test/cpp/test_aten_xla_tensor_1.cpp | 2 +- test/cpp/test_replication.cpp | 9 +- test/cpp/test_tensor.cpp | 134 ++++++++++++++-------------- test/cpp/test_xla_backend_intf.cpp | 3 +- test/cpp/test_xla_sharding.cpp | 8 +- 6 files changed, 93 insertions(+), 82 deletions(-) diff --git a/test/cpp/cpp_test_util.cpp b/test/cpp/cpp_test_util.cpp index 4ca8e4981a91..7db1934c37d2 100644 --- a/test/cpp/cpp_test_util.cpp +++ b/test/cpp/cpp_test_util.cpp @@ -246,17 +246,17 @@ void WithAllDevices( } std::string GetTensorTextGraph(at::Tensor tensor) { - XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(tensor)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(tensor)); return DumpUtil::ToText({xtensor->GetIrValue().node.get()}); } std::string GetTensorDotGraph(at::Tensor tensor) { - XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(tensor)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(tensor)); return DumpUtil::ToDot({xtensor->GetIrValue().node.get()}); } std::string GetTensorHloGraph(at::Tensor tensor) { - XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(tensor)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(tensor)); return DumpUtil::ToHlo({xtensor->GetIrValue()}, xtensor->GetDevice()); } @@ -276,9 +276,9 @@ std::vector Execute( lowering_ctx.AddResult(root); } - xla::XlaComputation computation = GetValueOrThrow(lowering_ctx.BuildXla()); - xla::ProgramShape program_shape = - GetValueOrThrow(computation.GetProgramShape()); + XLA_ASSIGN_OR_THROW(xla::XlaComputation computation, lowering_ctx.BuildXla()); + XLA_ASSIGN_OR_THROW(xla::ProgramShape program_shape, + computation.GetProgramShape()); xla::Shape shape = MakeShapeWithDeviceLayout( program_shape.result(), static_cast(device.type())); @@ -295,17 +295,20 @@ std::vector Execute( std::move(instances)); torch_xla::runtime::ComputationClient::ExecuteComputationOptions options; - return GetValueOrThrow( + XLA_ASSIGN_OR_THROW( + std::vector outputs, torch_xla::runtime::GetComputationClientOrDie()->ExecuteComputation( *computations.front(), UnwrapXlaData(lowering_ctx.GetParametersData()), device.toString(), options)); + return outputs; } std::vector Fetch( absl::Span device_data) { - std::vector literals = GetValueOrThrow( + XLA_ASSIGN_OR_THROW( + std::vector literals, runtime::GetComputationClientOrDie()->TransferFromDevice(device_data)); std::vector tensors; for (auto& literal : literals) { diff --git a/test/cpp/test_aten_xla_tensor_1.cpp b/test/cpp/test_aten_xla_tensor_1.cpp index e2813b88a944..2c79925bc161 100644 --- a/test/cpp/test_aten_xla_tensor_1.cpp +++ b/test/cpp/test_aten_xla_tensor_1.cpp @@ -27,7 +27,7 @@ TEST_F(AtenXlaTensorTest, TestStorage) { torch::Tensor a = torch::tensor({0.0}); ForEachDevice([&](const torch::Device& device) { torch::Tensor xla_a = CopyToDevice(a, device); - XLATensorPtr xla_tensor_a = GetValueOrThrow(bridge::GetXlaTensor(xla_a)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_tensor_a, bridge::GetXlaTensor(xla_a)); EXPECT_EQ(xla_a.device(), xla_tensor_a->Storage().device()); AllClose(a, xla_a); }); diff --git a/test/cpp/test_replication.cpp b/test/cpp/test_replication.cpp index 386f9db3a9a8..00ec937bf8a1 100644 --- a/test/cpp/test_replication.cpp +++ b/test/cpp/test_replication.cpp @@ -25,7 +25,8 @@ xla::XlaComputation CreateCrsComputation(const xla::Shape& shape) { xla::XlaBuilder builder("CrsComputation"); xla::XlaOp x = xla::Parameter(&builder, 0, shape, "x"); xla::CrossReplicaSum(x); - return GetValueOrThrow(builder.Build()); + XLA_ASSIGN_OR_THROW(xla::XlaComputation crs_computation, builder.Build()); + return crs_computation; } void TestSingleReplication( @@ -65,7 +66,8 @@ void TestSingleReplication( torch_xla::runtime::ComputationClient::ExecuteComputationOptions exec_options; for (size_t i = 0; i < device_strings.size(); ++i) { auto executor = [&, i]() { - results[i] = GetValueOrThrow( + XLA_ASSIGN_OR_THROW( + results[i], torch_xla::runtime::GetComputationClientOrDie()->ExecuteComputation( *compiled_computations[i], {std::dynamic_pointer_cast< @@ -79,7 +81,8 @@ void TestSingleReplication( counter.Wait(); for (size_t i = 0; i < results.size(); ++i) { - std::vector literals = GetValueOrThrow( + XLA_ASSIGN_OR_THROW( + std::vector literals, runtime::GetComputationClientOrDie()->TransferFromDevice(results[i])); ASSERT_EQ(literals.size(), 1); diff --git a/test/cpp/test_tensor.cpp b/test/cpp/test_tensor.cpp index 6d962c900496..61627f242a86 100644 --- a/test/cpp/test_tensor.cpp +++ b/test/cpp/test_tensor.cpp @@ -101,8 +101,8 @@ TEST_F(TensorTest, TestAdd) { at::Tensor c = a.add(b, 1.0); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLATensorPtr dev_a = GetValueOrThrow(XLATensor::Create(a, device)); - XLATensorPtr dev_b = GetValueOrThrow(XLATensor::Create(b, device)); + XLA_ASSIGN_OR_THROW(XLATensorPtr dev_a, XLATensor::Create(a, device)); + XLA_ASSIGN_OR_THROW(XLATensorPtr dev_b, XLATensor::Create(b, device)); XLATensorPtr dev_c = tensor_methods::add(dev_a, dev_b, 1.0); AllClose(c, dev_c); @@ -121,8 +121,8 @@ TEST_F(TensorTest, TestIntegerAdd) { at::isIntegralType(type) ? at::Scalar(int64_t(1)) : at::Scalar(1.0); at::Tensor c = a.add(b, one); - XLATensorPtr dev_a = GetValueOrThrow(XLATensor::Create(a, device)); - XLATensorPtr dev_b = GetValueOrThrow(XLATensor::Create(b, device)); + XLA_ASSIGN_OR_THROW(XLATensorPtr dev_a, XLATensor::Create(a, device)); + XLA_ASSIGN_OR_THROW(XLATensorPtr dev_b, XLATensor::Create(b, device)); XLATensorPtr dev_c = tensor_methods::add(dev_a, dev_b, one); EXPECT_TRUE(EqualValuesNoElementTypeCheck( @@ -135,7 +135,8 @@ TEST_F(TensorTest, TestSize) { at::Tensor input = at::rand({2, 1, 4, 6}, at::TensorOptions(at::kFloat)); int rank = input.dim(); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLATensorPtr dev_input = GetValueOrThrow(XLATensor::Create(input, device)); + XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input, + XLATensor::Create(input, device)); for (int dim = -rank; dim < rank; ++dim) { EXPECT_EQ(input.size(dim), dev_input->size(dim)); } @@ -151,10 +152,10 @@ TEST_F(TensorTest, TestRrelu) { at::Tensor noise = at::zeros_like(input); at::Tensor output = at::rrelu_with_noise(input, noise, lower, upper, training); - XLATensorPtr dev_input = - GetValueOrThrow(XLATensor::Create(input, device)); - XLATensorPtr dev_noise = - GetValueOrThrow(XLATensor::Create(noise, device)); + XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input, + XLATensor::Create(input, device)); + XLA_ASSIGN_OR_THROW(XLATensorPtr dev_noise, + XLATensor::Create(noise, device)); XLATensorPtr dev_outputs = tensor_methods::rrelu_with_noise( dev_input, dev_noise, lower, upper, training); AllClose(output, dev_outputs); @@ -169,7 +170,8 @@ TEST_F(TensorTest, TestThreshold) { float value = 20; at::Tensor output = at::threshold(input, threshold, value); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLATensorPtr dev_input = GetValueOrThrow(XLATensor::Create(input, device)); + XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input, + XLATensor::Create(input, device)); XLATensorPtr dev_output = tensor_methods::threshold(dev_input, threshold, value); AllClose(output, dev_output); @@ -187,10 +189,11 @@ TEST_F(TensorTest, TestAddMatMul) { at::Tensor bias = at::rand({labels}, at::TensorOptions(at::kFloat)); at::Tensor output = at::addmm(bias, input, weight); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLATensorPtr dev_input = GetValueOrThrow(XLATensor::Create(input, device)); - XLATensorPtr dev_weight = - GetValueOrThrow(XLATensor::Create(weight, device)); - XLATensorPtr dev_bias = GetValueOrThrow(XLATensor::Create(bias, device)); + XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input, + XLATensor::Create(input, device)); + XLA_ASSIGN_OR_THROW(XLATensorPtr dev_weight, + XLATensor::Create(weight, device)); + XLA_ASSIGN_OR_THROW(XLATensorPtr dev_bias, XLATensor::Create(bias, device)); XLATensorPtr dev_output = tensor_methods::addmm(dev_input, dev_weight, dev_bias); AllClose(output, dev_output); @@ -201,7 +204,8 @@ TEST_F(TensorTest, TestTranspose) { at::Tensor input = at::rand({2, 3}, at::TensorOptions(at::kFloat)); at::Tensor output = at::transpose(input, 0, 1); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLATensorPtr dev_input = GetValueOrThrow(XLATensor::Create(input, device)); + XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input, + XLATensor::Create(input, device)); XLATensorPtr dev_output = tensor_methods::transpose(dev_input, 0, 1); AllClose(output, dev_output); }); @@ -211,7 +215,8 @@ TEST_F(TensorTest, TestView) { at::Tensor input = at::rand({32, 20, 4, 4}, at::TensorOptions(at::kFloat)); at::Tensor output = input.view({-1, 320}); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLATensorPtr dev_input = GetValueOrThrow(XLATensor::Create(input, device)); + XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input, + XLATensor::Create(input, device)); XLATensorPtr dev_output = tensor_methods::view(dev_input, {-1, 320}); AllClose(output, dev_output); }); @@ -292,8 +297,8 @@ TEST_F(TensorTest, TestMaxPool2D) { /*padding=*/{padding, padding}, /*dilation=*/{1, 1}, /*ceil_mode=*/false); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLATensorPtr dev_input = - GetValueOrThrow(XLATensor::Create(input, device)); + XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input, + XLATensor::Create(input, device)); auto dev_output = tensor_methods::max_pool_nd( dev_input, /*spatial_dim_count=*/2, @@ -317,8 +322,8 @@ TEST_F(TensorTest, TestMaxPool2DNonSquare) { /*padding=*/{padding, padding + 1}, /*dilation=*/{1, 1}, /*ceil_mode=*/false); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLATensorPtr dev_input = - GetValueOrThrow(XLATensor::Create(input, device)); + XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input, + XLATensor::Create(input, device)); auto dev_output = tensor_methods::max_pool_nd( dev_input, /*spatial_dim_count=*/2, @@ -346,8 +351,8 @@ TEST_F(TensorTest, TestAvgPool2D) { /*ceil_mode=*/false, count_include_pad, /*divisor_override=*/std::nullopt); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLATensorPtr dev_input = - GetValueOrThrow(XLATensor::Create(input, device)); + XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input, + XLATensor::Create(input, device)); XLATensorPtr dev_output = tensor_methods::avg_pool_nd( dev_input, /*spatial_dim_count=*/2, @@ -377,8 +382,8 @@ TEST_F(TensorTest, TestAvgPool2DNonSquare) { /*count_include_pad=*/count_include_pad, /*divisor_override=*/std::nullopt); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLATensorPtr dev_input = - GetValueOrThrow(XLATensor::Create(input, device)); + XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input, + XLATensor::Create(input, device)); XLATensorPtr dev_output = tensor_methods::avg_pool_nd( dev_input, /*spatial_dim_count=*/2, @@ -416,20 +421,20 @@ TEST_F(TensorTest, TestBatchNorm1D) { /*running_mean=*/running_mean, /*running_var=*/running_var, /*training=*/training, /*momentum=*/momentum, /*eps=*/eps); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLATensorPtr xla_input = - GetValueOrThrow(XLATensor::Create(input, device)); - XLATensorPtr xla_weight = - undef_weight_bias - ? XLATensorPtr() - : GetValueOrThrow(XLATensor::Create(weight, device)); - XLATensorPtr xla_bias = - undef_weight_bias - ? XLATensorPtr() - : GetValueOrThrow(XLATensor::Create(bias, device)); - XLATensorPtr xla_running_mean = - GetValueOrThrow(XLATensor::Create(running_mean, device)); - XLATensorPtr xla_running_var = - GetValueOrThrow(XLATensor::Create(running_var, device)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_input, + XLATensor::Create(input, device)); + XLATensorPtr xla_weight; + if (!undef_weight_bias) { + XLA_ASSIGN_OR_THROW(xla_weight, XLATensor::Create(weight, device)); + } + XLATensorPtr xla_bias; + if (!undef_weight_bias) { + XLA_ASSIGN_OR_THROW(xla_bias, XLATensor::Create(bias, device)); + } + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_running_mean, + XLATensor::Create(running_mean, device)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_running_var, + XLATensor::Create(running_var, device)); auto xla_output = tensor_methods::native_batch_norm( /*input=*/xla_input, /*weight=*/xla_weight, /*bias=*/xla_bias, /*running_mean=*/xla_running_mean, /*running_var=*/xla_running_var, @@ -486,14 +491,14 @@ TEST_F(TensorTest, TestConv2D) { /*output_padding=*/{output_padding, output_padding}, /*groups=*/groups, false, false, false); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLATensorPtr dev_input = - GetValueOrThrow(XLATensor::Create(input, device)); - XLATensorPtr dev_weight = - GetValueOrThrow(XLATensor::Create(weight, device)); + XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input, + XLATensor::Create(input, device)); + XLA_ASSIGN_OR_THROW(XLATensorPtr dev_weight, + XLATensor::Create(weight, device)); XLATensorPtr dev_output; if (with_bias) { - XLATensorPtr dev_bias = - GetValueOrThrow(XLATensor::Create(bias, device)); + XLA_ASSIGN_OR_THROW(XLATensorPtr dev_bias, + XLATensor::Create(bias, device)); dev_output = tensor_methods::convolution_overrideable( dev_input, dev_weight, dev_bias, /*stride=*/{stride, stride}, @@ -558,14 +563,14 @@ TEST_F(TensorTest, TestConv2DNonSquare) { /*groups=*/groups, false, false, false); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLATensorPtr dev_input = - GetValueOrThrow(XLATensor::Create(input, device)); - XLATensorPtr dev_weight = - GetValueOrThrow(XLATensor::Create(weight, device)); + XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input, + XLATensor::Create(input, device)); + XLA_ASSIGN_OR_THROW(XLATensorPtr dev_weight, + XLATensor::Create(weight, device)); XLATensorPtr dev_output; if (with_bias) { - XLATensorPtr dev_bias = - GetValueOrThrow(XLATensor::Create(bias, device)); + XLA_ASSIGN_OR_THROW(XLATensorPtr dev_bias, + XLATensor::Create(bias, device)); dev_output = tensor_methods::convolution_overrideable( dev_input, dev_weight, dev_bias, /*stride=*/{stride, stride + 1}, @@ -634,14 +639,14 @@ TEST_F(TensorTest, TestConv3D) { {output_padding, output_padding, output_padding}, /*groups=*/groups, false, false, false); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLATensorPtr dev_input = - GetValueOrThrow(XLATensor::Create(input, device)); - XLATensorPtr dev_weight = - GetValueOrThrow(XLATensor::Create(weight, device)); + XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input, + XLATensor::Create(input, device)); + XLA_ASSIGN_OR_THROW(XLATensorPtr dev_weight, + XLATensor::Create(weight, device)); XLATensorPtr dev_output; if (with_bias) { - XLATensorPtr dev_bias = - GetValueOrThrow(XLATensor::Create(bias, device)); + XLA_ASSIGN_OR_THROW(XLATensorPtr dev_bias, + XLATensor::Create(bias, device)); dev_output = tensor_methods::convolution_overrideable( dev_input, dev_weight, dev_bias, /*stride=*/{stride, stride, stride}, @@ -709,15 +714,14 @@ TEST_F(TensorTest, TestConv3D) { // {output_padding, output_padding + 1, output_padding}, // /*groups=*/groups, false, false, false); // ForEachDevice([&](const torch::lazy::BackendDevice& device) { -// XLATensorPtr dev_input = -// GetValueOrThrow(XLATensor::Create(input, device)); -// XLATensorPtr dev_weight = -// GetValueOrThrow(XLATensor::Create(weight, device); -// XLATensorPtr dev_output; -// if (with_bias) { -// XLATensorPtr dev_bias = -// GetValueOrThrow(XLATensor::Create(bias, device)); -// dev_output = tensor_methods::convolution_overrideable( +// XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input, +// XLATensor::Create(input, device)); +// XLA_ASSIGN_OR_THROW(XLATensorPtr dev_weight, +// XLATensor::Create(weight, device)); XLATensorPtr +// dev_output; if (with_bias) { +// XLA_ASSIGN_OR_THROW(XLATensorPtr dev_bias, +// XLATensor::Create(bias, device)); dev_output = +// tensor_methods::convolution_overrideable( // dev_input, dev_weight, dev_bias, // /*stride=*/{stride, stride + 1, stride + 1}, // /*padding=*/{padding, padding + 1, padding + 1}, diff --git a/test/cpp/test_xla_backend_intf.cpp b/test/cpp/test_xla_backend_intf.cpp index c0d12583976a..25a39e57abb8 100644 --- a/test/cpp/test_xla_backend_intf.cpp +++ b/test/cpp/test_xla_backend_intf.cpp @@ -53,7 +53,8 @@ xla::XlaComputation CreateAddComputation(const xla::Shape& shape) { xla::XlaOp x = xla::Parameter(&builder, 0, shape, "x"); xla::XlaOp y = xla::Parameter(&builder, 1, shape, "y"); xla::XlaOp sum = xla::Add(x, y); - return GetValueOrThrow(builder.Build()); + XLA_ASSIGN_OR_THROW(xla::XlaComputation add_computation, builder.Build()); + return add_computation; } TEST(XLABackendTest, TestE2E) { diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index b179c6e523cc..f3b7541f6273 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -28,8 +28,8 @@ namespace { bool XlaDataValuesEqual(torch::lazy::BackendDataPtr a, torch::lazy::BackendDataPtr b, at::ScalarType element_type) { - std::vector tensors = - GetValueOrThrow(XlaDataToTensors({a, b}, {element_type, element_type})); + XLA_ASSIGN_OR_THROW(std::vector tensors, + XlaDataToTensors({a, b}, {element_type, element_type})); return TensorCompare(tensors[0], tensors[1]); } } // namespace @@ -385,8 +385,8 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) { auto x = xla::Parameter(&b, 0, shape, "p0"); b.ClearSharding(); auto y = xla::Add(x, xla::ConstantR0(&b, 3)); - xla::XlaComputation xla_computation = - GetValueOrThrow(b.Build(/*remove_dynamic_dimensions=*/false)); + XLA_ASSIGN_OR_THROW(xla::XlaComputation xla_computation, + b.Build(/*remove_dynamic_dimensions=*/false)); std::vector instances; instances.push_back({std::move(xla_computation), bridge::GetDefaultDevice()->toString(), From d9a9e44e75e994233bff66e9826574a65973e088 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 28 Aug 2025 12:35:36 -0300 Subject: [PATCH 071/133] `runtime`: Use new macros for throwing exceptions. (#9591) Follow-up: #9588 and #9580 Target: `torch_xla/csrc/runtime` directory In summary, this PR: - Replaces all calls to `OkOrThrow()` and `GetValueOrThrow()` (that throws an exception without source location information of the *"throw-site"*) with the macros `XLA_THROW_IF_ERROR()` and `XLA_ASSIGN_OR_THROW()`. - Corresponds to the fine-grained set of PRs that came from breaking down PR #9580 - Focuses on the `torch_xla/csrc/runtime` directory, replacing every use of those, now deprecated, functions by the newly introduced macros. --- torch_xla/csrc/runtime/computation_client.h | 7 ++-- .../csrc/runtime/ifrt_computation_client.cpp | 25 +++++++----- .../runtime/ifrt_computation_client_test.cpp | 11 +++-- .../csrc/runtime/pjrt_computation_client.cpp | 40 +++++++++++-------- .../runtime/pjrt_computation_client_test.cpp | 15 +++---- torch_xla/csrc/runtime/runtime.cpp | 3 +- torch_xla/csrc/runtime/tensor_source.h | 2 +- torch_xla/csrc/runtime/xla_util_test.cpp | 9 +++-- 8 files changed, 65 insertions(+), 47 deletions(-) diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 05478dc6cb42..6d05137a89a3 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -116,9 +116,9 @@ class ComputationClient { : name_(name), computation_(std::move(computation)), devices_(std::move(devices)) { - program_shape_ = GetValueOrThrow(computation_.GetProgramShape()); + XLA_ASSIGN_OR_THROW(program_shape_, computation_.GetProgramShape()); const xla::HloModuleProto& proto = computation_.proto(); - hash_ = GetValueOrThrow(ComputeHash(proto, name)); + XLA_ASSIGN_OR_THROW(hash_, ComputeHash(proto, name)); } Computation(std::string name, xla::XlaComputation computation, @@ -191,7 +191,8 @@ class ComputationClient { const std::string to_string() const override { xla::HloModuleConfig hlo_config(program_shape()); - std::unique_ptr module = GetValueOrThrow( + XLA_ASSIGN_OR_THROW( + std::unique_ptr module, xla::HloModule::CreateFromProto(computation().proto(), hlo_config)); return module->ToString(); } diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cpp b/torch_xla/csrc/runtime/ifrt_computation_client.cpp index 0325d1440fc8..db9ec8dab512 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cpp +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cpp @@ -177,7 +177,8 @@ void IfrtComputationClient::InitializeCoordinator(int global_rank, std::string port) { XLA_CHECK(coordinator_ == nullptr) << "Can only initialize the XlaCoordinator once."; - coordinator_ = GetValueOrThrow( + XLA_ASSIGN_OR_THROW( + coordinator_, XlaCoordinator::Create(global_rank, world_size, master_addr, port)); } @@ -395,10 +396,10 @@ tsl::RCReference IfrtComputationClient::ReplicateShardedData( auto instruction = XlaBuilderFriend::GetInstruction(y); *instruction->mutable_sharding() = xla::HloSharding::Replicate().ToProto(); - xla::XlaComputation computation = - GetValueOrThrow(builder.Build(/*remove_dynamic_dimensions=*/false)); - xla::ProgramShape program_shape = - GetValueOrThrow(computation.GetProgramShape()); + XLA_ASSIGN_OR_THROW(xla::XlaComputation computation, + builder.Build(/*remove_dynamic_dimensions=*/false)); + XLA_ASSIGN_OR_THROW(xla::ProgramShape program_shape, + computation.GetProgramShape()); std::string device = GetDefaultDevice(); std::vector instances; @@ -417,8 +418,9 @@ tsl::RCReference IfrtComputationClient::ReplicateShardedData( torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions execute_options; - auto sharded_results = GetValueOrThrow(ExecuteReplicated( - *computations.front(), {{handle}}, GetLocalDevices(), execute_options)); + XLA_ASSIGN_OR_THROW(std::vector sharded_results, + ExecuteReplicated(*computations.front(), {{handle}}, + GetLocalDevices(), execute_options)); auto replicated_output = std::dynamic_pointer_cast(sharded_results[0]) ->buffer->FullyReplicatedShard( @@ -516,14 +518,17 @@ std::vector IfrtComputationClient::Compile( mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); torch_xla::ConvertHloToStableHlo(instance.computation.mutable_proto(), &mlir_module); - std::shared_ptr executable = - GetValueOrThrow(client_->GetDefaultCompiler()->CompileAndLoad( + XLA_ASSIGN_OR_THROW( + std::shared_ptr executable, + client_->GetDefaultCompiler()->CompileAndLoad( std::make_unique(mlir_module), std::make_unique(compile_options, devices_list))); StableHloCompileCounter()->AddValue(1); - const auto& hlo_modules = GetValueOrThrow(executable->GetHloModules()); + XLA_ASSIGN_OR_THROW( + const std::vector>& hlo_modules, + executable->GetHloModules()); std::shared_ptr ifrt_computation = std::make_shared( diff --git a/torch_xla/csrc/runtime/ifrt_computation_client_test.cpp b/torch_xla/csrc/runtime/ifrt_computation_client_test.cpp index d48b4337d21c..4efc9550b8de 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client_test.cpp +++ b/torch_xla/csrc/runtime/ifrt_computation_client_test.cpp @@ -36,7 +36,8 @@ absl::StatusOr MakeComputation() { TEST(PjRtComputationClientTest, Init) { // Get a CPU client. tsl::setenv("PJRT_DEVICE", "CPU", true); - auto client = GetValueOrThrow(IfrtComputationClient::Create()); + XLA_ASSIGN_OR_THROW(std::unique_ptr client, + IfrtComputationClient::Create()); std::string device = client->GetDefaultDevice(); // Compose a computation. @@ -64,14 +65,16 @@ TEST(PjRtComputationClientTest, Init) { std::make_shared(std::move(literal_y), device)}; // Execute the graph. - std::vector results = - GetValueOrThrow(client->ExecuteReplicated( + XLA_ASSIGN_OR_THROW( + std::vector results, + client->ExecuteReplicated( *computations[0], client->TransferToDevice(absl::MakeConstSpan(args)), {device}, options)); // Copy the output from device back to host and assert correctness.. ASSERT_EQ(results.size(), 1); - auto result_literals = GetValueOrThrow(client->TransferFromDevice(results)); + XLA_ASSIGN_OR_THROW(std::vector result_literals, + client->TransferFromDevice(results)); ASSERT_THAT(result_literals, ::testing::SizeIs(1)); EXPECT_TRUE(xla::LiteralTestUtil::Equal( xla::LiteralUtil::CreateR2({{6.0f, 8.0f}, {10.0f, 12.0f}}), diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cpp b/torch_xla/csrc/runtime/pjrt_computation_client.cpp index 98ce8520da32..7e2833fc8f16 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cpp +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cpp @@ -168,7 +168,8 @@ void PjRtComputationClient::InitializeCoordinator(int global_rank, std::string port) { XLA_CHECK(coordinator_ == nullptr) << "Can only initialize the XlaCoordinator once."; - coordinator_ = GetValueOrThrow( + XLA_ASSIGN_OR_THROW( + coordinator_, XlaCoordinator::Create(global_rank, world_size, master_addr, port)); } @@ -367,10 +368,10 @@ PjRtComputationClient::ReplicateShardedData( auto instruction = XlaBuilderFriend::GetInstruction(y); *instruction->mutable_sharding() = xla::HloSharding::Replicate().ToProto(); - xla::XlaComputation computation = - GetValueOrThrow(builder.Build(/*remove_dynamic_dimensions=*/false)); - xla::ProgramShape program_shape = - GetValueOrThrow(computation.GetProgramShape()); + XLA_ASSIGN_OR_THROW(xla::XlaComputation computation, + builder.Build(/*remove_dynamic_dimensions=*/false)); + XLA_ASSIGN_OR_THROW(xla::ProgramShape program_shape, + computation.GetProgramShape()); std::string device = GetDefaultDevice(); std::vector @@ -386,8 +387,8 @@ PjRtComputationClient::ReplicateShardedData( torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions execute_options; - auto sharded_results = - GetValueOrThrow(ExecuteReplicated(*computations.front(), {sharded_data}, + XLA_ASSIGN_OR_THROW(std::vector sharded_results, + ExecuteReplicated(*computations.front(), {sharded_data}, GetLocalDevices(), execute_options)); XLA_CHECK(sharded_results.size() > 0) << "empty ExecuteReplicated results returned."; @@ -433,8 +434,9 @@ std::vector PjRtComputationClient::ReshardData( XLA_CHECK_NE(sharding.type(), xla::OpSharding::UNKNOWN) << "Resharding by UNKNOWN sharding type is not allowed."; - hlo_shardings.push_back( - GetValueOrThrow(xla::HloSharding::FromProto(sharding))); + XLA_ASSIGN_OR_THROW(xla::HloSharding hlo_sharding, + xla::HloSharding::FromProto(sharding)); + hlo_shardings.push_back(std::move(hlo_sharding)); xla::OpSharding fallback_sharding; fallback_sharding.set_type(xla::OpSharding::REPLICATED); @@ -457,9 +459,9 @@ std::vector PjRtComputationClient::ReshardData( root = xla::Tuple(&builder, param_ops); } - xla::XlaComputation xla_computation = GetValueOrThrow(builder.Build(root)); - xla::ProgramShape program_shape = - GetValueOrThrow(xla_computation.GetProgramShape()); + XLA_ASSIGN_OR_THROW(xla::XlaComputation xla_computation, builder.Build(root)); + XLA_ASSIGN_OR_THROW(xla::ProgramShape program_shape, + xla_computation.GetProgramShape()); std::string device = GetDefaultDevice(); std::vector instances; @@ -474,8 +476,9 @@ std::vector PjRtComputationClient::ReshardData( torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions execute_options; - auto resharded_results = GetValueOrThrow(ExecuteReplicated( - *computation, handles, GetLocalDevices(), execute_options)); + XLA_ASSIGN_OR_THROW(std::vector resharded_results, + ExecuteReplicated(*computation, handles, + GetLocalDevices(), execute_options)); return resharded_results; } @@ -660,7 +663,9 @@ std::vector PjRtComputationClient::Compile( TF_VLOG(3) << "memory usage is not availiable"; } - const auto& hlo_modules = GetValueOrThrow(executable->GetHloModules()); + XLA_ASSIGN_OR_THROW( + const std::vector>& hlo_modules, + executable->GetHloModules()); xla::HloComputation* hlo_computation = hlo_modules[0]->entry_computation(); std::shared_ptr pjrt_computation = std::make_shared( @@ -679,8 +684,9 @@ std::string PjRtComputationClient::SerializeComputation( const ComputationPtr computation) { const PjRtComputation& pjrt_computation = dynamic_cast(*computation); - - return GetValueOrThrow(pjrt_computation.executable->SerializeExecutable()); + XLA_ASSIGN_OR_THROW(std::string serialized_executable, + pjrt_computation.executable->SerializeExecutable()); + return serialized_executable; } ComputationClient::ComputationPtr PjRtComputationClient::DeserializeComputation( diff --git a/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp b/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp index 64496312ae4d..c19499e515a7 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp +++ b/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp @@ -25,7 +25,7 @@ class PjRtComputationClientTest : public ::testing::Test { PjRtComputationClientTest() { // Get a CPU client. tsl::setenv("PJRT_DEVICE", "CPU", true); - client_ = GetValueOrThrow(PjRtComputationClient::Create()); + XLA_ASSIGN_OR_THROW(client_, PjRtComputationClient::Create()); device_ = client_->GetDefaultDevice(); } @@ -114,15 +114,16 @@ TEST_F(PjRtComputationClientTest, Init) { std::make_shared(std::move(literal_y), device_)}; // Execute the graph. - std::vector results = - GetValueOrThrow(client_->ExecuteComputation( - *computations[0], - client_->TransferToDevice(absl::MakeConstSpan(args)), device_, - options)); + XLA_ASSIGN_OR_THROW(std::vector results, + client_->ExecuteComputation( + *computations[0], + client_->TransferToDevice(absl::MakeConstSpan(args)), + device_, options)); // Copy the output from device back to host and assert correctness. ASSERT_EQ(results.size(), 1); - auto result_literals = GetValueOrThrow(client_->TransferFromDevice(results)); + XLA_ASSIGN_OR_THROW(std::vector result_literals, + client_->TransferFromDevice(results)); ASSERT_THAT(result_literals, ::testing::SizeIs(1)); EXPECT_TRUE(xla::LiteralTestUtil::Equal( xla::LiteralUtil::CreateR2({{6.0f, 8.0f}, {10.0f, 12.0f}}), diff --git a/torch_xla/csrc/runtime/runtime.cpp b/torch_xla/csrc/runtime/runtime.cpp index 69bb53fe1df4..3836f6975719 100644 --- a/torch_xla/csrc/runtime/runtime.cpp +++ b/torch_xla/csrc/runtime/runtime.cpp @@ -61,7 +61,8 @@ const absl::StatusOr& GetComputationClient() { } ComputationClient* absl_nonnull GetComputationClientOrDie() { - return GetValueOrThrow(GetComputationClient()); + XLA_ASSIGN_OR_THROW(ComputationClient * client, GetComputationClient()); + return client; } ComputationClient* GetComputationClientIfInitialized() { diff --git a/torch_xla/csrc/runtime/tensor_source.h b/torch_xla/csrc/runtime/tensor_source.h index cc8e646eee75..888396ae6783 100644 --- a/torch_xla/csrc/runtime/tensor_source.h +++ b/torch_xla/csrc/runtime/tensor_source.h @@ -31,7 +31,7 @@ class TensorSource { virtual std::vector byte_strides() const { std::vector byte_strides(shape().dimensions_size()); - OkOrThrow( + XLA_THROW_IF_ERROR( xla::ShapeUtil::ByteStrides(shape(), absl::MakeSpan(byte_strides))); return byte_strides; } diff --git a/torch_xla/csrc/runtime/xla_util_test.cpp b/torch_xla/csrc/runtime/xla_util_test.cpp index 8f2662994f14..9178fc133780 100644 --- a/torch_xla/csrc/runtime/xla_util_test.cpp +++ b/torch_xla/csrc/runtime/xla_util_test.cpp @@ -121,10 +121,10 @@ TEST(XlaUtilTest, XlaToHlo) { TEST(XlaUtilTest, TestDeterministicModuleProtoSerializationEmptyProto) { xla::HloModuleProto empty_proto; - auto result = - GetValueOrThrow(GetDeterministicSerializedModuleProto(empty_proto)); + XLA_ASSIGN_OR_THROW(std::string serialized_result, + GetDeterministicSerializedModuleProto(empty_proto)); // Verify that the result is an empty string - EXPECT_TRUE(result.empty()); + EXPECT_TRUE(serialized_result.empty()); } TEST(XlaUtilTest, TestDeterministicModuleProtoSerialization) { @@ -250,7 +250,8 @@ TEST(XlaUtilTest, TestDeterministicModuleProtoSerialization) { } } } - std::string serialized_proto = GetValueOrThrow( + XLA_ASSIGN_OR_THROW( + std::string serialized_proto, GetDeterministicSerializedModuleProto(hlo_module_proto)); return torch::lazy::Hash(serialized_proto); }; From 8d20a86a188fe66970ff988eb0fdb11ba3239eda Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 28 Aug 2025 12:36:03 -0300 Subject: [PATCH 072/133] `ops`: Use new macros for throwing exceptions. (#9592) Follow-up: #9588 and #9580 Target: - `torch_xla/csrc/ops` directory - Files related to the tracing of tensor operations In summary, this PR: - Replaces all calls to `OkOrThrow()` and `GetValueOrThrow()` (that throws an exception without source location information of the *"throw-site"*) with the macros `XLA_THROW_IF_ERROR()` and `XLA_ASSIGN_OR_THROW()`. - Corresponds to the fine-grained set of PRs that came from breaking down PR #9580 - Focuses on the `torch_xla/csrc/ops` directory and other files related to the tracing of tensor operations, replacing every use of those, now deprecated, functions by the newly introduced macros. --- torch_xla/csrc/convolution.cpp | 17 +++++++++------- torch_xla/csrc/cross_replica_reduces.cpp | 10 ++++----- torch_xla/csrc/data_ops.cpp | 4 ++-- torch_xla/csrc/ops/index_ops.cpp | 5 ++++- torch_xla/csrc/ops/triangular_solve.cpp | 3 ++- torch_xla/csrc/pooling.cpp | 7 +++++-- torch_xla/csrc/reduction.cpp | 6 ++++-- torch_xla/csrc/shape_helper.cpp | 3 ++- torch_xla/csrc/tensor_methods.cpp | 16 ++++++++------- torch_xla/csrc/tensor_ops.cpp | 26 ++++++++++++++++-------- torch_xla/csrc/tensor_util.cpp | 2 +- 11 files changed, 61 insertions(+), 38 deletions(-) diff --git a/torch_xla/csrc/convolution.cpp b/torch_xla/csrc/convolution.cpp index 8dd4d43efe8f..6312045ca9e5 100644 --- a/torch_xla/csrc/convolution.cpp +++ b/torch_xla/csrc/convolution.cpp @@ -218,9 +218,11 @@ xla::XlaOp BuildConvBackwardInput(xla::XlaOp grad_output, xla::XlaOp kernel, MakeConvOpAttrs(spatial_stride, spatial_padding, spatial_dilation, false); xla::XlaOp kernel_transposed = xla::Transpose( kernel, FilterTransposePermutation(input_shape.dimensions_size())); - return GetValueOrThrow(MakeXlaBackpropInputConvOp( - "conv_backward_input", input_shape, kernel_transposed, grad_output, - conv_op_attrs)); + XLA_ASSIGN_OR_THROW(xla::XlaOp conv_backward_input, + MakeXlaBackpropInputConvOp("conv_backward_input", + input_shape, kernel_transposed, + grad_output, conv_op_attrs)); + return conv_backward_input; } // Computes the kernel gradient for a convolution. @@ -238,14 +240,15 @@ xla::XlaOp BuildConvBackwardWeight(xla::XlaOp grad_output, xla::XlaOp input, xla::InversePermutation(transpose_permutation); xla::Shape transposed_weight_shape = xla::ShapeUtil::PermuteDimensions(transpose_permutation, kernel_shape); - xla::XlaOp conv = GetValueOrThrow(MakeXlaBackpropFilterConvOp( - "conv_backward_weight", input, transposed_weight_shape, grad_output, - conv_op_attrs)); + XLA_ASSIGN_OR_THROW(xla::XlaOp conv_backward_weight, + MakeXlaBackpropFilterConvOp("conv_backward_weight", input, + transposed_weight_shape, + grad_output, conv_op_attrs)); // Reorder the dimensions of the filter gradient to match the NCHW convention // of PyTorch. The original result of the convolution has the spatial and // feature dimensions swapped and the spatial dimensions reversed. - return xla::Transpose(conv, inv_transpose_permutation); + return xla::Transpose(conv_backward_weight, inv_transpose_permutation); } xla::XlaOp BuildGradBias(xla::XlaOp grad_output) { diff --git a/torch_xla/csrc/cross_replica_reduces.cpp b/torch_xla/csrc/cross_replica_reduces.cpp index aaa982e6ebd3..77519c03cfc2 100644 --- a/torch_xla/csrc/cross_replica_reduces.cpp +++ b/torch_xla/csrc/cross_replica_reduces.cpp @@ -116,7 +116,7 @@ std::shared_ptr CreateToken( at::Tensor all_reduce(const at::Tensor& self, std::string reduceOp, std::string /*group_name*/) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr self_tensor, bridge::GetXlaTensor(self)); // TODO(alanwaketan): Use group_name to generate groups. Currently we just // use {} as a workaround. Scale is always 1.0 here, and we always pin // layout. @@ -270,7 +270,7 @@ AllGatherResult BuildAllGather(xla::XlaOp input, xla::XlaOp token, int64_t dim, at::Tensor all_gather_into_tensor(const at::Tensor& self, int64_t group_size, std::string group_name) { TORCH_LAZY_FN_COUNTER("xla::"); - auto self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr self_tensor, bridge::GetXlaTensor(self)); std::vector all_groups(group_size); std::iota(all_groups.begin(), all_groups.end(), 0); auto result = tensor_methods::all_gather(self_tensor, 0, group_size, @@ -349,9 +349,9 @@ at::Tensor all_to_all_single(const at::Tensor& input, } XLATensorPtr result_ptr; torch::lazy::Value new_token; + XLA_ASSIGN_OR_THROW(XLATensorPtr input_tensor, bridge::GetXlaTensor(input)); std::tie(result_ptr, new_token) = tensor_methods::all_to_all( - GetValueOrThrow(bridge::GetXlaTensor(input)), token, 0, 0, split_count, - {all_groups}, pin_layout); + input_tensor, token, 0, 0, split_count, {all_groups}, pin_layout); at::Tensor result = bridge::AtenFromXlaTensor(std::move(result_ptr)); at::Tensor result_with_grad = torch::autograd::make_variable( @@ -481,7 +481,7 @@ xla::XlaOp BuildReduceScatter(AllReduceType reduce_type, xla::XlaOp input, at::Tensor reduce_scatter_tensor(const at::Tensor& input, std::string reduce_op, int64_t group_size, std::string group_name) { TORCH_LAZY_FN_COUNTER("xla::"); - auto self = GetValueOrThrow(bridge::GetXlaTensor(input)); + XLA_ASSIGN_OR_THROW(XLATensorPtr self, bridge::GetXlaTensor(input)); std::vector all_groups(group_size); std::iota(all_groups.begin(), all_groups.end(), 0); int64_t shard_count = group_size; diff --git a/torch_xla/csrc/data_ops.cpp b/torch_xla/csrc/data_ops.cpp index 3953ef7ca11a..166f69f7659d 100644 --- a/torch_xla/csrc/data_ops.cpp +++ b/torch_xla/csrc/data_ops.cpp @@ -197,8 +197,8 @@ xla::XlaOp BuildMaskedFillScalar(xla::XlaOp input, xla::XlaOp mask, const xla::Shape& mask_shape = ShapeHelper::ShapeOfXlaOp(mask); if (!xla::ShapeUtil::Compatible(input_shape, mask_shape)) { - xla::Shape shape = - GetValueOrThrow(XlaHelpers::GetPromotedShape(input_shape, mask_shape)); + XLA_ASSIGN_OR_THROW(xla::Shape shape, + XlaHelpers::GetPromotedShape(input_shape, mask_shape)); input = BuildExpand(input, shape.dimensions()); mask = BuildExpand(mask, shape.dimensions()); } diff --git a/torch_xla/csrc/ops/index_ops.cpp b/torch_xla/csrc/ops/index_ops.cpp index 25b0b4f4b852..9dddc9424e03 100644 --- a/torch_xla/csrc/ops/index_ops.cpp +++ b/torch_xla/csrc/ops/index_ops.cpp @@ -18,6 +18,7 @@ #include "torch_xla/csrc/ops/scalar.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/util.h" +#include "torch_xla/csrc/status.h" #include "torch_xla/csrc/tensor_methods.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/xla_graph_executor.h" @@ -315,8 +316,10 @@ XLATensorPtr GetZeroElementTensor(const XLATensorPtr& base, base_dimensions.begin() + start_dim + indices.size(), base_dimensions.end()); - return GetValueOrThrow( + XLA_ASSIGN_OR_THROW( + XLATensorPtr output, tensor_methods::full(dimensions, 0, base->GetDevice(), base->dtype())); + return output; } XLATensorPtr IndexByTensors(const XLATensorPtr& base, diff --git a/torch_xla/csrc/ops/triangular_solve.cpp b/torch_xla/csrc/ops/triangular_solve.cpp index c2dc3414de14..4bf1d77940c4 100644 --- a/torch_xla/csrc/ops/triangular_solve.cpp +++ b/torch_xla/csrc/ops/triangular_solve.cpp @@ -33,7 +33,8 @@ std::pair InferTriangularSolveShape( return std::pair(rhs_batch_shape, lhs_batch_shape); } // Obtain the promoted shapes and add back the trailing dimension. - xla::Shape rhs_batch_promoted_shape = GetValueOrThrow( + XLA_ASSIGN_OR_THROW( + xla::Shape rhs_batch_promoted_shape, XlaHelpers::GetPromotedShape(rhs_batch_shape, lhs_batch_shape)); xla::Shape lhs_batch_promoted_shape(rhs_batch_promoted_shape); rhs_batch_promoted_shape.add_dimensions(nrhs); diff --git a/torch_xla/csrc/pooling.cpp b/torch_xla/csrc/pooling.cpp index 78056775f828..934de6db67db 100644 --- a/torch_xla/csrc/pooling.cpp +++ b/torch_xla/csrc/pooling.cpp @@ -49,7 +49,9 @@ xla::XlaComputation CreateGeComputation(xla::PrimitiveType type) { xla::XlaOp y = xla::Parameter(&reduction_builder, 1, xla::ShapeUtil::MakeShape(type, {}), "y"); xla::Ge(x, y); - return GetValueOrThrow(reduction_builder.Build()); + XLA_ASSIGN_OR_THROW(xla::XlaComputation ge_computation, + reduction_builder.Build()); + return ge_computation; } xla::TensorFormat MakeNCHWFormat(int64_t spatial_dim_count) { @@ -367,7 +369,8 @@ xla::XlaOp ComputeMaxPoolIndices( return results; }; - std::vector results = GetValueOrThrow( + XLA_ASSIGN_OR_THROW( + std::vector results, xla::WhileLoopHelper(cond_fn, body_fn, initial_values.values, "ComputeMaxPoolIndices", padded_input.builder())); diff --git a/torch_xla/csrc/reduction.cpp b/torch_xla/csrc/reduction.cpp index e0e9cbc07416..007032e794ba 100644 --- a/torch_xla/csrc/reduction.cpp +++ b/torch_xla/csrc/reduction.cpp @@ -60,7 +60,8 @@ xla::XlaComputation CreateAllComputation(xla::PrimitiveType type) { xla::XlaOp zero = xla::Zero(&builder, type); xla::XlaOp one = xla::One(&builder, type); xla::Select(xla::And(xla::Ne(x, zero), xla::Ne(y, zero)), one, zero); - return GetValueOrThrow(builder.Build()); + XLA_ASSIGN_OR_THROW(xla::XlaComputation all_computation, builder.Build()); + return all_computation; } xla::XlaComputation CreateAnyComputation(xla::PrimitiveType type) { @@ -72,7 +73,8 @@ xla::XlaComputation CreateAnyComputation(xla::PrimitiveType type) { xla::XlaOp zero = xla::Zero(&builder, type); xla::XlaOp one = xla::One(&builder, type); xla::Select(xla::Or(xla::Ne(x, zero), xla::Ne(y, zero)), one, zero); - return GetValueOrThrow(builder.Build()); + XLA_ASSIGN_OR_THROW(xla::XlaComputation any_computation, builder.Build()); + return any_computation; } xla::XlaOp GetScaleValue(xla::XlaOp input, xla::XlaOp count, diff --git a/torch_xla/csrc/shape_helper.cpp b/torch_xla/csrc/shape_helper.cpp index 724ba4007154..ac393796e6ee 100644 --- a/torch_xla/csrc/shape_helper.cpp +++ b/torch_xla/csrc/shape_helper.cpp @@ -6,7 +6,8 @@ namespace torch_xla { const xla::Shape& ShapeHelper::ShapeOfXlaOp(xla::XlaOp op) { - return *GetValueOrThrow(GetShape(op)); + XLA_ASSIGN_OR_THROW(const xla::Shape* shape, GetShape(op)); + return *shape; } absl::StatusOr GetShape(xla::XlaOp op) { diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 2786ca1718bd..bfd67b59de24 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -1479,9 +1479,11 @@ std::tuple cummax(const XLATensorPtr& input, at::Tensor val = at::empty(shape_, at::TensorOptions().dtype(input->dtype())); at::Tensor idx = at::empty(shape_, at::TensorOptions().dtype(at::kLong)); - return std::make_tuple( - GetValueOrThrow(XLATensor::Create(val, input->GetDevice())), - GetValueOrThrow(XLATensor::Create(idx, input->GetDevice()))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_val, + XLATensor::Create(val, input->GetDevice())); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_idx, + XLATensor::Create(idx, input->GetDevice())); + return std::make_tuple(xla_val, xla_idx); } torch::lazy::NodePtr node = torch_xla::MakeNode(input->GetIrValue(), canonical_dim); @@ -2533,10 +2535,10 @@ std::tuple native_batch_norm( } } else { at::Tensor at_input = bridge::AtenFromXlaTensor(input); - mean = GetValueOrThrow( - bridge::GetXlaTensor(at::empty({0}, at_input.options()))); - variance_inverse = GetValueOrThrow( - bridge::GetXlaTensor(at::empty({0}, at_input.options()))); + XLA_ASSIGN_OR_THROW( + mean, bridge::GetXlaTensor(at::empty({0}, at_input.options()))); + XLA_ASSIGN_OR_THROW(variance_inverse, bridge::GetXlaTensor(at::empty( + {0}, at_input.options()))); } XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); diff --git a/torch_xla/csrc/tensor_ops.cpp b/torch_xla/csrc/tensor_ops.cpp index edb7d22297c4..84d788d624fb 100644 --- a/torch_xla/csrc/tensor_ops.cpp +++ b/torch_xla/csrc/tensor_ops.cpp @@ -8,6 +8,7 @@ #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/util.h" +#include "torch_xla/csrc/status.h" #include "torch_xla/csrc/tensor_methods.h" namespace torch_xla { @@ -148,7 +149,8 @@ XLATensorPtr SmoothL1LossBackward(const XLATensorPtr& grad_output, XLATensorPtr grad_scale = tensor_methods::get_dimensions_size( broadcasted_input, XlaHelpers::GetAllDimensions(broadcasted_input->shape())); - XLATensorPtr div_result = GetValueOrThrow( + XLA_ASSIGN_OR_THROW( + XLATensorPtr div_result, tensor_methods::div(elementwise_loss_backward, grad_scale)); return tensor_methods::mul(div_result, grad_output); } @@ -174,7 +176,8 @@ XLATensorPtr SoftplusBackward(const XLATensorPtr& grad_output, XLATensorPtr z = tensor_methods::exp(scaled_input); XLATensorPtr one_vec = tensor_methods::full_like(z, 1, z->GetDevice(), z->dtype()); - XLATensorPtr div = GetValueOrThrow( + XLA_ASSIGN_OR_THROW( + XLATensorPtr div, tensor_methods::div(z, tensor_methods::add(z, one_vec, 1))); return tensor_methods::where(tensor_methods::gt(scaled_input, threshold), @@ -207,24 +210,29 @@ XLATensorPtr EmbeddingDenseBackward(const XLATensorPtr& grad_output, int64_t numel = xla::ShapeUtil::ElementsIn(indices_shape_ref.get()); XLATensorPtr grad = tensor_methods::view(grad_output, {numel, grad_output->size(-1)}); - XLATensorPtr grad_weight = GetValueOrThrow( + XLA_ASSIGN_OR_THROW( + XLATensorPtr grad_weight, tensor_methods::full({num_weights, grad_output->size(-1)}, 0, grad_output->GetDevice(), grad_output->dtype())); XLATensorPtr indices_rank1 = tensor_methods::view(indices, {numel}); if (scale_grad_by_freq) { // Compute the histogram of index values. - XLATensorPtr counts = GetValueOrThrow(tensor_methods::full( - {num_weights}, 0, indices->GetDevice(), indices->dtype())); - XLATensorPtr ones = GetValueOrThrow(tensor_methods::full( - {numel}, 1, indices->GetDevice(), indices->dtype())); + XLA_ASSIGN_OR_THROW( + XLATensorPtr counts, + tensor_methods::full({num_weights}, 0, indices->GetDevice(), + indices->dtype())); + XLA_ASSIGN_OR_THROW(XLATensorPtr ones, + tensor_methods::full({numel}, 1, indices->GetDevice(), + indices->dtype())); tensor_methods::index_put_(counts, counts, {indices_rank1}, /*start_dim=*/0, /*values=*/ones, /*accumulate=*/true, /*result_permutation=*/{0}); XLATensorPtr grad_weights_scale = tensor_methods::index(counts, {indices_rank1}, 0); // Scale the value of the gradient by the histogram. - grad = GetValueOrThrow(tensor_methods::div( - grad, tensor_methods::unsqueeze(grad_weights_scale, 1))); + XLA_ASSIGN_OR_THROW( + grad, tensor_methods::div( + grad, tensor_methods::unsqueeze(grad_weights_scale, 1))); } // Don't accumulate gradients for indices which are equal with the given // padding_idx. diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index bf5f7966f8f0..688097e188f1 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -1056,7 +1056,7 @@ xla::PrimitiveType GetShapeDimensionType( std::shared_ptr get_data_handle( const at::Tensor& input) { - XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(input)); if (xtensor->CurrentDataHandle() != nullptr) { TF_VLOG(4) << "The xla tensor has a current data handle."; return std::dynamic_pointer_cast( From d55cc00177aebd65e3f36c875983aa4d7ea8d18a Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 28 Aug 2025 12:38:21 -0300 Subject: [PATCH 073/133] `init_python_bindings.cpp`: Use new macros for throwing exceptions. (#9595) Follow-up: #9588 and #9580 Target: `init_python_bindings.cpp` file In summary, this PR: - Replaces all calls to `OkOrThrow()` and `GetValueOrThrow()` (that throws an exception without source location information of the *"throw-site"*) with the macros `XLA_THROW_IF_ERROR()` and `XLA_ASSIGN_OR_THROW()`. - Corresponds to the fine-grained set of PRs that came from breaking down PR #9580 - Focuses on the `init_python_bindings.cpp` file, replacing every use of those, now deprecated, functions by the newly introduced macros. --- torch_xla/csrc/init_python_bindings.cpp | 336 +++++++++++++----------- 1 file changed, 183 insertions(+), 153 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index d605c07406ba..ffff87ee0a14 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -190,7 +190,17 @@ class PythonScope : public Scope { torch::PyWarningHandler handler; if constexpr (returns_status_type) { - return GetValueOrThrow(f(args...)); + if constexpr (std::is_void::value) { + // If the bound function returns `absl::Status`, check for errors + // and return void. + XLA_THROW_IF_ERROR(f(args...)); + return; + } else { + // If the bound function returns `absl::StatusOr`, check for + // errors and return `T`. + XLA_ASSIGN_OR_THROW(auto result, f(args...)); + return result; + } } else { return f(args...); } @@ -269,7 +279,8 @@ std::string GetTensorsDump( const std::vector& tensors, const std::function< std::string(absl::Span)>& coverter) { - auto xtensors = GetValueOrThrow(bridge::GetXlaTensors(tensors)); + XLA_ASSIGN_OR_THROW(std::vector xtensors, + bridge::GetXlaTensors(tensors)); std::vector nodes; std::transform( xtensors.begin(), xtensors.end(), std::back_inserter(nodes), @@ -318,7 +329,7 @@ static std::vector CollectXlaTensors( } bool IsNonDeviceDataIR(const at::Tensor& tensor) { - XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(tensor)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(tensor)); return xtensor->CurrentIrValue() && !DeviceData::Cast(xtensor->CurrentIrValue().node.get()); } @@ -343,15 +354,14 @@ std::vector XlaCustomCall( for (auto& dtype : output_dtypes) { dtypes.push_back(reinterpret_cast(dtype.ptr())->scalar_type); } - + XLA_ASSIGN_OR_THROW(std::vector xla_inputs, + bridge::GetXlaTensors(inputs)); if (is_tpu) { return bridge::AtenFromXlaTensors(tensor_methods::tpu_custom_call( - GetValueOrThrow(bridge::GetXlaTensors(inputs)), payload, output_shapes, - dtypes)); + xla_inputs, payload, output_shapes, dtypes)); } return bridge::AtenFromXlaTensors(tensor_methods::gpu_custom_call( - GetValueOrThrow(bridge::GetXlaTensors(inputs)), payload, output_shapes, - dtypes)); + xla_inputs, payload, output_shapes, dtypes)); } std::vector> ExtractXlaDotGeneralDimVectors( @@ -423,13 +433,13 @@ void AllReduceInPlace(const std::string& reduce_type, const std::vector>& replica_groups, bool pin_layout) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - std::vector xtensors = - GetValueOrThrow(bridge::GetXlaTensors(tensors)); + XLA_ASSIGN_OR_THROW(std::vector xtensors, + bridge::GetXlaTensors(tensors)); tensor_methods::all_reduce(xtensors, GetReduceType(reduce_type), scale, replica_groups, pin_layout); - std::vector new_xtensors = - GetValueOrThrow(bridge::GetXlaTensors(tensors)); - OkOrThrow(bridge::ReplaceXlaTensor(tensors, new_xtensors)); + XLA_ASSIGN_OR_THROW(std::vector new_xtensors, + bridge::GetXlaTensors(tensors)); + XLA_THROW_IF_ERROR(bridge::ReplaceXlaTensor(tensors, new_xtensors)); } at::Tensor AllReduce(const std::string& reduce_type, const at::Tensor& input, @@ -437,9 +447,9 @@ at::Tensor AllReduce(const std::string& reduce_type, const at::Tensor& input, const std::vector>& replica_groups, bool pin_layout) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_input, bridge::GetXlaTensor(input)); auto result = tensor_methods::all_reduce( - GetValueOrThrow(bridge::GetXlaTensor(input)), GetReduceType(reduce_type), - scale, replica_groups, pin_layout); + xla_input, GetReduceType(reduce_type), scale, replica_groups, pin_layout); return bridge::AtenFromXlaTensor(std::move(result)); } @@ -447,9 +457,11 @@ at::Tensor DynamicExpand(const at::Tensor& input, const std::vector& size, const at::Tensor& src_tensor, int src_dim, int target_dim) { + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_input, bridge::GetXlaTensor(input)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_src_tensor, + bridge::GetXlaTensor(src_tensor)); XLATensorPtr result = tensor_methods::dynamic_expand( - GetValueOrThrow(bridge::GetXlaTensor(input)), size, - GetValueOrThrow(bridge::GetXlaTensor(src_tensor)), src_dim, target_dim); + xla_input, size, xla_src_tensor, src_dim, target_dim); return bridge::AtenFromXlaTensor(std::move(result)); } @@ -457,17 +469,18 @@ at::Tensor DynamicView(const at::Tensor& input, const std::vector& size, const at::Tensor& src_tensor, int src_dim, int target_dim, float mul_scaler) { + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_input, bridge::GetXlaTensor(input)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_src_tensor, + bridge::GetXlaTensor(src_tensor)); XLATensorPtr result = tensor_methods::dynamic_view( - GetValueOrThrow(bridge::GetXlaTensor(input)), size, - GetValueOrThrow(bridge::GetXlaTensor(src_tensor)), src_dim, target_dim, - mul_scaler); + xla_input, size, xla_src_tensor, src_dim, target_dim, mul_scaler); return bridge::AtenFromXlaTensor(std::move(result)); } at::Tensor CastInt4(const at::Tensor& weight, const std::vector& int4_weight_values) { - auto result = tensor_methods::cast_int4( - GetValueOrThrow(bridge::GetXlaTensor(weight)), int4_weight_values); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_weight, bridge::GetXlaTensor(weight)); + auto result = tensor_methods::cast_int4(xla_weight, int4_weight_values); return bridge::AtenFromXlaTensor(std::move(result)); } @@ -476,9 +489,10 @@ at::Tensor QuantizeTensor(const at::Tensor& input, const std::vector& zero_point_list, int quant_min, int quant_max, const std::string& dtype, int axis) { - auto result = tensor_methods::quantize_tensor( - GetValueOrThrow(bridge::GetXlaTensor(input)), scale_list, zero_point_list, - quant_min, quant_max, dtype, axis); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_input, bridge::GetXlaTensor(input)); + auto result = + tensor_methods::quantize_tensor(xla_input, scale_list, zero_point_list, + quant_min, quant_max, dtype, axis); return bridge::AtenFromXlaTensor(std::move(result)); } @@ -487,9 +501,10 @@ at::Tensor DequantizeTensor(const at::Tensor& input, const std::vector& zero_point_list, int quant_min, int quant_max, const std::string& dtype, int axis) { - auto result = tensor_methods::dequantize_tensor( - GetValueOrThrow(bridge::GetXlaTensor(input)), scale_list, zero_point_list, - quant_min, quant_max, dtype, axis); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_input, bridge::GetXlaTensor(input)); + auto result = + tensor_methods::dequantize_tensor(xla_input, scale_list, zero_point_list, + quant_min, quant_max, dtype, axis); return bridge::AtenFromXlaTensor(std::move(result)); } @@ -503,10 +518,11 @@ std::pair> ReduceScatter( TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr result; torch::lazy::Value new_token; + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_input, bridge::GetXlaTensor(input)); std::tie(result, new_token) = tensor_methods::reduce_scatter( - GetValueOrThrow(bridge::GetXlaTensor(input)), *token, - GetReduceType(reduce_type), scale, scatter_dim, shard_count, - replica_groups, pin_layout, channel_id, use_global_device_ids); + xla_input, *token, GetReduceType(reduce_type), scale, scatter_dim, + shard_count, replica_groups, pin_layout, channel_id, + use_global_device_ids); return std::pair>( bridge::AtenFromXlaTensor(std::move(result)), std::make_shared(new_token)); @@ -518,12 +534,12 @@ std::shared_ptr ReduceScatterOut( int64_t scatter_dim, int64_t shard_count, const std::vector>& replica_groups, bool pin_layout) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr out = GetValueOrThrow(bridge::GetXlaTensor(output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_output, bridge::GetXlaTensor(output)); torch::lazy::Value new_token; + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_input, bridge::GetXlaTensor(input)); new_token = tensor_methods::reduce_scatter_out( - out, GetValueOrThrow(bridge::GetXlaTensor(input)), *token, - GetReduceType(reduce_type), scale, scatter_dim, shard_count, - replica_groups, pin_layout); + xla_output, xla_input, *token, GetReduceType(reduce_type), scale, + scatter_dim, shard_count, replica_groups, pin_layout); return std::make_shared(new_token); } @@ -534,8 +550,8 @@ ReduceScatterCoalesced(const std::string& reduce_type, double scale, int64_t scatter_dim, int64_t shard_count, const std::vector>& replica_groups, bool pin_layout) { - std::vector xtensors = - GetValueOrThrow(bridge::GetXlaTensors(inputs)); + XLA_ASSIGN_OR_THROW(std::vector xtensors, + bridge::GetXlaTensors(inputs)); std::vector result; torch::lazy::Value new_token; std::tie(result, new_token) = tensor_methods::reduce_scatter_coalesced( @@ -554,10 +570,10 @@ std::shared_ptr ReduceScatterCoalescedOut( const std::shared_ptr& token, double scale, int64_t scatter_dim, int64_t shard_count, const std::vector>& replica_groups, bool pin_layout) { - std::vector xtensors_out = - GetValueOrThrow(bridge::GetXlaTensors(outputs)); - std::vector xtensors = - GetValueOrThrow(bridge::GetXlaTensors(inputs)); + XLA_ASSIGN_OR_THROW(std::vector xtensors_out, + bridge::GetXlaTensors(outputs)); + XLA_ASSIGN_OR_THROW(std::vector xtensors, + bridge::GetXlaTensors(inputs)); torch::lazy::Value new_token; new_token = tensor_methods::reduce_scatter_coalesced_out( xtensors_out, xtensors, *token, GetReduceType(reduce_type), scale, @@ -571,9 +587,10 @@ at::Tensor AllGather(const at::Tensor& input, int64_t dim, int64_t shard_count, std::optional channel_id = std::nullopt, std::optional use_global_device_ids = std::nullopt) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto result = tensor_methods::all_gather( - GetValueOrThrow(bridge::GetXlaTensor(input)), dim, shard_count, - replica_groups, pin_layout, channel_id, use_global_device_ids); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_input, bridge::GetXlaTensor(input)); + auto result = + tensor_methods::all_gather(xla_input, dim, shard_count, replica_groups, + pin_layout, channel_id, use_global_device_ids); return bridge::AtenFromXlaTensor(std::move(result)); } @@ -583,11 +600,12 @@ std::shared_ptr AllGatherOut( int64_t shard_count, const std::vector>& replica_groups, bool pin_layout) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr out = GetValueOrThrow(bridge::GetXlaTensor(output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_output, bridge::GetXlaTensor(output)); torch::lazy::Value new_token; - new_token = tensor_methods::all_gather_out( - out, GetValueOrThrow(bridge::GetXlaTensor(input)), *token, dim, - shard_count, replica_groups, pin_layout); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_input, bridge::GetXlaTensor(input)); + new_token = + tensor_methods::all_gather_out(xla_output, xla_input, *token, dim, + shard_count, replica_groups, pin_layout); return std::make_shared(new_token); } @@ -597,8 +615,8 @@ AllGatherCoalesced(const std::vector& tensors, int64_t dim, int64_t shard_count, const std::vector>& replica_groups, bool pin_layout) { - std::vector xtensors = - GetValueOrThrow(bridge::GetXlaTensors(tensors)); + XLA_ASSIGN_OR_THROW(std::vector xtensors, + bridge::GetXlaTensors(tensors)); std::vector result; torch::lazy::Value new_token; std::tie(result, new_token) = tensor_methods::all_gather_coalesced( @@ -615,10 +633,10 @@ std::shared_ptr AllGatherCoalescedOut( const std::shared_ptr& token, int64_t dim, int64_t shard_count, const std::vector>& replica_groups, bool pin_layout) { - std::vector xtensors_out = - GetValueOrThrow(bridge::GetXlaTensors(outputs)); - std::vector xtensors = - GetValueOrThrow(bridge::GetXlaTensors(inputs)); + XLA_ASSIGN_OR_THROW(std::vector xtensors_out, + bridge::GetXlaTensors(outputs)); + XLA_ASSIGN_OR_THROW(std::vector xtensors, + bridge::GetXlaTensors(inputs)); torch::lazy::Value new_token; new_token = tensor_methods::all_gather_coalesced_out( xtensors_out, xtensors, *token, dim, shard_count, replica_groups, @@ -633,9 +651,10 @@ std::pair> AllToAll( TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr result; torch::lazy::Value new_token; + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_input, bridge::GetXlaTensor(input)); std::tie(result, new_token) = tensor_methods::all_to_all( - GetValueOrThrow(bridge::GetXlaTensor(input)), *token, split_dimension, - concat_dimension, split_count, replica_groups, pin_layout); + xla_input, *token, split_dimension, concat_dimension, split_count, + replica_groups, pin_layout); return std::pair>( bridge::AtenFromXlaTensor(std::move(result)), std::make_shared(new_token)); @@ -646,9 +665,9 @@ std::pair> CollectivePermute( const std::vector>& source_target_pairs) { XLATensorPtr result; torch::lazy::Value new_token; + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_input, bridge::GetXlaTensor(input)); std::tie(result, new_token) = tensor_methods::collective_permute( - GetValueOrThrow(bridge::GetXlaTensor(input)), *token, - source_target_pairs); + xla_input, *token, source_target_pairs); return std::pair>( bridge::AtenFromXlaTensor(std::move(result)), std::make_shared(new_token)); @@ -664,8 +683,9 @@ std::pair> Send( int64_t channel_id) { XLATensorPtr result; torch::lazy::Value new_token; - std::tie(result, new_token) = tensor_methods::send( - GetValueOrThrow(bridge::GetXlaTensor(input)), *token, channel_id); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_input, bridge::GetXlaTensor(input)); + std::tie(result, new_token) = + tensor_methods::send(xla_input, *token, channel_id); return {bridge::AtenFromXlaTensor(std::move(result)), std::make_shared(new_token)}; } @@ -673,7 +693,7 @@ std::pair> Send( std::pair> Recv( at::Tensor& output, const std::shared_ptr& token, int64_t channel_id) { - XLATensorPtr out = GetValueOrThrow(bridge::GetXlaTensor(output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr out, bridge::GetXlaTensor(output)); XLATensorPtr result; torch::lazy::Value new_token; std::tie(result, new_token) = tensor_methods::recv(out, *token, channel_id); @@ -831,12 +851,12 @@ void ClearPendingIrs(const std::string& device_str) { } std::ptrdiff_t GetTensorViewAliasId(const at::Tensor& tensor) { - XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(tensor)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(tensor)); return xtensor->GetViewAliasId(); } std::ptrdiff_t GetTensorId(const at::Tensor& tensor) { - XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(tensor)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(tensor)); return xtensor->GetUniqueId(); } @@ -867,7 +887,7 @@ std::vector GetXlaTensorsFromAten( } at::Tensor GetXlaTensorDimensionSize(const at::Tensor& tensor, int64_t dim) { - XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(tensor)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(tensor)); return bridge::AtenFromXlaTensor( tensor_methods::get_dimensions_size(xtensor, {dim})); } @@ -912,8 +932,8 @@ py::object GetRevisions() { std::vector XlaUserComputation( const std::string& opname, const std::vector& inputs, runtime::ComputationClient::ComputationPtr computation) { - std::vector xinputs = - GetValueOrThrow(bridge::GetXlaTensors(inputs)); + XLA_ASSIGN_OR_THROW(std::vector xinputs, + bridge::GetXlaTensors(inputs)); std::vector xresults = tensor_methods::user_computation(opname, xinputs, std::move(computation)); std::vector results; @@ -927,8 +947,8 @@ std::vector XlaUserComputation( runtime::ComputationClient::ComputationPtr CreateComputation( const std::string& name, xla::XlaOp root) { - xla::XlaComputation computation = - GetValueOrThrow(root.builder()->Build(root)); + XLA_ASSIGN_OR_THROW(xla::XlaComputation computation, + root.builder()->Build(root)); return std::make_shared( name, std::move(computation)); } @@ -1005,8 +1025,8 @@ void MapXlaEnvVarsToLazy() { } at::Tensor MarkTensor(const at::Tensor& input, const std::string& info) { - XLATensorPtr result = tensor_methods::mark_tensor( - GetValueOrThrow(bridge::GetXlaTensor(input)), info); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_input, bridge::GetXlaTensor(input)); + XLATensorPtr result = tensor_methods::mark_tensor(xla_input, info); return bridge::AtenFromXlaTensor(std::move(result)); } @@ -1169,8 +1189,8 @@ class PyLoweringContext { // Builds a HLO graph given a set of output tensors. void Build(std::vector tensors) { // Get the backing XLA tensors from the output torch tensor handles - std::vector xtensors = - GetValueOrThrow(bridge::GetXlaTensors(tensors)); + XLA_ASSIGN_OR_THROW(std::vector xtensors, + bridge::GetXlaTensors(tensors)); // Get the lazy IR value from the output XLA tensors std::vector ir_values; @@ -1188,7 +1208,7 @@ class PyLoweringContext { ShardingUtil::SetHloSharding(&lowering_ctx); - computation = GetValueOrThrow(lowering_ctx.BuildXla()); + XLA_ASSIGN_OR_THROW(computation, lowering_ctx.BuildXla()); } // Builds a HLO graph given a set of output tensors, and add unused parameters @@ -1196,8 +1216,8 @@ class PyLoweringContext { void BuildForiLoop(std::vector tensors, std::vector additional_inputs_list = {}) { // Get the backing XLA tensors from the output torch tensor handles - std::vector xtensors = - GetValueOrThrow(bridge::GetXlaTensors(tensors)); + XLA_ASSIGN_OR_THROW(std::vector xtensors, + bridge::GetXlaTensors(tensors)); // Get the lazy IR value from the output XLA tensors std::vector ir_values; @@ -1222,8 +1242,8 @@ class PyLoweringContext { local_builder->GetProgramShape()->parameters_size(); int64_t additional_inputs_list_size = additional_inputs_list.size(); for (int64_t i = parameter_idx; i < additional_inputs_list_size; i++) { - XLATensorPtr xtensor = - GetValueOrThrow(bridge::GetXlaTensor(additional_inputs_list[i])); + XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, + bridge::GetXlaTensor(additional_inputs_list[i])); xla::Shape shape = xtensor->shape().get(); xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, "UnusedArgumentsPlaceholder"); @@ -1233,7 +1253,7 @@ class PyLoweringContext { ShardingUtil::SetHloSharding(&lowering_ctx); - computation = GetValueOrThrow(lowering_ctx.BuildXla()); + XLA_ASSIGN_OR_THROW(computation, lowering_ctx.BuildXla()); // wrap inputs of cond/body_computation if ((GetNameString() == "condctx") || (GetNameString() == "bodyctx")) { @@ -1243,15 +1263,17 @@ class PyLoweringContext { if (UseVirtualDevice()) { param_shardings = XlaHelpers::ExtractInputShardings(computation); } - xla::ProgramShape program_shape = - GetValueOrThrow(computation.GetProgramShape()); + XLA_ASSIGN_OR_THROW(xla::ProgramShape program_shape, + computation.GetProgramShape()); // TODO(@manfei): please confirm whether we check for more than two or use // default value true bool should_wrap_parameter = (program_shape.parameters_size() >= 2); if (should_wrap_parameter) { - computation = GetValueOrThrow(XlaHelpers::WrapXlaComputation( - computation, program_shape.parameters(), param_shardings, - /* buffer_donor_indices */ {})); + XLA_ASSIGN_OR_THROW( + computation, + XlaHelpers::WrapXlaComputation( + computation, program_shape.parameters(), param_shardings, + /* buffer_donor_indices */ {})); } } } @@ -1266,7 +1288,8 @@ class PyLoweringContext { lowering_ctx.GetParametersData(); // Fetch this parameter data - std::vector literals = GetValueOrThrow( + XLA_ASSIGN_OR_THROW( + std::vector literals, runtime::GetComputationClientOrDie()->TransferFromDevice( UnwrapXlaData(device_data))); @@ -1319,7 +1342,7 @@ class PyLoweringContext { // remain parameters. int64_t GetTensorParameterId(at::Tensor tensor) { // Convert tensor into the backing lazy node - XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(tensor)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(tensor)); torch::lazy::Value value = xtensor->GetIrValue(); const torch::lazy::Node* node = value.node.get(); if (node->op() != xla_device_data) { @@ -1346,8 +1369,10 @@ class PyLoweringContext { // Create a serialized HloModule protobuf from a lowered graph py::bytes GetHlo() { const xla::HloModuleProto& proto = computation.proto(); - return GetValueOrThrow( + XLA_ASSIGN_OR_THROW( + std::string serialized_module, runtime::util::GetDeterministicSerializedModuleProto(proto)); + return serialized_module; } // Create human-readable HloModule protobuf text from a lowered graph @@ -1625,7 +1650,7 @@ void InitXlaModuleBindings(py::module m) { }) .def("_get_xla_tensor_shape_type", [](const at::Tensor& tensor) -> std::string { - auto xla_tensor = GetValueOrThrow(bridge::GetXlaTensor(tensor)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_tensor, bridge::GetXlaTensor(tensor)); xla::Shape shape = xla_tensor->shape().get(); return xla::primitive_util::LowercasePrimitiveTypeName( shape.element_type()); @@ -1719,12 +1744,12 @@ void InitXlaModuleBindings(py::module m) { }) .def("_init_computation_client", []() { - GetValueOrThrow(runtime::GetComputationClient()); + XLA_THROW_IF_ERROR(runtime::GetComputationClient()); }) .def("_xla_get_device_hw_type", [](const at::Tensor& tensor) { - XLATensorPtr xtensor = - GetValueOrThrow(bridge::GetXlaTensor(tensor)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, + bridge::GetXlaTensor(tensor)); XlaDeviceType xla_device_type = static_cast(xtensor->GetDevice().type()); return DeviceType(xla_device_type).toString(); @@ -1862,8 +1887,9 @@ void InitXlaModuleBindings(py::module m) { double scale, const py::list& groups) { std::vector> replica_groups = CreateReduceGroups(groups); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_input, bridge::GetXlaTensor(input)); auto result = tensor_methods::all_reduce( - GetValueOrThrow(bridge::GetXlaTensor(input)), + xla_input, GetReduceType(reduce_type), scale, std::move(replica_groups)); return bridge::AtenFromXlaTensor(std::move(result)); }) @@ -2076,8 +2102,9 @@ void InitXlaModuleBindings(py::module m) { const py::list& groups) { std::vector> replica_groups = CreateReduceGroups(groups); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_input, bridge::GetXlaTensor(input)); auto result = tensor_methods::reduce_scatter( - GetValueOrThrow(bridge::GetXlaTensor(input)), + xla_input, GetReduceType(reduce_type), scale, scatter_dim, shard_count, replica_groups); return bridge::AtenFromXlaTensor(std::move(result)); @@ -2443,8 +2470,8 @@ void InitXlaModuleBindings(py::module m) { [](const std::string& device) { return GetMemoryInfo(device); }) .def("_xla_set_mat_mul_precision", [](const std::string& mat_mul_precision) { - xla::PrecisionConfig::Precision precision = - GetValueOrThrow(xla::StringToPrecision(mat_mul_precision)); + XLA_ASSIGN_OR_THROW(xla::PrecisionConfig::Precision precision, + xla::StringToPrecision(mat_mul_precision)); XlaHelpers::set_mat_mul_precision(precision); }) .def("_xla_get_mat_mul_precision", []() { @@ -2493,7 +2520,7 @@ void InitXlaModuleBindings(py::module m) { std::string hlo_text; { NoGilSection nogil; - hlo_text = GetValueOrThrow(runtime::util::GetComputationHloText( + XLA_ASSIGN_OR_THROW(hlo_text, runtime::util::GetComputationHloText( computation->computation())); } return hlo_text; @@ -2517,16 +2544,16 @@ void InitXlaModuleBindings(py::module m) { bool maximize) { { NoGilSection nogil; - XLATensorPtr found_inf_xla = - GetValueOrThrow(bridge::GetXlaTensor(found_inf)); - XLATensorPtr step_xla = - GetValueOrThrow(bridge::GetXlaTensor(step)); - XLATensorPtr param_xla = - GetValueOrThrow(bridge::GetXlaTensor(param)); - XLATensorPtr d_p_xla = GetValueOrThrow(bridge::GetXlaTensor(d_p)); - XLATensorPtr buf_xla = GetValueOrThrow(bridge::GetXlaTensor(buf)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_found_inf, + bridge::GetXlaTensor(found_inf)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_step, + bridge::GetXlaTensor(step)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_param, + bridge::GetXlaTensor(param)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_d_p, bridge::GetXlaTensor(d_p)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_buf, bridge::GetXlaTensor(buf)); tensor_methods::sgd_optimizer_step_( - found_inf_xla, step_xla, param_xla, buf_xla, d_p_xla, + xla_found_inf, xla_step, xla_param, xla_buf, xla_d_p, weight_decay, momentum, lr, dampening, nesterov, maximize); } }) @@ -2538,23 +2565,23 @@ void InitXlaModuleBindings(py::module m) { bool use_adamw) { { NoGilSection nogil; - XLATensorPtr found_inf_xla = - GetValueOrThrow(bridge::GetXlaTensor(found_inf)); - XLATensorPtr step_xla = - GetValueOrThrow(bridge::GetXlaTensor(step)); - XLATensorPtr param_xla = - GetValueOrThrow(bridge::GetXlaTensor(param)); - XLATensorPtr grad_xla = - GetValueOrThrow(bridge::GetXlaTensor(grad)); - XLATensorPtr exp_avg_xla = - GetValueOrThrow(bridge::GetXlaTensor(exp_avg)); - XLATensorPtr exp_avg_sq_xla = - GetValueOrThrow(bridge::GetXlaTensor(exp_avg_sq)); - XLATensorPtr max_exp_avg_sq_xla = - GetValueOrThrow(bridge::GetXlaTensor(max_exp_avg_sq)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_found_inf, + bridge::GetXlaTensor(found_inf)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_step, + bridge::GetXlaTensor(step)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_param, + bridge::GetXlaTensor(param)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad, + bridge::GetXlaTensor(grad)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_exp_avg, + bridge::GetXlaTensor(exp_avg)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_exp_avg_sq, + bridge::GetXlaTensor(exp_avg_sq)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_max_exp_avg_sq, + bridge::GetXlaTensor(max_exp_avg_sq)); tensor_methods::adam_optimizer_step_( - found_inf_xla, step_xla, param_xla, grad_xla, exp_avg_xla, - exp_avg_sq_xla, max_exp_avg_sq_xla, beta1, beta2, lr, + xla_found_inf, xla_step, xla_param, xla_grad, xla_exp_avg, + xla_exp_avg_sq, xla_max_exp_avg_sq, beta1, beta2, lr, weight_decay, eps, amsgrad, maximize, use_adamw); } }) @@ -2564,7 +2591,7 @@ void InitXlaModuleBindings(py::module m) { }) .def("_xla_annotate_custom_sharding", [](const at::Tensor& input, xla::OpSharding sharding) { - XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(input)); ShardingUtil::XlaAnnotateCustomSharding(xtensor, sharding); }) .def("_mark_manual_sharding", @@ -2576,7 +2603,7 @@ void InitXlaModuleBindings(py::module m) { .def( "_spmd_full_to_shard_shape", [](const at::Tensor& input) -> at::Tensor { - XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(input)); auto sharding_spec = xtensor->sharding_spec(); XLA_CHECK(sharding_spec != nullptr) << "Input tensor is not sharded"; @@ -2597,7 +2624,7 @@ void InitXlaModuleBindings(py::module m) { [](const at::Tensor& input, const xla::OpSharding& sharding, const std::vector& output_shape, const py::object& output_dtype) -> at::Tensor { - XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(input)); auto sharding_spec = xtensor->sharding_spec(); XLA_CHECK(sharding_spec != nullptr && sharding_spec->sharding.type() == xla::OpSharding::MANUAL) @@ -2619,17 +2646,17 @@ void InitXlaModuleBindings(py::module m) { }) .def("_xla_clear_sharding", [](const at::Tensor& input) { - XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(input)); xtensor->ClearShardingSpec(); }) .def("_get_xla_sharding_spec", [](const at::Tensor& input) -> std::string { - XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(input)); return GetXLAShardingSpec(xtensor); }) .def("_get_xla_op_sharding", [](const at::Tensor& input) -> std::optional { - XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(input)); XLATensor::ShardingSpecPtr sharding_spec = xtensor ? xtensor->sharding_spec() : nullptr; if (sharding_spec != nullptr) { @@ -2646,14 +2673,14 @@ void InitXlaModuleBindings(py::module m) { std::vector sharding_specs; sharding_specs.reserve(tensors.size()); for (const at::Tensor& tensor : tensors) { - sharding_specs.push_back(GetXLAShardingSpec( - GetValueOrThrow(bridge::GetXlaTensor(tensor)))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_tensor, bridge::GetXlaTensor(tensor)); + sharding_specs.push_back(GetXLAShardingSpec(xla_tensor)); } return sharding_specs; }) .def("_get_xla_sharding_type", [](const at::Tensor& input) -> std::optional { - XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(input)); auto sharding_spec = xtensor->sharding_spec(); if (sharding_spec != nullptr) { return ShardingUtil::GetShardingType(sharding_spec->sharding); @@ -2748,8 +2775,8 @@ void InitXlaModuleBindings(py::module m) { std::vector element_types; // Find all shard handles for transfer for (auto& tensor : input) { - XLATensorPtr xtensor = - GetValueOrThrow(bridge::GetXlaTensor(tensor)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, + bridge::GetXlaTensor(tensor)); XLA_CHECK(xtensor->GetXlaData() != nullptr) << "Shard data is not available"; XLA_CHECK(xtensor->sharding_spec() != nullptr) @@ -2767,8 +2794,8 @@ void InitXlaModuleBindings(py::module m) { shard_handles[0]->shape().element_type())); } - std::vector cpu_shards = - GetValueOrThrow(XlaDataToTensors(WrapXlaData(handles), element_types)); + XLA_ASSIGN_OR_THROW(std::vector cpu_shards, + XlaDataToTensors(WrapXlaData(handles), element_types)); // Populate the resulting vector of shards and device strings std::vector>> result; int shards_per_tensor = @@ -2802,8 +2829,8 @@ void InitXlaModuleBindings(py::module m) { -> std::vector>> { std::vector>> result; for (auto& tensor : input_tensors) { - XLATensorPtr xtensor = - GetValueOrThrow(bridge::GetXlaTensor(tensor)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, + bridge::GetXlaTensor(tensor)); XLA_CHECK(xtensor->sharding_spec() != nullptr) << "Tensor is not sharded"; auto handle = @@ -2860,8 +2887,8 @@ void InitXlaModuleBindings(py::module m) { "_load_local_shards", [](const at::Tensor& tensor, std::vector& shards, std::vector& devices) { - XLATensorPtr xtensor = - GetValueOrThrow(bridge::GetXlaTensor(tensor)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, + bridge::GetXlaTensor(tensor)); XLA_CHECK(xtensor->sharding_spec() != nullptr) << "Cannot load local shards into a non sharded tensor"; XLA_CHECK(devices.size() == @@ -2938,7 +2965,7 @@ void InitXlaModuleBindings(py::module m) { }) .def("_is_placecholder", [](at::Tensor& input) { - XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(input)); return xtensor->CurrentDataHandle() && !xtensor->CurrentDataHandle()->HasValue(); }) @@ -3024,8 +3051,9 @@ void InitXlaModuleBindings(py::module m) { reinterpret_cast(dtype.ptr())->scalar_type); } + XLA_ASSIGN_OR_THROW(std::vector xla_inputs, bridge::GetXlaTensors(inputs)); auto xtensors = tensor_methods::custom_call( - GetValueOrThrow(bridge::GetXlaTensors(inputs)), target, + xla_inputs, target, output_shapes, dtypes, has_side_effect, backend_config, api_version, frontend_attributes); return bridge::AtenFromXlaTensors(std::move(xtensors)); @@ -3063,7 +3091,7 @@ void InitXlaModuleBindings(py::module m) { .def("_set_xla_custom_op_name_prefix", [](const at::Tensor& input, const std::string& op_name_prefix, size_t max_call_stack_depth) -> bool { - XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(input)); std::shared_ptr user_meta = std::make_shared(op_name_prefix, max_call_stack_depth); @@ -3090,8 +3118,9 @@ void InitXlaModuleBindings(py::module m) { std::vector handles; handles.reserve(tensors.size()); for (auto& tensor : tensors) { - handles.push_back( - GetValueOrThrow(bridge::GetXlaTensor(tensor))->GetHandle()); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_tensor, + bridge::GetXlaTensor(tensor)); + handles.push_back(xla_tensor->GetHandle()); } return handles; }) @@ -3108,8 +3137,8 @@ void InitXlaModuleBindings(py::module m) { .def("_xla_mark_dynamic", [](const at::Tensor& input, uint32_t dim) { TORCH_LAZY_COUNTER("XlaMarkDynamic", 1); - XLATensorPtr xtensor = - GetValueOrThrow(bridge::GetXlaTensor(input)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, + bridge::GetXlaTensor(input)); xtensor->MarkDynamicDimension(dim); }) .def("_xla_dynamic_expand", @@ -3150,8 +3179,8 @@ void InitXlaModuleBindings(py::module m) { // Note that donated buffers can not be used after being donated. "_set_buffer_donation", [](at::Tensor& tensor, bool should_donate) -> bool { - XLATensorPtr xtensor = - GetValueOrThrow(bridge::GetXlaTensor(tensor)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, + bridge::GetXlaTensor(tensor)); bool buffer_donation_updated = false; if (xtensor->CurrentDataHandle() != nullptr) { auto data = @@ -3174,8 +3203,8 @@ void InitXlaModuleBindings(py::module m) { }) .def("_get_buffer_donation", [](const at::Tensor& input) -> bool { - XLATensorPtr xtensor = - GetValueOrThrow(bridge::GetXlaTensor(input)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, + bridge::GetXlaTensor(input)); if (!xtensor) { return false; } else if (xtensor->CurrentDataHandle() != nullptr) { @@ -3196,8 +3225,8 @@ void InitXlaModuleBindings(py::module m) { }) .def("_on_ready_callback", [](const at::Tensor& tensor, const std::function& callback) { - XLATensorPtr xtensor = - GetValueOrThrow(bridge::GetXlaTensor(tensor)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, + bridge::GetXlaTensor(tensor)); XLA_CHECK(xtensor) << "The input is not an XLA tensor."; // Wait for placeholder `Data`s to be assigned XLAGraphExecutor::Get()->WaitDeviceOps({}); @@ -3221,7 +3250,8 @@ void InitXlaModuleBindings(py::module m) { }) .def("_unsafe_buffer_pointer", [](const at::Tensor& input) -> std::uintptr_t { - auto xtensor = GetValueOrThrow(bridge::GetXlaTensor(input)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, + bridge::GetXlaTensor(input)); if (xtensor->CurrentDataHandle() != nullptr) { std::shared_ptr data = std::dynamic_pointer_cast( @@ -3398,8 +3428,8 @@ void InitXlaModuleBindings(py::module m) { }) .def("_unique_id_for_ir_and_data", [](const at::Tensor& tensor) -> std::string { - XLATensorPtr xtensor = - GetValueOrThrow(bridge::GetXlaTensor(tensor)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, + bridge::GetXlaTensor(tensor)); if (xtensor->CurrentIrValue()) { torch::lazy::Value value = xtensor->CurrentIrValue(); return std::to_string((uintptr_t)value.node.get()) + ", " + From 90be04af3987556fed90ac9fba9d4e5823e01b1e Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 28 Aug 2025 12:38:33 -0300 Subject: [PATCH 074/133] `aten_xla_type.cpp`: Use new macros for throwing exceptions. (#9596) Follow-up: #9588 and #9580 Target: `aten_xla_type.cpp` file In summary, this PR: - Replaces all calls to `OkOrThrow()` and `GetValueOrThrow()` (that throws an exception without source location information of the *"throw-site"*) with the macros `XLA_THROW_IF_ERROR()` and `XLA_ASSIGN_OR_THROW()`. - Corresponds to the fine-grained set of PRs that came from breaking down PR #9580 - Focuses on the `aten_xla_type.cpp` file, replacing every use of those, now deprecated, functions by the newly introduced macros. --- torch_xla/csrc/aten_xla_type.cpp | 1583 ++++++++++++++++-------------- 1 file changed, 834 insertions(+), 749 deletions(-) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 9606a989f831..86b4cb84707e 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -167,8 +167,8 @@ class OpConfig { // Transform the inputs into a list of XLATensorPtr. // For that, either get their corresponding XLATensorPtr, or use the found // XLA tensor's BackendDevice for creating a new one. - torch::lazy::BackendDevice device = - GetValueOrThrow(bridge::GetXlaTensor(*it))->GetDevice(); + XLA_ASSIGN_OR_THROW(XLATensorPtr tensor, bridge::GetXlaTensor(*it)); + torch::lazy::BackendDevice device = tensor->GetDevice(); XLAInputVector xla_inputs(inputs.size()); std::transform(inputs.begin(), inputs.end(), xla_inputs.begin(), [&](const at::Tensor& tensor) { @@ -347,7 +347,7 @@ std::pair GetBinaryOperands( XLATensorPtr other_tensor; auto self_xtensor_status = bridge::GetXlaTensor(self); if (!self_xtensor_status.ok()) { - other_tensor = GetValueOrThrow(bridge::GetXlaTensor(other)); + XLA_ASSIGN_OR_THROW(other_tensor, bridge::GetXlaTensor(other)); self_tensor = bridge::GetOrCreateXlaTensor(self, other_tensor->GetDevice()); } else { self_tensor = std::move(self_xtensor_status).value(); @@ -397,8 +397,8 @@ template at::Tensor DoBinaryOp(const at::Tensor& self, const at::Scalar& other, const B& bin_op) { at::ScalarType dtype = at::result_type(self, other); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - XLATensorPtr result = bin_op(self_tensor, other, dtype); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLATensorPtr result = bin_op(xla_self, other, dtype); return bridge::AtenFromXlaTensor(result); } @@ -406,8 +406,8 @@ template at::Tensor DoBinaryOp(const at::Scalar& self, const at::Tensor& other, const B& bin_op) { at::ScalarType dtype = at::result_type(self, other); - XLATensorPtr other_tensor = GetValueOrThrow(bridge::GetXlaTensor(other)); - XLATensorPtr result = bin_op(self, other_tensor, dtype); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_other, bridge::GetXlaTensor(other)); + XLATensorPtr result = bin_op(self, xla_other, dtype); return bridge::AtenFromXlaTensor(result); } @@ -424,8 +424,8 @@ at::Tensor DoBinaryOpWithoutPromo(const at::Tensor& self, template at::Tensor DoBinaryOpWithoutPromo(const at::Tensor& self, const at::Scalar& other, const B& bin_op) { - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - XLATensorPtr result = bin_op(self_tensor, other); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLATensorPtr result = bin_op(xla_self, other); return bridge::AtenFromXlaTensor(result); } @@ -436,8 +436,8 @@ void DoBinaryOpOut(const at::Tensor& self, const at::Tensor& other, XLA_CHECK(at::canCast(/*from=*/dtype, /*to=*/out.scalar_type())); std::pair operands = GetBinaryOperands(self, UnwrapNumber(other, dtype)); - XLATensorPtr out_tensor = GetValueOrThrow(bridge::GetXlaTensor(out)); - bin_op_out(operands.first, operands.second, out_tensor); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_out, bridge::GetXlaTensor(out)); + bin_op_out(operands.first, operands.second, xla_out); } } // namespace @@ -445,8 +445,8 @@ void DoBinaryOpOut(const at::Tensor& self, const at::Tensor& other, at::Tensor& XLANativeFunctions::__ilshift__(at::Tensor& self, const at::Scalar& other) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - tensor_methods::__ilshift__(self_tensor, other); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + tensor_methods::__ilshift__(xla_self, other); return self; } @@ -454,9 +454,9 @@ at::Tensor& XLANativeFunctions::__ilshift__(at::Tensor& self, const at::Tensor& other) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); CheckBinaryOpTypePromotion(self, self, other); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - tensor_methods::__ilshift__(self_tensor, - GetValueOrThrow(bridge::GetXlaTensor(other))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_other, bridge::GetXlaTensor(other)); + tensor_methods::__ilshift__(xla_self, xla_other); return self; } @@ -464,8 +464,8 @@ at::Tensor& XLANativeFunctions::__irshift__(at::Tensor& self, const at::Scalar& other) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); CheckBinaryOpTypePromotion(self, self, other); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - tensor_methods::__irshift__(self_tensor, other); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + tensor_methods::__irshift__(xla_self, other); return self; } @@ -473,9 +473,9 @@ at::Tensor& XLANativeFunctions::__irshift__(at::Tensor& self, const at::Tensor& other) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); CheckBinaryOpTypePromotion(self, self, other); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - tensor_methods::__irshift__(self_tensor, - GetValueOrThrow(bridge::GetXlaTensor(other))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_other, bridge::GetXlaTensor(other)); + tensor_methods::__irshift__(xla_self, xla_other); return self; } @@ -530,8 +530,9 @@ at::Tensor XLANativeFunctions::_adaptive_avg_pool3d( } auto common_device = torch_xla::bridge::GetXlaDevice(self); XLA_CHECK(common_device); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); torch::lazy::NodePtr node = torch_xla::MakeNode( - GetValueOrThrow(bridge::GetXlaTensor(self))->GetIrValue(), + xla_self->GetIrValue(), std::vector(output_size.begin(), output_size.end())); return torch_xla::bridge::AtenFromXlaTensor( torch_xla::XLATensor::Create(std::move(node), *common_device)); @@ -552,9 +553,11 @@ at::Tensor XLANativeFunctions::_adaptive_avg_pool3d_backward( } auto common_device = torch_xla::bridge::GetXlaDevice(grad_output, self); XLA_CHECK(common_device); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); torch::lazy::NodePtr node = torch_xla::MakeNode( - GetValueOrThrow(bridge::GetXlaTensor(grad_output))->GetIrValue(), - GetValueOrThrow(bridge::GetXlaTensor(self))->GetIrValue()); + xla_grad_output->GetIrValue(), xla_self->GetIrValue()); return torch_xla::bridge::AtenFromXlaTensor( torch_xla::XLATensor::Create(std::move(node), *common_device)); @@ -569,8 +572,9 @@ at::Tensor XLANativeFunctions::_adaptive_avg_pool2d( return at::native::call_fallback_fn< &xla_fallback, ATEN_OP(_adaptive_avg_pool2d)>::call(self, output_size); } - return bridge::AtenFromXlaTensor(tensor_methods::_adaptive_avg_pool2d( - GetValueOrThrow(bridge::GetXlaTensor(self)), output_size_list)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor( + tensor_methods::_adaptive_avg_pool2d(xla_self, output_size_list)); } at::Tensor XLANativeFunctions::_adaptive_avg_pool2d_backward( @@ -585,10 +589,11 @@ at::Tensor XLANativeFunctions::_adaptive_avg_pool2d_backward( &xla_fallback, ATEN_OP(_adaptive_avg_pool2d_backward)>::call(grad_output, self); } + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor( - tensor_methods::_adaptive_avg_pool2d_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad_output)), - GetValueOrThrow(bridge::GetXlaTensor(self)))); + tensor_methods::_adaptive_avg_pool2d_backward(xla_grad_output, xla_self)); } std::tuple XLANativeFunctions::adaptive_max_pool2d( @@ -600,9 +605,9 @@ std::tuple XLANativeFunctions::adaptive_max_pool2d( return at::native::call_fallback_fn< &xla_fallback, ATEN_OP(adaptive_max_pool2d)>::call(self, output_size); } + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); std::tuple res = - tensor_methods::adaptive_max_pool2d( - GetValueOrThrow(bridge::GetXlaTensor(self)), output_size_list); + tensor_methods::adaptive_max_pool2d(xla_self, output_size_list); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(res)), bridge::AtenFromXlaTensor(std::get<1>(res))); } @@ -621,19 +626,24 @@ at::Tensor XLANativeFunctions::adaptive_max_pool2d_backward( self, indices); } - return bridge::AtenFromXlaTensor(tensor_methods::adaptive_max_pool2d_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad_output)), - GetValueOrThrow(bridge::GetXlaTensor(self)))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor( + tensor_methods::adaptive_max_pool2d_backward(xla_grad_output, xla_self)); } void XLANativeFunctions::_amp_foreach_non_finite_check_and_unscale_( at::TensorList self, at::Tensor& found_inf, const at::Tensor& inv_scale) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr found_inf_tensor = - GetValueOrThrow(bridge::GetXlaTensor(found_inf)); + XLA_ASSIGN_OR_THROW(std::vector xla_self, + bridge::GetXlaTensors(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_found_inf, + bridge::GetXlaTensor(found_inf)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_inv_scale, + bridge::GetXlaTensor(inv_scale)); tensor_methods::_amp_foreach_non_finite_check_and_unscale_( - GetValueOrThrow(bridge::GetXlaTensors(self)), found_inf_tensor, - GetValueOrThrow(bridge::GetXlaTensor(inv_scale))); + xla_self, xla_found_inf, xla_inv_scale); } at::Tensor& XLANativeFunctions::_amp_update_scale_(at::Tensor& current_scale, @@ -643,14 +653,15 @@ at::Tensor& XLANativeFunctions::_amp_update_scale_(at::Tensor& current_scale, double scale_backoff_factor, int64_t growth_interval) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr growth_tracker_tensor = - GetValueOrThrow(bridge::GetXlaTensor(growth_tracker)); - XLATensorPtr current_scale_tensor = - GetValueOrThrow(bridge::GetXlaTensor(current_scale)); - tensor_methods::_amp_update_scale_( - growth_tracker_tensor, current_scale_tensor, - GetValueOrThrow(bridge::GetXlaTensor(found_inf)), scale_growth_factor, - scale_backoff_factor, growth_interval); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_current_scale, + bridge::GetXlaTensor(current_scale)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_growth_tracker, + bridge::GetXlaTensor(growth_tracker)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_found_inf, + bridge::GetXlaTensor(found_inf)); + tensor_methods::_amp_update_scale_(xla_growth_tracker, xla_current_scale, + xla_found_inf, scale_growth_factor, + scale_backoff_factor, growth_interval); return current_scale; } @@ -674,7 +685,7 @@ at::Tensor XLANativeFunctions::_copy_from(const at::Tensor& self, } else { auto dst_tensor = std::move(dst_tensor_status).value(); tensor_methods::copy_(dst_tensor, self_tensor_status.value()); - OkOrThrow(bridge::ReplaceXlaTensor(dst, dst_tensor)); + XLA_THROW_IF_ERROR(bridge::ReplaceXlaTensor(dst, dst_tensor)); } return dst; } @@ -741,8 +752,8 @@ at::Tensor XLANativeFunctions::_to_copy( if (device && device->type() != c10::kXLA) { XLA_CHECK(device->type() == c10::kCPU) << "only cpu device is supported in _to_copy."; - auto self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - auto eager_tensor = self_tensor->ToTensor(/*detached=*/true); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + auto eager_tensor = xla_self->ToTensor(/*detached=*/true); // Use the eager .to on the eager tensor. return eager_tensor.to(options, non_blocking, /*copy=*/true); @@ -773,8 +784,8 @@ std::tuple XLANativeFunctions::_linalg_eigh( ATEN_OP(_linalg_eigh)>::call(self, uplo, compute_v); } - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - auto outputs = tensor_methods::eigh(self_tensor, uplo); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + auto outputs = tensor_methods::eigh(xla_self, uplo); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)), bridge::AtenFromXlaTensor(std::get<1>(outputs))); } @@ -782,8 +793,8 @@ std::tuple XLANativeFunctions::_linalg_eigh( std::tuple XLANativeFunctions::_linalg_slogdet(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - auto outputs = tensor_methods::slogdet(self_tensor); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + auto outputs = tensor_methods::slogdet(xla_self); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)), bridge::AtenFromXlaTensor(std::get<1>(outputs)), bridge::AtenFromXlaTensor(XLATensorPtr()), @@ -798,18 +809,20 @@ at::Tensor XLANativeFunctions::_log_softmax(const at::Tensor& self, int64_t dim, std::vector shapes{ torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; - return bridge::AtenFromXlaTensor( - tensor_methods::log_softmax(GetValueOrThrow(bridge::GetXlaTensor(self)), - dim, std::nullopt, std::move(shapes))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor(tensor_methods::log_softmax( + xla_self, dim, std::nullopt, std::move(shapes))); } at::Tensor XLANativeFunctions::_log_softmax_backward_data( const at::Tensor& grad_output, const at::Tensor& output, int64_t dim, at::ScalarType /* input_dtype */) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::log_softmax_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad_output)), - GetValueOrThrow(bridge::GetXlaTensor(output)), dim)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_output, bridge::GetXlaTensor(output)); + return bridge::AtenFromXlaTensor( + tensor_methods::log_softmax_backward(xla_grad_output, xla_output, dim)); } std::tuple XLANativeFunctions::_pack_padded_sequence( @@ -823,17 +836,20 @@ std::tuple XLANativeFunctions::_pack_padded_sequence( at::Tensor XLANativeFunctions::_softmax(const at::Tensor& self, int64_t dim, bool /* half_to_float */) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::softmax( - GetValueOrThrow(bridge::GetXlaTensor(self)), dim, std::nullopt)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor( + tensor_methods::softmax(xla_self, dim, std::nullopt)); } at::Tensor XLANativeFunctions::_softmax_backward_data( const at::Tensor& grad_output, const at::Tensor& output, int64_t dim, at::ScalarType input_dtype) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::softmax_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad_output)), - GetValueOrThrow(bridge::GetXlaTensor(output)), dim)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_output, bridge::GetXlaTensor(output)); + return bridge::AtenFromXlaTensor( + tensor_methods::softmax_backward(xla_grad_output, xla_output, dim)); } at::Tensor XLANativeFunctions::_unsafe_view(const at::Tensor& self, @@ -882,16 +898,17 @@ at::Tensor XLANativeFunctions::addmm(const at::Tensor& self, return at::native::call_fallback_fn<&xla_fallback, ATEN_OP(addmm)>::call( self, mat1, mat2, beta, alpha); } - return bridge::AtenFromXlaTensor(tensor_methods::addmm( - GetValueOrThrow(bridge::GetXlaTensor(mat1)), - /*weight=*/GetValueOrThrow(bridge::GetXlaTensor(mat2)), - /*bias=*/GetValueOrThrow(bridge::GetXlaTensor(self)))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_mat1, bridge::GetXlaTensor(mat1)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_mat2, bridge::GetXlaTensor(mat2)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor( + tensor_methods::addmm(xla_mat1, /*weight=*/xla_mat2, /*bias=*/xla_self)); } at::Tensor XLANativeFunctions::alias(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::alias(GetValueOrThrow(bridge::GetXlaTensor(self)))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor(tensor_methods::alias(xla_self)); } at::Tensor XLANativeFunctions::alias_copy(const at::Tensor& self) { @@ -904,8 +921,8 @@ at::Tensor& XLANativeFunctions::arange_out(const at::Scalar& start, const at::Scalar& step, at::Tensor& out) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr out_tensor = GetValueOrThrow(bridge::GetXlaTensor(out)); - tensor_methods::arange_out(out_tensor, start, end, step, out.scalar_type()); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_out, bridge::GetXlaTensor(out)); + tensor_methods::arange_out(xla_out, start, end, step, out.scalar_type()); return out; } @@ -964,10 +981,9 @@ static at::Tensor as_strided_eliminate_one_dim_fast_path( } } } + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_tensor, bridge::GetXlaTensor(tensor)); return bridge::AtenFromXlaTensor(tensor_methods::squeeze( - tensor_methods::slice(GetValueOrThrow(bridge::GetXlaTensor(tensor)), - skip_dim, 0, 1, 1), - skip_dim)); + tensor_methods::slice(xla_tensor, skip_dim, 0, 1, 1), skip_dim)); } // now tensor_dim.size() == stride.size() long reduce_size_location = -1; @@ -998,9 +1014,9 @@ static at::Tensor as_strided_eliminate_one_dim_fast_path( // stride. K = 1; } + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_tensor, bridge::GetXlaTensor(tensor)); return bridge::AtenFromXlaTensor(tensor_methods::slice( - GetValueOrThrow(bridge::GetXlaTensor(tensor)), reduce_size_location, 0, - size[reduce_size_location] * K, K)); + xla_tensor, reduce_size_location, 0, size[reduce_size_location] * K, K)); } at::Tensor XLANativeFunctions::as_strided_copy( @@ -1013,7 +1029,8 @@ at::Tensor XLANativeFunctions::as_strided_copy( // Retrieve the base tensor, if there's one. // This function actually operates on the tensor's storage. Since XLA does not // expose the actual storage, we use the originally allocated tensor. - const at::Tensor& base = GetValueOrThrow(bridge::GetXlaTensor(self))->Base(); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + const at::Tensor& base = xla_self->Base(); at::Tensor tensor = base.defined() ? base : self; // Fast path: using slice to replace as_strided to avoid the index copy. @@ -1029,11 +1046,11 @@ at::Tensor XLANativeFunctions::as_strided_copy( // Sets the base tensor as tensor. // Even though this function copies (without aliasing) tensor, it's still // treated as a view function in the functionalization layer. + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_tensor, bridge::GetXlaTensor(tensor)); return bridge::AtenFromXlaTensor(bridge::SetBaseTensor( - tensor_methods::as_strided( - GetValueOrThrow(bridge::GetXlaTensor(tensor)), - XlaHelpers::I64List(size), XlaHelpers::I64List(stride), - XlaHelpers::I64Optional(storage_offset)), + tensor_methods::as_strided(xla_tensor, XlaHelpers::I64List(size), + XlaHelpers::I64List(stride), + XlaHelpers::I64Optional(storage_offset)), tensor)); } @@ -1127,21 +1144,22 @@ at::Tensor XLANativeFunctions::as_strided_scatter( at::IntArrayRef size, at::IntArrayRef stride, std::optional storage_offset) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto base_ = GetValueOrThrow(bridge::GetXlaTensor(base)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_base, bridge::GetXlaTensor(base)); auto xsize = XlaHelpers::I64List(size); auto xstride = XlaHelpers::I64List(stride); - if (!AsStrided::StrideIsSupported(base_->shape(), xsize, xstride, + if (!AsStrided::StrideIsSupported(xla_base->shape(), xsize, xstride, storage_offset.value_or(0))) { return at::native::call_fallback_fn< &xla_fallback, ATEN_OP(as_strided_scatter)>::call(base, mutated_view, size, stride, storage_offset); } - auto mutated_view_ = GetValueOrThrow(bridge::GetXlaTensor(mutated_view)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_mutated_view, + bridge::GetXlaTensor(mutated_view)); return bridge::AtenFromXlaTensor( - base_->CreateFrom(torch_xla::MakeNode( - base_->GetIrValue(), mutated_view_->GetIrValue(), - torch::lazy::ToVector(base_->shape().get().dimensions()), + xla_base->CreateFrom(torch_xla::MakeNode( + xla_base->GetIrValue(), xla_mutated_view->GetIrValue(), + torch::lazy::ToVector(xla_base->shape().get().dimensions()), xstride, storage_offset.value_or(0)))); } @@ -1150,9 +1168,10 @@ at::Tensor XLANativeFunctions::atan2(const at::Tensor& self, TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); auto common_device = torch_xla::bridge::GetXlaDevice(self, other); XLA_CHECK(common_device); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_other, bridge::GetXlaTensor(other)); torch::lazy::NodePtr node = torch_xla::MakeNode( - GetValueOrThrow(bridge::GetXlaTensor(self))->GetIrValue(), - GetValueOrThrow(bridge::GetXlaTensor(other))->GetIrValue()); + xla_self->GetIrValue(), xla_other->GetIrValue()); return torch_xla::bridge::AtenFromXlaTensor( torch_xla::XLATensor::Create(std::move(node), *common_device)); @@ -1163,11 +1182,11 @@ at::Tensor XLANativeFunctions::avg_pool2d( at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, std::optional divisor_override) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::avg_pool_nd( - GetValueOrThrow(bridge::GetXlaTensor(self)), /*spatial_dim_count=*/2, - XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), - XlaHelpers::I64List(padding), ceil_mode, count_include_pad, - divisor_override)); + xla_self, /*spatial_dim_count=*/2, XlaHelpers::I64List(kernel_size), + XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode, + count_include_pad, divisor_override)); } at::Tensor XLANativeFunctions::avg_pool2d_backward( @@ -1184,12 +1203,13 @@ at::Tensor XLANativeFunctions::avg_pool2d_backward( count_include_pad, divisor_override); } + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::avg_pool_nd_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad_output)), - GetValueOrThrow(bridge::GetXlaTensor(self)), - /*spatial_dim_count=*/2, XlaHelpers::I64List(kernel_size), - XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode, - count_include_pad)); + xla_grad_output, xla_self, /*spatial_dim_count=*/2, + XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), + XlaHelpers::I64List(padding), ceil_mode, count_include_pad)); } at::Tensor XLANativeFunctions::avg_pool3d( @@ -1197,11 +1217,11 @@ at::Tensor XLANativeFunctions::avg_pool3d( at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, std::optional divisor_override) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::avg_pool_nd( - GetValueOrThrow(bridge::GetXlaTensor(self)), /*spatial_dim_count=*/3, - XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), - XlaHelpers::I64List(padding), ceil_mode, count_include_pad, - divisor_override)); + xla_self, /*spatial_dim_count=*/3, XlaHelpers::I64List(kernel_size), + XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode, + count_include_pad, divisor_override)); } at::Tensor XLANativeFunctions::avg_pool3d_backward( @@ -1218,12 +1238,13 @@ at::Tensor XLANativeFunctions::avg_pool3d_backward( count_include_pad, divisor_override); } + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::avg_pool_nd_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad_output)), - GetValueOrThrow(bridge::GetXlaTensor(self)), - /*spatial_dim_count=*/3, XlaHelpers::I64List(kernel_size), - XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode, - count_include_pad)); + xla_grad_output, xla_self, /*spatial_dim_count=*/3, + XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), + XlaHelpers::I64List(padding), ceil_mode, count_include_pad)); } at::Tensor XLANativeFunctions::baddbmm(const at::Tensor& self, @@ -1232,11 +1253,11 @@ at::Tensor XLANativeFunctions::baddbmm(const at::Tensor& self, const at::Scalar& beta, const at::Scalar& alpha) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - - return bridge::AtenFromXlaTensor(tensor_methods::baddbmm( - GetValueOrThrow(bridge::GetXlaTensor(self)), - GetValueOrThrow(bridge::GetXlaTensor(batch1)), - GetValueOrThrow(bridge::GetXlaTensor(batch2)), beta, alpha)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_batch1, bridge::GetXlaTensor(batch1)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_batch2, bridge::GetXlaTensor(batch2)); + return bridge::AtenFromXlaTensor( + tensor_methods::baddbmm(xla_self, xla_batch1, xla_batch2, beta, alpha)); } at::Tensor XLANativeFunctions::bernoulli( @@ -1247,8 +1268,8 @@ at::Tensor XLANativeFunctions::bernoulli( ATEN_OP(bernoulli)>::call(self, generator); } - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - return bridge::AtenFromXlaTensor(tensor_methods::bernoulli(self_tensor)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor(tensor_methods::bernoulli(xla_self)); } at::Tensor XLANativeFunctions::bernoulli( @@ -1258,8 +1279,8 @@ at::Tensor XLANativeFunctions::bernoulli( return at::native::call_fallback_fn< &xla_fallback, ATEN_OP2(bernoulli, p)>::call(self, p, generator); } - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - return bridge::AtenFromXlaTensor(tensor_methods::bernoulli(self_tensor, p)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor(tensor_methods::bernoulli(xla_self, p)); } at::Tensor& XLANativeFunctions::bernoulli_( @@ -1270,9 +1291,9 @@ at::Tensor& XLANativeFunctions::bernoulli_( return at::native::call_fallback_fn< &xla_fallback, ATEN_OP2(bernoulli_, Tensor)>::call(self, p, generator); } - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - tensor_methods::bernoulli_(self_tensor, - GetValueOrThrow(bridge::GetXlaTensor(p))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_p, bridge::GetXlaTensor(p)); + tensor_methods::bernoulli_(xla_self, xla_p); return self; } @@ -1316,16 +1337,18 @@ at::Tensor XLANativeFunctions::bitwise_xor(const at::Tensor& self, at::Tensor XLANativeFunctions::bmm(const at::Tensor& self, const at::Tensor& mat2) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::bmm(GetValueOrThrow(bridge::GetXlaTensor(self)), - GetValueOrThrow(bridge::GetXlaTensor(mat2)))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_mat2, bridge::GetXlaTensor(mat2)); + return bridge::AtenFromXlaTensor(tensor_methods::bmm(xla_self, xla_mat2)); } at::Tensor XLANativeFunctions::cat(const at::ITensorListRef& tensors, int64_t dim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto xtensors = GetValueOrThrow(bridge::GetXlaTensors(tensors)); - auto output = GetValueOrThrow( + XLA_ASSIGN_OR_THROW(std::vector xtensors, + bridge::GetXlaTensors(tensors)); + XLA_ASSIGN_OR_THROW( + XLATensorPtr output, tensor_methods::cat(xtensors, dim, at::native::result_type(tensors))); return bridge::AtenFromXlaTensor(std::move(output)); } @@ -1333,15 +1356,15 @@ at::Tensor XLANativeFunctions::cat(const at::ITensorListRef& tensors, at::Tensor XLANativeFunctions::celu(const at::Tensor& self, const at::Scalar& alpha) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::celu(GetValueOrThrow(bridge::GetXlaTensor(self)), alpha)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor(tensor_methods::celu(xla_self, alpha)); } at::Tensor& XLANativeFunctions::celu_(at::Tensor& self, const at::Scalar& alpha) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - tensor_methods::celu_(self_tensor, alpha); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + tensor_methods::celu_(xla_self, alpha); return self; } @@ -1349,43 +1372,45 @@ at::Tensor XLANativeFunctions::clamp(const at::Tensor& self, const std::optional& min, const std::optional& max) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::clamp( - GetValueOrThrow(bridge::GetXlaTensor(self)), min, max)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor(tensor_methods::clamp(xla_self, min, max)); } at::Tensor XLANativeFunctions::clamp_max(const at::Tensor& self, const at::Scalar& max) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::clamp( - GetValueOrThrow(bridge::GetXlaTensor(self)), std::nullopt, max)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor( + tensor_methods::clamp(xla_self, std::nullopt, max)); } at::Tensor XLANativeFunctions::clamp_min(const at::Tensor& self, const at::Scalar& min) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::clamp( - GetValueOrThrow(bridge::GetXlaTensor(self)), min, std::nullopt)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor( + tensor_methods::clamp(xla_self, min, std::nullopt)); } at::Tensor XLANativeFunctions::clone( const at::Tensor& self, std::optional /* memory_format */) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); if (self.is_conj()) { // Materialize the conjugate if necessary. - tensor = tensor_methods::conj(tensor); + xla_self = tensor_methods::conj(xla_self); } - return bridge::AtenFromXlaTensor(tensor_methods::clone(tensor)); + return bridge::AtenFromXlaTensor(tensor_methods::clone(xla_self)); } at::Tensor XLANativeFunctions::constant_pad_nd(const at::Tensor& self, at::IntArrayRef pad, const at::Scalar& value) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::constant_pad_nd( - GetValueOrThrow(bridge::GetXlaTensor(self)), XlaHelpers::I64List(pad), - value)); + xla_self, XlaHelpers::I64List(pad), value)); } // This functions covers the whole convolution lowering. @@ -1395,20 +1420,18 @@ at::Tensor XLANativeFunctions::convolution_overrideable( at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_input, bridge::GetXlaTensor(input)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_weight, bridge::GetXlaTensor(weight)); if (IsDefined(bias)) { + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_bias, bridge::GetXlaTensor(*bias)); return bridge::AtenFromXlaTensor(tensor_methods::convolution_overrideable( - GetValueOrThrow(bridge::GetXlaTensor(input)), - GetValueOrThrow(bridge::GetXlaTensor(weight)), - GetValueOrThrow(bridge::GetXlaTensor(*bias)), - XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), - XlaHelpers::I64List(dilation), transposed, + xla_input, xla_weight, xla_bias, XlaHelpers::I64List(stride), + XlaHelpers::I64List(padding), XlaHelpers::I64List(dilation), transposed, XlaHelpers::I64List(output_padding), groups)); } else { return bridge::AtenFromXlaTensor(tensor_methods::convolution_overrideable( - GetValueOrThrow(bridge::GetXlaTensor(input)), - GetValueOrThrow(bridge::GetXlaTensor(weight)), - XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), - XlaHelpers::I64List(dilation), transposed, + xla_input, xla_weight, XlaHelpers::I64List(stride), + XlaHelpers::I64List(padding), XlaHelpers::I64List(dilation), transposed, XlaHelpers::I64List(output_padding), groups)); } } @@ -1421,12 +1444,13 @@ XLANativeFunctions::convolution_backward_overrideable( at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, std::array output_mask) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_input, bridge::GetXlaTensor(input)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_weight, bridge::GetXlaTensor(weight)); auto gradients = tensor_methods::convolution_backward_overrideable( - GetValueOrThrow(bridge::GetXlaTensor(grad_output)), - GetValueOrThrow(bridge::GetXlaTensor(input)), - GetValueOrThrow(bridge::GetXlaTensor(weight)), - XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), - XlaHelpers::I64List(dilation), transposed, + xla_grad_output, xla_input, xla_weight, XlaHelpers::I64List(stride), + XlaHelpers::I64List(padding), XlaHelpers::I64List(dilation), transposed, XlaHelpers::I64List(output_padding), groups); return std::make_tuple( output_mask[0] ? bridge::AtenFromXlaTensor(std::get<0>(gradients)) @@ -1454,18 +1478,18 @@ at::Tensor XLANativeFunctions::cross(const at::Tensor& self, const at::Tensor& other, std::optional dim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_other, bridge::GetXlaTensor(other)); return bridge::AtenFromXlaTensor( - tensor_methods::cross(GetValueOrThrow(bridge::GetXlaTensor(self)), - GetValueOrThrow(bridge::GetXlaTensor(other)), - XlaHelpers::I64Optional(dim))); + tensor_methods::cross(xla_self, xla_other, XlaHelpers::I64Optional(dim))); } std::tuple XLANativeFunctions::cummax( const at::Tensor& self, int64_t dim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); std::tuple res = - tensor_methods::cummax(self_tensor, dim); + tensor_methods::cummax(xla_self, dim); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(res)), bridge::AtenFromXlaTensor(std::get<1>(res))); } @@ -1473,64 +1497,64 @@ std::tuple XLANativeFunctions::cummax( at::Tensor XLANativeFunctions::cumprod(const at::Tensor& self, int64_t dim, std::optional dtype) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); std::optional promoted_dtype = - PromoteIntegralType(self_tensor->dtype(), dtype); - if (IsOperationOnType(promoted_dtype, self_tensor->dtype(), + PromoteIntegralType(xla_self->dtype(), dtype); + if (IsOperationOnType(promoted_dtype, xla_self->dtype(), at::ScalarType::Long)) { // XLA reduce-window does not support S64 mode. return at::native::call_fallback_fn<&xla_fallback, ATEN_OP(cumprod)>::call( self, dim, dtype); } return bridge::AtenFromXlaTensor( - tensor_methods::cumprod(self_tensor, dim, promoted_dtype)); + tensor_methods::cumprod(xla_self, dim, promoted_dtype)); } at::Tensor XLANativeFunctions::cumsum(const at::Tensor& self, int64_t dim, std::optional dtype) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor( - tensor_methods::cumsum(self_tensor, dim, dtype)); + tensor_methods::cumsum(xla_self, dim, dtype)); } // TODO(alanwaketan): Let's rewrite a without reusing other native functions. at::Tensor XLANativeFunctions::detach_copy(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(GetValueOrThrow(bridge::GetXlaTensor(self))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor(xla_self); } at::Tensor XLANativeFunctions::diag(const at::Tensor& self, int64_t diagonal) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::diag( - GetValueOrThrow(bridge::GetXlaTensor(self)), diagonal)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor(tensor_methods::diag(xla_self, diagonal)); } at::Tensor XLANativeFunctions::diagonal_copy(const at::Tensor& self, int64_t offset, int64_t dim1, int64_t dim2) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::diagonal( - GetValueOrThrow(bridge::GetXlaTensor(self)), offset, dim1, dim2)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor( + tensor_methods::diagonal(xla_self, offset, dim1, dim2)); } at::Tensor XLANativeFunctions::diagonal_scatter(const at::Tensor& base, const at::Tensor& mutated_view, int64_t offset, int64_t dim1, int64_t dim2) { - auto base_ = GetValueOrThrow(bridge::GetXlaTensor(base)); - auto mutated_view_ = GetValueOrThrow(bridge::GetXlaTensor(mutated_view)); - int64_t base_rank = GetValueOrThrow(bridge::GetXlaTensor(base)) - ->shape() - .get() - .dimensions_size(); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_base, bridge::GetXlaTensor(base)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_mutated_view, + bridge::GetXlaTensor(mutated_view)); + int64_t base_rank = xla_base->shape().get().dimensions_size(); int64_t canonical_dim1 = torch::lazy::GetCanonicalDimensionIndex(dim1, base_rank); int64_t canonical_dim2 = torch::lazy::GetCanonicalDimensionIndex(dim2, base_rank); return bridge::AtenFromXlaTensor( - base_->CreateFrom(torch_xla::MakeNode( - base_->GetIrValue(), mutated_view_->GetIrValue(), offset, + xla_base->CreateFrom(torch_xla::MakeNode( + xla_base->GetIrValue(), xla_mutated_view->GetIrValue(), offset, canonical_dim1, canonical_dim2))); } @@ -1546,16 +1570,17 @@ at::Tensor XLANativeFunctions::div( TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); at::ScalarType dtype = at::result_type(self, other); auto operands = GetBinaryOperands(self, UnwrapNumber(other, dtype)); - auto output = GetValueOrThrow(tensor_methods::div( - operands.first, operands.second, rounding_mode, dtype)); + XLA_ASSIGN_OR_THROW(XLATensorPtr output, + tensor_methods::div(operands.first, operands.second, + rounding_mode, dtype)); return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor XLANativeFunctions::div(const at::Tensor& self, const at::Scalar& other) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::div(GetValueOrThrow(bridge::GetXlaTensor(self)), other)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor(tensor_methods::div(xla_self, other)); } at::Tensor XLANativeFunctions::dot(const at::Tensor& self, @@ -1577,9 +1602,10 @@ at::Tensor XLANativeFunctions::dot(const at::Tensor& self, return at::native::call_fallback_fn<&xla_fallback, ATEN_OP(dot)>::call( self, tensor); } + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_tensor, bridge::GetXlaTensor(tensor)); return bridge::AtenFromXlaTensor( - tensor_methods::matmul(GetValueOrThrow(bridge::GetXlaTensor(self)), - GetValueOrThrow(bridge::GetXlaTensor(tensor)))); + tensor_methods::matmul(xla_self, xla_tensor)); } at::Tensor XLANativeFunctions::einsum(std::string_view equation, @@ -1624,18 +1650,23 @@ at::Tensor XLANativeFunctions::elu_backward(const at::Tensor& grad_output, XLA_CHECK(!self || alpha.to() >= 0.0) << "In-place elu backward calculation is triggered with a negative slope " "which is not supported."; + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self_or_result, + bridge::GetXlaTensor(self_or_result)); return bridge::AtenFromXlaTensor(tensor_methods::elu_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad_output)), alpha, scale, - input_scale, GetValueOrThrow(bridge::GetXlaTensor(self_or_result)))); + xla_grad_output, alpha, scale, input_scale, xla_self_or_result)); } at::Tensor XLANativeFunctions::embedding_dense_backward( const at::Tensor& grad_output, const at::Tensor& indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_indices, bridge::GetXlaTensor(indices)); return bridge::AtenFromXlaTensor(tensor_methods::embedding_dense_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad_output)), - GetValueOrThrow(bridge::GetXlaTensor(indices)), num_weights, padding_idx, + xla_grad_output, xla_indices, num_weights, padding_idx, scale_grad_by_freq)); } @@ -1655,17 +1686,23 @@ XLANativeFunctions::_embedding_bag_forward_only( include_last_offset, padding_idx); } - auto indices_tensor = GetValueOrThrow(bridge::GetXlaTensor(indices)); - auto sample_weights = - per_sample_weights.has_value() && per_sample_weights.value().defined() - ? GetValueOrThrow(bridge::GetXlaTensor(per_sample_weights.value())) - : tensor_methods::full_like(indices_tensor, 1.0, - *torch_xla::bridge::GetXlaDevice(weight), - at::ScalarType::Float); - auto result = tensor_methods::embedding_bag( - GetValueOrThrow(bridge::GetXlaTensor(weight)), indices_tensor, - GetValueOrThrow(bridge::GetXlaTensor(offsets)), mode, sample_weights, - include_last_offset); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_indices, bridge::GetXlaTensor(indices)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_weight, bridge::GetXlaTensor(weight)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_offsets, bridge::GetXlaTensor(offsets)); + + XLATensorPtr sample_weights; + if (per_sample_weights.has_value() && per_sample_weights.value().defined()) { + XLA_ASSIGN_OR_THROW(sample_weights, + bridge::GetXlaTensor(per_sample_weights.value())); + } else { + sample_weights = tensor_methods::full_like( + xla_indices, 1.0, *torch_xla::bridge::GetXlaDevice(weight), + at::ScalarType::Float); + } + + auto result = + tensor_methods::embedding_bag(xla_weight, xla_indices, xla_offsets, mode, + sample_weights, include_last_offset); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(result)), bridge::AtenFromXlaTensor(std::get<1>(result)), bridge::AtenFromXlaTensor(std::get<2>(result)), @@ -1712,7 +1749,8 @@ at::Tensor XLANativeFunctions::empty_symint( // does not actually end up doing any memory initialization, we use that and // avoid going to CPU for it. A common PT pattern is indeed doing empty() plus // s_copy_(). - XLATensorPtr xla_tensor = GetValueOrThrow( + XLA_ASSIGN_OR_THROW( + XLATensorPtr xla_tensor, all_dims_static ? tensor_methods::full(XlaHelpers::I64List(int_sizes.value()), 0, GetXlaDeviceOrCurrent(device), @@ -1751,15 +1789,15 @@ at::Tensor XLANativeFunctions::expand_copy_symint(const at::Tensor& self, bool implicit) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); std::optional size = c10::asIntArrayRefSlowOpt(sym_size); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); if (size.has_value()) { - return bridge::AtenFromXlaTensor( - tensor_methods::expand(GetValueOrThrow(bridge::GetXlaTensor(self)), - torch::lazy::ToVector(*size))); + return bridge::AtenFromXlaTensor(tensor_methods::expand( + xla_self, torch::lazy::ToVector(*size))); } else { // at least one of the dimension is symbolic, use the sym_int version of the // node - return bridge::AtenFromXlaTensor(tensor_methods::expand_symint( - GetValueOrThrow(bridge::GetXlaTensor(self)), sym_size)); + return bridge::AtenFromXlaTensor( + tensor_methods::expand_symint(xla_self, sym_size)); } } @@ -1773,30 +1811,30 @@ at::Tensor& XLANativeFunctions::exponential_( generator); } XLA_CHECK_GE(lambd, 0.0); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - tensor_methods::exponential_(self_tensor, lambd); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + tensor_methods::exponential_(xla_self, lambd); return self; } at::Tensor& XLANativeFunctions::eye_out(int64_t n, at::Tensor& out) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr out_tensor = GetValueOrThrow(bridge::GetXlaTensor(out)); - tensor_methods::eye_out(out_tensor, n, n); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_out, bridge::GetXlaTensor(out)); + tensor_methods::eye_out(xla_out, n, n); return out; } at::Tensor& XLANativeFunctions::eye_out(int64_t n, int64_t m, at::Tensor& out) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr out_tensor = GetValueOrThrow(bridge::GetXlaTensor(out)); - tensor_methods::eye_out(out_tensor, n, m); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_out, bridge::GetXlaTensor(out)); + tensor_methods::eye_out(xla_out, n, m); return out; } at::Tensor& XLANativeFunctions::fill_(at::Tensor& self, const at::Scalar& value) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - tensor_methods::fill_(self_tensor, value); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + tensor_methods::fill_(xla_self, value); return self; } @@ -1812,10 +1850,11 @@ at::Tensor& XLANativeFunctions::fill_(at::Tensor& self, at::Tensor XLANativeFunctions::flip(const at::Tensor& self, at::IntArrayRef dims) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto xself = GetValueOrThrow(bridge::GetXlaTensor(self)); - auto output = - GetValueOrThrow(tensor_methods::flip(xself, XlaHelpers::I64List(dims))); - return bridge::AtenFromXlaTensor(std::move(output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW( + XLATensorPtr xla_output, + tensor_methods::flip(xla_self, XlaHelpers::I64List(dims))); + return bridge::AtenFromXlaTensor(std::move(xla_output)); } at::Tensor XLANativeFunctions::floor_divide(const at::Tensor& self, @@ -1866,25 +1905,29 @@ at::Tensor XLANativeFunctions::full(at::IntArrayRef size, } else { intend_dtype = fill_value.type(); } - return bridge::AtenFromXlaTensor(GetValueOrThrow( + XLA_ASSIGN_OR_THROW( + XLATensorPtr output, tensor_methods::full(absl::Span(size), fill_value, - GetXlaDeviceOrCurrent(device), intend_dtype))); + GetXlaDeviceOrCurrent(device), intend_dtype)); + return bridge::AtenFromXlaTensor(output); } at::Tensor XLANativeFunctions::gather(const at::Tensor& self, int64_t dim, const at::Tensor& index, bool /* sparse_grad */) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(GetValueOrThrow( - tensor_methods::gather(GetValueOrThrow(bridge::GetXlaTensor(self)), dim, - GetValueOrThrow(bridge::GetXlaTensor(index))))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_index, bridge::GetXlaTensor(index)); + XLA_ASSIGN_OR_THROW(XLATensorPtr output, + tensor_methods::gather(xla_self, dim, xla_index)); + return bridge::AtenFromXlaTensor(output); } at::Tensor XLANativeFunctions::gelu(const at::Tensor& self, std::string_view approximate) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::gelu( - GetValueOrThrow(bridge::GetXlaTensor(self)), approximate)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor(tensor_methods::gelu(xla_self, approximate)); } at::Tensor XLANativeFunctions::gelu_backward(const at::Tensor& grad, @@ -1892,18 +1935,21 @@ at::Tensor XLANativeFunctions::gelu_backward(const at::Tensor& grad, std::string_view approximate) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); at::ScalarType result_type = at::result_type(grad, self); - return bridge::AtenFromXlaTensor(tensor_methods::gelu_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad.to(result_type))), - GetValueOrThrow(bridge::GetXlaTensor(self.to(result_type))), - approximate)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad, + bridge::GetXlaTensor(grad.to(result_type))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, + bridge::GetXlaTensor(self.to(result_type))); + return bridge::AtenFromXlaTensor( + tensor_methods::gelu_backward(xla_grad, xla_self, approximate)); } at::Tensor XLANativeFunctions::hardtanh(const at::Tensor& self, const at::Scalar& min_val, const at::Scalar& max_val) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::clamp( - GetValueOrThrow(bridge::GetXlaTensor(self)), min_val, max_val)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor( + tensor_methods::clamp(xla_self, min_val, max_val)); } at::Tensor XLANativeFunctions::hardtanh_backward(const at::Tensor& grad_output, @@ -1911,9 +1957,11 @@ at::Tensor XLANativeFunctions::hardtanh_backward(const at::Tensor& grad_output, const at::Scalar& min_val, const at::Scalar& max_val) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::hardtanh_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad_output)), - GetValueOrThrow(bridge::GetXlaTensor(self)), min_val, max_val)); + xla_grad_output, xla_self, min_val, max_val)); } at::Tensor XLANativeFunctions::index( @@ -1950,29 +1998,31 @@ at::Tensor XLANativeFunctions::index_add(const at::Tensor& self, int64_t dim, const at::Tensor& source, const at::Scalar& alpha) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::index_add( - GetValueOrThrow(bridge::GetXlaTensor(self)), dim, - GetValueOrThrow(bridge::GetXlaTensor(index)), - GetValueOrThrow(bridge::GetXlaTensor(source)), alpha)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_index, bridge::GetXlaTensor(index)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_source, bridge::GetXlaTensor(source)); + return bridge::AtenFromXlaTensor( + tensor_methods::index_add(xla_self, dim, xla_index, xla_source, alpha)); } at::Tensor XLANativeFunctions::index_copy(const at::Tensor& self, int64_t dim, const at::Tensor& index, const at::Tensor& source) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - return bridge::AtenFromXlaTensor(tensor_methods::index_copy( - self_tensor, dim, GetValueOrThrow(bridge::GetXlaTensor(index)), - GetValueOrThrow(bridge::GetXlaTensor(source)))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_index, bridge::GetXlaTensor(index)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_source, bridge::GetXlaTensor(source)); + return bridge::AtenFromXlaTensor( + tensor_methods::index_copy(xla_self, dim, xla_index, xla_source)); } at::Tensor& XLANativeFunctions::index_fill_(at::Tensor& self, int64_t dim, const at::Tensor& index, const at::Scalar& value) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - tensor_methods::index_fill_( - self_tensor, dim, GetValueOrThrow(bridge::GetXlaTensor(index)), value); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_index, bridge::GetXlaTensor(index)); + tensor_methods::index_fill_(xla_self, dim, xla_index, value); return self; } @@ -1980,10 +2030,10 @@ at::Tensor& XLANativeFunctions::index_fill_(at::Tensor& self, int64_t dim, const at::Tensor& index, const at::Tensor& value) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - tensor_methods::index_fill_(self_tensor, dim, - GetValueOrThrow(bridge::GetXlaTensor(index)), - GetValueOrThrow(bridge::GetXlaTensor(value))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_index, bridge::GetXlaTensor(index)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_value, bridge::GetXlaTensor(value)); + tensor_methods::index_fill_(xla_self, dim, xla_index, xla_value); return self; } @@ -2011,9 +2061,9 @@ at::Tensor& XLANativeFunctions::index_put_( device = bridge::GetXlaDevice(canonical_index_info.indices); } XLA_CHECK(device.has_value()); - XLATensorPtr self_tensor = bridge::GetOrCreateXlaTensor(self, *device); + XLATensorPtr xla_self = bridge::GetOrCreateXlaTensor(self, *device); tensor_methods::index_put_( - self_tensor, + xla_self, bridge::GetOrCreateXlaTensor(canonical_index_info.base, *device), bridge::GetOrCreateXlaTensors(canonical_index_info.indices, *device), canonical_index_info.start_dim, @@ -2025,9 +2075,10 @@ at::Tensor& XLANativeFunctions::index_put_( at::Tensor XLANativeFunctions::index_select(const at::Tensor& self, int64_t dim, const at::Tensor& index) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::index_select( - GetValueOrThrow(bridge::GetXlaTensor(self)), dim, - GetValueOrThrow(bridge::GetXlaTensor(index)))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_index, bridge::GetXlaTensor(index)); + return bridge::AtenFromXlaTensor( + tensor_methods::index_select(xla_self, dim, xla_index)); } at::Tensor XLANativeFunctions::kl_div(const at::Tensor& self, @@ -2040,8 +2091,8 @@ at::Tensor XLANativeFunctions::kl_div(const at::Tensor& self, std::tuple XLANativeFunctions::kthvalue( const at::Tensor& self, int64_t k, int64_t dim, bool keepdim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto results = tensor_methods::kthvalue( - GetValueOrThrow(bridge::GetXlaTensor(self)), k, dim, keepdim); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + auto results = tensor_methods::kthvalue(xla_self, k, dim, keepdim); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), bridge::AtenFromXlaTensor(std::get<1>(results))); } @@ -2056,9 +2107,11 @@ at::Tensor XLANativeFunctions::leaky_relu_backward( auto node_negative_slope = torch::lazy::LazyGraphExecutor::Get()->GetIrValueForScalarFromCodegen( negative_slope, *common_device); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); torch::lazy::NodePtr node = torch_xla::MakeNode( - GetValueOrThrow(bridge::GetXlaTensor(grad_output))->GetIrValue(), - GetValueOrThrow(bridge::GetXlaTensor(self))->GetIrValue(), + xla_grad_output->GetIrValue(), xla_self->GetIrValue(), node_negative_slope, self_is_result); return torch_xla::bridge::AtenFromXlaTensor( torch_xla::XLATensor::Create(std::move(node), *common_device)); @@ -2074,10 +2127,11 @@ at::Tensor XLANativeFunctions::lerp(const at::Tensor& self, XLA_CHECK_EQ(self.dtype(), weight.dtype()) << "expected dtype " << self.dtype() << " for `weight` but got dtype " << weight.dtype(); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_end, bridge::GetXlaTensor(end)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_weight, bridge::GetXlaTensor(weight)); return bridge::AtenFromXlaTensor( - tensor_methods::lerp(GetValueOrThrow(bridge::GetXlaTensor(self)), - GetValueOrThrow(bridge::GetXlaTensor(end)), - GetValueOrThrow(bridge::GetXlaTensor(weight)))); + tensor_methods::lerp(xla_self, xla_end, xla_weight)); } at::Tensor XLANativeFunctions::lerp(const at::Tensor& self, @@ -2087,9 +2141,10 @@ at::Tensor XLANativeFunctions::lerp(const at::Tensor& self, XLA_CHECK_EQ(self.dtype(), end.dtype()) << "expected dtype " << self.dtype() << " for `end` but got dtype " << end.dtype(); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_end, bridge::GetXlaTensor(end)); return bridge::AtenFromXlaTensor( - tensor_methods::lerp(GetValueOrThrow(bridge::GetXlaTensor(self)), - GetValueOrThrow(bridge::GetXlaTensor(end)), weight)); + tensor_methods::lerp(xla_self, xla_end, weight)); } at::Tensor XLANativeFunctions::lift(const at::Tensor& tensor) { @@ -2118,8 +2173,9 @@ std::tuple XLANativeFunctions::linalg_inv_ex( } auto common_device = torch_xla::bridge::GetXlaDevice(self); TORCH_INTERNAL_ASSERT(common_device); - torch::lazy::NodePtr node = torch_xla::MakeNode( - GetValueOrThrow(bridge::GetXlaTensor(self))->GetIrValue()); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + torch::lazy::NodePtr node = + torch_xla::MakeNode(xla_self->GetIrValue()); auto result = torch_xla::XLATensor::Create(std::move(node), *common_device); auto info = tensor_methods::full_like(result, 0, result->GetDevice(), at::ScalarType::Int); @@ -2148,68 +2204,68 @@ at::Tensor XLANativeFunctions::linspace(const at::Scalar& start, at::Tensor XLANativeFunctions::log(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::log(GetValueOrThrow(bridge::GetXlaTensor(self)))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor(tensor_methods::log(xla_self)); } at::Tensor XLANativeFunctions::logit(const at::Tensor& self, std::optional eps) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::logit(GetValueOrThrow(bridge::GetXlaTensor(self)), eps)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor(tensor_methods::logit(xla_self, eps)); } at::Tensor XLANativeFunctions::log10(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::log_base(GetValueOrThrow(bridge::GetXlaTensor(self)), - torch::lazy::OpKind(at::aten::log10), 10.0)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor(tensor_methods::log_base( + xla_self, torch::lazy::OpKind(at::aten::log10), 10.0)); } at::Tensor XLANativeFunctions::log1p(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::log1p(GetValueOrThrow(bridge::GetXlaTensor(self)))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor(tensor_methods::log1p(xla_self)); } at::Tensor XLANativeFunctions::log2(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::log_base(GetValueOrThrow(bridge::GetXlaTensor(self)), - torch::lazy::OpKind(at::aten::log2), 2.0)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor(tensor_methods::log_base( + xla_self, torch::lazy::OpKind(at::aten::log2), 2.0)); } at::Tensor XLANativeFunctions::logsumexp(const at::Tensor& self, at::IntArrayRef dim, bool keepdim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor( - tensor_methods::logsumexp(GetValueOrThrow(bridge::GetXlaTensor(self)), - torch::lazy::ToVector(dim), + tensor_methods::logsumexp(xla_self, torch::lazy::ToVector(dim), /*keep_reduced_dimensions=*/keepdim)); } at::Tensor XLANativeFunctions::xlogy(const at::Tensor& self, const at::Tensor& other) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::xlogy(GetValueOrThrow(bridge::GetXlaTensor(self)), - GetValueOrThrow(bridge::GetXlaTensor(other)))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_other, bridge::GetXlaTensor(other)); + return bridge::AtenFromXlaTensor(tensor_methods::xlogy(xla_self, xla_other)); } at::Tensor XLANativeFunctions::masked_scatter(const at::Tensor& self, const at::Tensor& mask, const at::Tensor& source) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - return bridge::AtenFromXlaTensor(tensor_methods::masked_scatter( - self_tensor, GetValueOrThrow(bridge::GetXlaTensor(mask)), - GetValueOrThrow(bridge::GetXlaTensor(source)))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_mask, bridge::GetXlaTensor(mask)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_source, bridge::GetXlaTensor(source)); + return bridge::AtenFromXlaTensor( + tensor_methods::masked_scatter(xla_self, xla_mask, xla_source)); } at::Tensor XLANativeFunctions::masked_select(const at::Tensor& self, const at::Tensor& mask) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); // Initially make XLA handled masked_select() handling experimental, and // opt-in. if (!DebugUtil::ExperimentEnabled("masked_select")) { @@ -2217,21 +2273,23 @@ at::Tensor XLANativeFunctions::masked_select(const at::Tensor& self, ATEN_OP(masked_select)>::call(self, mask); } - return bridge::AtenFromXlaTensor(tensor_methods::masked_select( - self_tensor, GetValueOrThrow(bridge::GetXlaTensor(mask)))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_mask, bridge::GetXlaTensor(mask)); + return bridge::AtenFromXlaTensor( + tensor_methods::masked_select(xla_self, xla_mask)); } at::Tensor XLANativeFunctions::max(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::max(GetValueOrThrow(bridge::GetXlaTensor(self)))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor(tensor_methods::max(xla_self)); } std::tuple XLANativeFunctions::max( const at::Tensor& self, int64_t dim, bool keepdim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto outputs = tensor_methods::max( - GetValueOrThrow(bridge::GetXlaTensor(self)), dim, keepdim); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + auto outputs = tensor_methods::max(xla_self, dim, keepdim); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)), bridge::AtenFromXlaTensor(std::get<1>(outputs))); } @@ -2240,12 +2298,11 @@ std::tuple XLANativeFunctions::max_out( const at::Tensor& self, int64_t dim, bool keepdim, at::Tensor& max, at::Tensor& max_values) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr max_tensor = GetValueOrThrow(bridge::GetXlaTensor(max)); - XLATensorPtr max_values_tensor = - GetValueOrThrow(bridge::GetXlaTensor(max_values)); - tensor_methods::max_out(max_tensor, max_values_tensor, - GetValueOrThrow(bridge::GetXlaTensor(self)), dim, - keepdim); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_max, bridge::GetXlaTensor(max)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_max_values, + bridge::GetXlaTensor(max_values)); + tensor_methods::max_out(xla_max, xla_max_values, xla_self, dim, keepdim); return std::forward_as_tuple(max, max_values); } @@ -2270,10 +2327,10 @@ std::tuple XLANativeFunctions::max_pool2d_with_indices( dilation, ceil_mode); } + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); auto outputs = tensor_methods::max_pool_nd( - GetValueOrThrow(bridge::GetXlaTensor(self)), /*spatial_dim_count=*/2, - XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), - XlaHelpers::I64List(padding), ceil_mode); + xla_self, /*spatial_dim_count=*/2, XlaHelpers::I64List(kernel_size), + XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)), bridge::AtenFromXlaTensor(std::get<1>(outputs))); } @@ -2293,11 +2350,13 @@ at::Tensor XLANativeFunctions::max_pool2d_with_indices_backward( padding, dilation, ceil_mode, indices); } + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::max_pool_nd_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad_output)), - GetValueOrThrow(bridge::GetXlaTensor(self)), - /*spatial_dim_count=*/2, XlaHelpers::I64List(kernel_size), - XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode)); + xla_grad_output, xla_self, /*spatial_dim_count=*/2, + XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), + XlaHelpers::I64List(padding), ceil_mode)); } at::Tensor XLANativeFunctions::max_pool3d( @@ -2323,11 +2382,13 @@ at::Tensor XLANativeFunctions::max_pool3d_with_indices_backward( padding, dilation, ceil_mode, indices); } + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::max_pool_nd_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad_output)), - GetValueOrThrow(bridge::GetXlaTensor(self)), - /*spatial_dim_count=*/3, XlaHelpers::I64List(kernel_size), - XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode)); + xla_grad_output, xla_self, /*spatial_dim_count=*/3, + XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), + XlaHelpers::I64List(padding), ceil_mode)); } std::tuple XLANativeFunctions::max_pool3d_with_indices( @@ -2343,10 +2404,10 @@ std::tuple XLANativeFunctions::max_pool3d_with_indices( dilation, ceil_mode); } + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); auto outputs = tensor_methods::max_pool_nd( - GetValueOrThrow(bridge::GetXlaTensor(self)), /*spatial_dim_count=*/3, - XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), - XlaHelpers::I64List(padding), ceil_mode); + xla_self, /*spatial_dim_count=*/3, XlaHelpers::I64List(kernel_size), + XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)), bridge::AtenFromXlaTensor(std::get<1>(outputs))); } @@ -2355,10 +2416,10 @@ at::Tensor XLANativeFunctions::max_unpool2d(const at::Tensor& self, const at::Tensor& indices, at::IntArrayRef output_size) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::max_unpool(GetValueOrThrow(bridge::GetXlaTensor(self)), - GetValueOrThrow(bridge::GetXlaTensor(indices)), - torch::lazy::ToVector(output_size))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_indices, bridge::GetXlaTensor(indices)); + return bridge::AtenFromXlaTensor(tensor_methods::max_unpool( + xla_self, xla_indices, torch::lazy::ToVector(output_size))); } at::Tensor XLANativeFunctions::max_unpool3d(const at::Tensor& self, @@ -2367,19 +2428,19 @@ at::Tensor XLANativeFunctions::max_unpool3d(const at::Tensor& self, at::IntArrayRef stride, at::IntArrayRef padding) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::max_unpool(GetValueOrThrow(bridge::GetXlaTensor(self)), - GetValueOrThrow(bridge::GetXlaTensor(indices)), - torch::lazy::ToVector(output_size))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_indices, bridge::GetXlaTensor(indices)); + return bridge::AtenFromXlaTensor(tensor_methods::max_unpool( + xla_self, xla_indices, torch::lazy::ToVector(output_size))); } at::Tensor XLANativeFunctions::mean(const at::Tensor& self, std::optional dtype) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::mean( - self_tensor, - torch::lazy::Iota(self_tensor->shape().get().dimensions_size()), + xla_self, + torch::lazy::Iota(xla_self->shape().get().dimensions_size()), /*keep_reduced_dimensions=*/false, dtype)); } @@ -2387,64 +2448,64 @@ at::Tensor XLANativeFunctions::mean(const at::Tensor& self, at::OptionalIntArrayRef dim, bool keepdim, std::optional dtype) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - return bridge::AtenFromXlaTensor(tensor_methods::mean( - self_tensor, - dim ? torch::lazy::ToVector(*dim) - : torch::lazy::Iota( - self_tensor->shape().get().dimensions_size()), - keepdim, dtype)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor( + tensor_methods::mean(xla_self, + dim ? torch::lazy::ToVector(*dim) + : torch::lazy::Iota( + xla_self->shape().get().dimensions_size()), + keepdim, dtype)); } at::Tensor XLANativeFunctions::min(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::min(GetValueOrThrow(bridge::GetXlaTensor(self)))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor(tensor_methods::min(xla_self)); } std::tuple XLANativeFunctions::min( const at::Tensor& self, int64_t dim, bool keepdim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto outputs = tensor_methods::min( - GetValueOrThrow(bridge::GetXlaTensor(self)), dim, keepdim); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + auto outputs = tensor_methods::min(xla_self, dim, keepdim); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)), bridge::AtenFromXlaTensor(std::get<1>(outputs))); } at::Tensor XLANativeFunctions::mish(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::mish(GetValueOrThrow(bridge::GetXlaTensor(self)))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor(tensor_methods::mish(xla_self)); } std::tuple XLANativeFunctions::min_out( const at::Tensor& self, int64_t dim, bool keepdim, at::Tensor& min, at::Tensor& min_indices) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr min_tensor = GetValueOrThrow(bridge::GetXlaTensor(min)); - XLATensorPtr min_indices_tensor = - GetValueOrThrow(bridge::GetXlaTensor(min_indices)); - tensor_methods::min_out(min_tensor, min_indices_tensor, - GetValueOrThrow(bridge::GetXlaTensor(self)), dim, - keepdim); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_min, bridge::GetXlaTensor(min)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_min_indices, + bridge::GetXlaTensor(min_indices)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + tensor_methods::min_out(xla_min, xla_min_indices, xla_self, dim, keepdim); return std::forward_as_tuple(min, min_indices); } at::Tensor XLANativeFunctions::mm(const at::Tensor& self, const at::Tensor& mat2) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::mm( - /*input=*/GetValueOrThrow(bridge::GetXlaTensor(self)), - /*weight=*/GetValueOrThrow(bridge::GetXlaTensor(mat2)))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_mat2, bridge::GetXlaTensor(mat2)); + return bridge::AtenFromXlaTensor(tensor_methods::mm(xla_self, xla_mat2)); } at::Tensor XLANativeFunctions::mse_loss(const at::Tensor& self, const at::Tensor& target, int64_t reduction) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::mse_loss( - GetValueOrThrow(bridge::GetXlaTensor(self)), - GetValueOrThrow(bridge::GetXlaTensor(target)), reduction)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_target, bridge::GetXlaTensor(target)); + return bridge::AtenFromXlaTensor( + tensor_methods::mse_loss(xla_self, xla_target, reduction)); } at::Tensor XLANativeFunctions::mse_loss_backward(const at::Tensor& grad_output, @@ -2452,10 +2513,12 @@ at::Tensor XLANativeFunctions::mse_loss_backward(const at::Tensor& grad_output, const at::Tensor& target, int64_t reduction) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_target, bridge::GetXlaTensor(target)); return bridge::AtenFromXlaTensor(tensor_methods::mse_loss_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad_output)), - GetValueOrThrow(bridge::GetXlaTensor(self)), - GetValueOrThrow(bridge::GetXlaTensor(target)), reduction)); + xla_grad_output, xla_self, xla_target, reduction)); } at::Tensor XLANativeFunctions::mul(const at::Tensor& self, @@ -2499,26 +2562,26 @@ at::Tensor XLANativeFunctions::multinomial( replacement, generator); } - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor( - tensor_methods::multinomial(self_tensor, num_samples, replacement)); + tensor_methods::multinomial(xla_self, num_samples, replacement)); } at::Tensor XLANativeFunctions::mv(const at::Tensor& self, const at::Tensor& vec) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::mv(GetValueOrThrow(bridge::GetXlaTensor(self)), - GetValueOrThrow(bridge::GetXlaTensor(vec)))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_vec, bridge::GetXlaTensor(vec)); + return bridge::AtenFromXlaTensor(tensor_methods::mv(xla_self, xla_vec)); } at::Tensor& XLANativeFunctions::mv_out(const at::Tensor& self, const at::Tensor& vec, at::Tensor& out) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr out_tensor = GetValueOrThrow(bridge::GetXlaTensor(out)); - tensor_methods::mv_out(out_tensor, - GetValueOrThrow(bridge::GetXlaTensor(self)), - GetValueOrThrow(bridge::GetXlaTensor(vec))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_out, bridge::GetXlaTensor(out)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_vec, bridge::GetXlaTensor(vec)); + tensor_methods::mv_out(xla_out, xla_self, xla_vec); return out; } @@ -2531,8 +2594,8 @@ at::Tensor XLANativeFunctions::nan_to_num(const at::Tensor& self, if (!at::native::is_floating_point(self)) { return torch::lazy::CopyTensor(self); } - XLATensorPtr input_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - const torch::lazy::BackendDevice& device = input_tensor->GetDevice(); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + const torch::lazy::BackendDevice& device = xla_self->GetDevice(); auto element_type = MakeXlaPrimitiveType(self.scalar_type(), &device); XlaHelpers::MinMax min_max = XlaHelpers::MinMaxValues(element_type); at::Scalar nan_replacement = nan.has_value() ? *nan : 0.0; @@ -2547,7 +2610,7 @@ at::Tensor XLANativeFunctions::nan_to_num(const at::Tensor& self, << min_max.min.toDouble() << ", " << min_max.max.toDouble() << "]."; } return bridge::AtenFromXlaTensor(tensor_methods::nan_to_num( - input_tensor, nan_replacement, posinf_replacement, neginf_replacement)); + xla_self, nan_replacement, posinf_replacement, neginf_replacement)); } std::tuple @@ -2558,17 +2621,16 @@ XLANativeFunctions::native_batch_norm( const std::optional& running_var, bool training, double momentum, double eps) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr input_tensor = GetValueOrThrow(bridge::GetXlaTensor(input)); - const torch::lazy::BackendDevice& device = input_tensor->GetDevice(); - XLATensorPtr running_mean_tensor = + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_input, bridge::GetXlaTensor(input)); + const torch::lazy::BackendDevice& device = xla_input->GetDevice(); + XLATensorPtr xla_running_mean = bridge::GetOrCreateXlaTensor(running_mean, device); - XLATensorPtr running_var_tensor = + XLATensorPtr xla_running_var = bridge::GetOrCreateXlaTensor(running_var, device); auto outputs = tensor_methods::native_batch_norm( - GetValueOrThrow(bridge::GetXlaTensor(input)), - bridge::GetOrCreateXlaTensor(weight, device), - bridge::GetOrCreateXlaTensor(bias, device), running_mean_tensor, - running_var_tensor, training, momentum, eps); + xla_input, bridge::GetOrCreateXlaTensor(weight, device), + bridge::GetOrCreateXlaTensor(bias, device), xla_running_mean, + xla_running_var, training, momentum, eps); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)), bridge::AtenFromXlaTensor(std::get<1>(outputs)), bridge::AtenFromXlaTensor(std::get<2>(outputs))); @@ -2580,17 +2642,16 @@ XLANativeFunctions::_native_batch_norm_legit( const std::optional& bias, at::Tensor& running_mean, at::Tensor& running_var, bool training, double momentum, double eps) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr input_tensor = GetValueOrThrow(bridge::GetXlaTensor(input)); - const torch::lazy::BackendDevice& device = input_tensor->GetDevice(); - XLATensorPtr running_mean_tensor = - GetValueOrThrow(bridge::GetXlaTensor(running_mean)); - XLATensorPtr running_var_tensor = - GetValueOrThrow(bridge::GetXlaTensor(running_var)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_input, bridge::GetXlaTensor(input)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_running_mean, + bridge::GetXlaTensor(running_mean)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_running_var, + bridge::GetXlaTensor(running_var)); + const torch::lazy::BackendDevice& device = xla_input->GetDevice(); auto outputs = tensor_methods::native_batch_norm( - GetValueOrThrow(bridge::GetXlaTensor(input)), - bridge::GetOrCreateXlaTensor(weight, device), - bridge::GetOrCreateXlaTensor(bias, device), running_mean_tensor, - running_var_tensor, training, momentum, eps); + xla_input, bridge::GetOrCreateXlaTensor(weight, device), + bridge::GetOrCreateXlaTensor(bias, device), xla_running_mean, + xla_running_var, training, momentum, eps); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)), bridge::AtenFromXlaTensor(std::get<1>(outputs)), bridge::AtenFromXlaTensor(std::get<2>(outputs))); @@ -2602,15 +2663,14 @@ XLANativeFunctions::_native_batch_norm_legit( const std::optional& bias, bool training, double momentum, double eps) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr input_tensor = GetValueOrThrow(bridge::GetXlaTensor(input)); - const torch::lazy::BackendDevice& device = input_tensor->GetDevice(); - XLATensorPtr null_running_mean_tensor = XLATensorPtr(); - XLATensorPtr null_running_var_tensor = XLATensorPtr(); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_input, bridge::GetXlaTensor(input)); + XLATensorPtr xla_null_running_mean = XLATensorPtr(); + XLATensorPtr xla_null_running_var = XLATensorPtr(); + const torch::lazy::BackendDevice& device = xla_input->GetDevice(); auto outputs = tensor_methods::native_batch_norm( - GetValueOrThrow(bridge::GetXlaTensor(input)), - bridge::GetOrCreateXlaTensor(weight, device), - bridge::GetOrCreateXlaTensor(bias, device), null_running_mean_tensor, - null_running_var_tensor, training, momentum, eps); + xla_input, bridge::GetOrCreateXlaTensor(weight, device), + bridge::GetOrCreateXlaTensor(bias, device), xla_null_running_mean, + xla_null_running_var, training, momentum, eps); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)), bridge::AtenFromXlaTensor(std::get<1>(outputs)), bridge::AtenFromXlaTensor(std::get<2>(outputs))); @@ -2626,13 +2686,12 @@ XLANativeFunctions::native_batch_norm_backward( const std::optional& save_invstd, bool train, double eps, std::array output_mask) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr grad_out_tensor = - GetValueOrThrow(bridge::GetXlaTensor(grad_out)); - const torch::lazy::BackendDevice& device = grad_out_tensor->GetDevice(); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_out, + bridge::GetXlaTensor(grad_out)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_input, bridge::GetXlaTensor(input)); + const torch::lazy::BackendDevice& device = xla_grad_out->GetDevice(); auto gradients = tensor_methods::native_batch_norm_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad_out)), - GetValueOrThrow(bridge::GetXlaTensor(input)), - bridge::GetOrCreateXlaTensor(weight, device), + xla_grad_out, xla_input, bridge::GetOrCreateXlaTensor(weight, device), bridge::GetOrCreateXlaTensor(save_mean, device), bridge::GetOrCreateXlaTensor(save_invstd, device), train, eps); at::Tensor undefined; @@ -2648,8 +2707,8 @@ XLANativeFunctions::native_batch_norm_backward( std::tuple XLANativeFunctions::native_dropout( const at::Tensor& self, double p, std::optional train) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - auto results = tensor_methods::native_dropout(self_tensor, p, train); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + auto results = tensor_methods::native_dropout(xla_self, p, train); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), bridge::AtenFromXlaTensor(std::get<1>(results))); } @@ -2660,8 +2719,8 @@ at::Tensor XLANativeFunctions::neg(const at::Tensor& self) { << "Negation, the `-` operator, on a bool tensor is not supported. If " "you are trying to invert a mask, use the `~` or `logical_not()` " "operator instead."; - return bridge::AtenFromXlaTensor( - tensor_methods::neg(GetValueOrThrow(bridge::GetXlaTensor(self)))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor(tensor_methods::neg(xla_self)); } at::Tensor XLANativeFunctions::nll_loss2d_backward( @@ -2669,18 +2728,19 @@ at::Tensor XLANativeFunctions::nll_loss2d_backward( const at::Tensor& target, const std::optional& weight, int64_t reduction, int64_t ignore_index, const at::Tensor& total_weight) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - XLATensorPtr weight_tensor = - bridge::GetOrCreateXlaTensor(weight, self_tensor->GetDevice()); - XLATensorPtr total_weight_tensor; + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_target, bridge::GetXlaTensor(target)); + const torch::lazy::BackendDevice& device = xla_grad_output->GetDevice(); + XLATensorPtr xla_weight = bridge::GetOrCreateXlaTensor(weight, device); + XLATensorPtr xla_total_weight; if (IsDefined(weight)) { - total_weight_tensor = - bridge::GetOrCreateXlaTensor(total_weight, self_tensor->GetDevice()); + xla_total_weight = bridge::GetOrCreateXlaTensor(total_weight, device); } return bridge::AtenFromXlaTensor(tensor_methods::nll_loss2d_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad_output)), self_tensor, - GetValueOrThrow(bridge::GetXlaTensor(target)), weight_tensor, reduction, - ignore_index, total_weight_tensor)); + xla_grad_output, xla_self, xla_target, xla_weight, reduction, + ignore_index, xla_total_weight)); } std::tuple XLANativeFunctions::nll_loss2d_forward( @@ -2688,15 +2748,16 @@ std::tuple XLANativeFunctions::nll_loss2d_forward( const std::optional& weight, int64_t reduction, int64_t ignore_index) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - XLATensorPtr total_weight = GetValueOrThrow(tensor_methods::full( - {}, 1, self_tensor->GetDevice(), self_tensor->dtype())); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_target, bridge::GetXlaTensor(target)); + const torch::lazy::BackendDevice& device = xla_self->GetDevice(); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_total_weight, + tensor_methods::full({}, 1, device, xla_self->dtype())); return std::make_tuple( bridge::AtenFromXlaTensor(tensor_methods::nll_loss2d( - self_tensor, GetValueOrThrow(bridge::GetXlaTensor(target)), - bridge::GetOrCreateXlaTensor(weight, self_tensor->GetDevice()), + xla_self, xla_target, bridge::GetOrCreateXlaTensor(weight, device), reduction, ignore_index)), - bridge::AtenFromXlaTensor(total_weight)); + bridge::AtenFromXlaTensor(xla_total_weight)); } at::Tensor XLANativeFunctions::nll_loss_backward( @@ -2704,18 +2765,19 @@ at::Tensor XLANativeFunctions::nll_loss_backward( const at::Tensor& target, const std::optional& weight, int64_t reduction, int64_t ignore_index, const at::Tensor& total_weight) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - XLATensorPtr weight_tensor = - bridge::GetOrCreateXlaTensor(weight, self_tensor->GetDevice()); - XLATensorPtr total_weight_tensor; + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_target, bridge::GetXlaTensor(target)); + const torch::lazy::BackendDevice& device = xla_grad_output->GetDevice(); + XLATensorPtr xla_weight = bridge::GetOrCreateXlaTensor(weight, device); + XLATensorPtr xla_total_weight; if (IsDefined(weight)) { - total_weight_tensor = - bridge::GetOrCreateXlaTensor(total_weight, self_tensor->GetDevice()); + xla_total_weight = bridge::GetOrCreateXlaTensor(total_weight, device); } return bridge::AtenFromXlaTensor(tensor_methods::nll_loss_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad_output)), self_tensor, - GetValueOrThrow(bridge::GetXlaTensor(target)), weight_tensor, reduction, - ignore_index, total_weight_tensor)); + xla_grad_output, xla_self, xla_target, xla_weight, reduction, + ignore_index, xla_total_weight)); } std::tuple XLANativeFunctions::nll_loss_forward( @@ -2723,26 +2785,27 @@ std::tuple XLANativeFunctions::nll_loss_forward( const std::optional& weight, int64_t reduction, int64_t ignore_index) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - XLATensorPtr total_weight = GetValueOrThrow(tensor_methods::full( - {}, 1, self_tensor->GetDevice(), self_tensor->dtype())); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_target, bridge::GetXlaTensor(target)); + const torch::lazy::BackendDevice& device = xla_self->GetDevice(); + XLA_ASSIGN_OR_THROW(XLATensorPtr total_weight, + tensor_methods::full({}, 1, device, xla_self->dtype())); return std::make_tuple( bridge::AtenFromXlaTensor(tensor_methods::nll_loss( - self_tensor, GetValueOrThrow(bridge::GetXlaTensor(target)), - bridge::GetOrCreateXlaTensor(weight, self_tensor->GetDevice()), + xla_self, xla_target, bridge::GetOrCreateXlaTensor(weight, device), reduction, ignore_index)), bridge::AtenFromXlaTensor(total_weight)); } at::Tensor XLANativeFunctions::nonzero(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); // Initially make XLA handled nonzero() handling experimental, and opt-in. if (!DebugUtil::ExperimentEnabled("nonzero")) { return at::native::call_fallback_fn<&xla_fallback, ATEN_OP(nonzero)>::call( self); } - return bridge::AtenFromXlaTensor(tensor_methods::nonzero(self_tensor)); + return bridge::AtenFromXlaTensor(tensor_methods::nonzero(xla_self)); } at::Tensor XLANativeFunctions::norm(const at::Tensor& self, @@ -2755,9 +2818,9 @@ at::Tensor XLANativeFunctions::norm(const at::Tensor& self, return at::native::call_fallback_fn< &xla_fallback, ATEN_OP2(norm, ScalarOpt_dtype)>::call(self, p, dtype); } + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor( - tensor_methods::norm(GetValueOrThrow(bridge::GetXlaTensor(self)), p, - dtype, {}, /*keepdim=*/false)); + tensor_methods::norm(xla_self, p, dtype, {}, /*keepdim=*/false)); } at::Tensor XLANativeFunctions::norm(const at::Tensor& self, @@ -2769,9 +2832,9 @@ at::Tensor XLANativeFunctions::norm(const at::Tensor& self, return at::native::call_fallback_fn<&xla_fallback, ATEN_OP2(norm, Scalar)>::call(self, p); } + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor( - tensor_methods::norm(GetValueOrThrow(bridge::GetXlaTensor(self)), p, - std::nullopt, {}, /*keepdim=*/false)); + tensor_methods::norm(xla_self, p, std::nullopt, {}, /*keepdim=*/false)); } at::Tensor XLANativeFunctions::norm(const at::Tensor& self, @@ -2787,8 +2850,9 @@ at::Tensor XLANativeFunctions::norm(const at::Tensor& self, keepdim, dtype); } - return bridge::AtenFromXlaTensor(tensor_methods::norm( - GetValueOrThrow(bridge::GetXlaTensor(self)), p, dtype, dim, keepdim)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor( + tensor_methods::norm(xla_self, p, dtype, dim, keepdim)); } at::Tensor XLANativeFunctions::norm(const at::Tensor& self, @@ -2802,9 +2866,9 @@ at::Tensor XLANativeFunctions::norm(const at::Tensor& self, &xla_fallback, ATEN_OP2(norm, ScalarOpt_dim)>::call(self, p, dim, keepdim); } + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor( - tensor_methods::norm(GetValueOrThrow(bridge::GetXlaTensor(self)), p, - std::nullopt, dim, keepdim)); + tensor_methods::norm(xla_self, p, std::nullopt, dim, keepdim)); } at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, double std, @@ -2815,8 +2879,8 @@ at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, double std, &xla_fallback, ATEN_OP2(normal, Tensor_float)>::call(mean, std, generator); } - return bridge::AtenFromXlaTensor( - tensor_methods::normal(GetValueOrThrow(bridge::GetXlaTensor(mean)), std)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_mean, bridge::GetXlaTensor(mean)); + return bridge::AtenFromXlaTensor(tensor_methods::normal(xla_mean, std)); } at::Tensor XLANativeFunctions::normal(double mean, const at::Tensor& std, @@ -2827,8 +2891,8 @@ at::Tensor XLANativeFunctions::normal(double mean, const at::Tensor& std, &xla_fallback, ATEN_OP2(normal, float_Tensor)>::call(mean, std, generator); } - return bridge::AtenFromXlaTensor( - tensor_methods::normal(mean, GetValueOrThrow(bridge::GetXlaTensor(std)))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_std, bridge::GetXlaTensor(std)); + return bridge::AtenFromXlaTensor(tensor_methods::normal(mean, xla_std)); } at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, @@ -2840,9 +2904,9 @@ at::Tensor XLANativeFunctions::normal(const at::Tensor& mean, &xla_fallback, ATEN_OP2(normal, Tensor_Tensor)>::call(mean, std, generator); } - return bridge::AtenFromXlaTensor( - tensor_methods::normal(GetValueOrThrow(bridge::GetXlaTensor(mean)), - GetValueOrThrow(bridge::GetXlaTensor(std)))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_mean, bridge::GetXlaTensor(mean)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_std, bridge::GetXlaTensor(std)); + return bridge::AtenFromXlaTensor(tensor_methods::normal(xla_mean, xla_std)); } at::Tensor& XLANativeFunctions::normal_( @@ -2853,16 +2917,17 @@ at::Tensor& XLANativeFunctions::normal_( return at::native::call_fallback_fn<&xla_fallback, ATEN_OP(normal_)>::call( self, mean, std, generator); } - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - tensor_methods::normal_(self_tensor, mean, std); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + tensor_methods::normal_(xla_self, mean, std); return self; } at::Tensor XLANativeFunctions::permute_copy(const at::Tensor& self, at::IntArrayRef dims) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::permute( - GetValueOrThrow(bridge::GetXlaTensor(self)), XlaHelpers::I64List(dims))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor( + tensor_methods::permute(xla_self, XlaHelpers::I64List(dims))); } at::Tensor XLANativeFunctions::pow(const at::Tensor& self, @@ -2908,11 +2973,10 @@ at::Tensor XLANativeFunctions::_prelu_kernel(const at::Tensor& self, << weight_num << " and channel size = " << channel_size; } - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - XLATensorPtr weight_tensor = GetValueOrThrow(bridge::GetXlaTensor(weight)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_weight, bridge::GetXlaTensor(weight)); - return bridge::AtenFromXlaTensor( - tensor_methods::prelu(self_tensor, weight_tensor)); + return bridge::AtenFromXlaTensor(tensor_methods::prelu(xla_self, xla_weight)); } std::tuple XLANativeFunctions::_prelu_kernel_backward( @@ -2920,13 +2984,13 @@ std::tuple XLANativeFunctions::_prelu_kernel_backward( const at::Tensor& weight) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr grad_output_tensor = - GetValueOrThrow(bridge::GetXlaTensor(grad_output)); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - XLATensorPtr weight_tensor = GetValueOrThrow(bridge::GetXlaTensor(weight)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_weight, bridge::GetXlaTensor(weight)); - auto outputs = tensor_methods::prelu_backward(grad_output_tensor, self_tensor, - weight_tensor); + auto outputs = + tensor_methods::prelu_backward(xla_grad_output, xla_self, xla_weight); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)), bridge::AtenFromXlaTensor(std::get<1>(outputs))); } @@ -2934,10 +2998,10 @@ std::tuple XLANativeFunctions::_prelu_kernel_backward( at::Tensor XLANativeFunctions::prod(const at::Tensor& self, std::optional dtype) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::prod( - self_tensor, - torch::lazy::Iota(self_tensor->shape().get().dimensions_size()), + xla_self, + torch::lazy::Iota(xla_self->shape().get().dimensions_size()), /*keep_reduced_dimensions=*/false, PromoteIntegralType(self.scalar_type(), dtype))); } @@ -2946,9 +3010,10 @@ at::Tensor XLANativeFunctions::prod(const at::Tensor& self, int64_t dim, bool keepdim, std::optional dtype) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::prod( - GetValueOrThrow(bridge::GetXlaTensor(self)), {dim}, keepdim, - PromoteIntegralType(self.scalar_type(), dtype))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor( + tensor_methods::prod(xla_self, {dim}, keepdim, + PromoteIntegralType(self.scalar_type(), dtype))); } void XLANativeFunctions::_propagate_xla_data(const at::Tensor& input, @@ -2959,8 +3024,8 @@ void XLANativeFunctions::_propagate_xla_data(const at::Tensor& input, // for in-place ops we have in hands. // 1) Aid XLA's InputOutputAlias. - auto input_tensor = GetValueOrThrow(bridge::GetXlaTensor(input)); - auto output_tensor = GetValueOrThrow(bridge::GetXlaTensor(output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr input_tensor, bridge::GetXlaTensor(input)); + XLA_ASSIGN_OR_THROW(XLATensorPtr output_tensor, bridge::GetXlaTensor(output)); if (input_tensor->CurrentDataHandle() != nullptr || (input_tensor->CurrentIrValue().node != nullptr && torch_xla::DeviceData::Cast( @@ -3007,18 +3072,18 @@ void XLANativeFunctions::_propagate_xla_data(const at::Tensor& input, at::Tensor& XLANativeFunctions::put_(at::Tensor& self, const at::Tensor& index, const at::Tensor& source, bool accumulate) { - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - tensor_methods::put_( - self_tensor, GetValueOrThrow(bridge::GetXlaTensor(index)), - GetValueOrThrow(bridge::GetXlaTensor(source)), accumulate); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_index, bridge::GetXlaTensor(index)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_source, bridge::GetXlaTensor(source)); + tensor_methods::put_(xla_self, xla_index, xla_source, accumulate); return self; } std::tuple XLANativeFunctions::qr( const at::Tensor& self, bool some) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto results = - tensor_methods::qr(GetValueOrThrow(bridge::GetXlaTensor(self)), some); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + auto results = tensor_methods::qr(xla_self, some); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), bridge::AtenFromXlaTensor(std::get<1>(results))); } @@ -3033,16 +3098,17 @@ at::Tensor& XLANativeFunctions::random_( &xla_fallback, ATEN_OP2(random_, from)>::call(self, from, to, generator); } - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - at::ScalarType dtype = self_tensor->dtype(); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + at::ScalarType dtype = xla_self->dtype(); // Prevent "to_val" from overflowing with at::ScalarType::Long. int64_t inc = (dtype == at::ScalarType::Long) ? 0 : 1; int64_t to_val = (to) ? *to : GetIntegerUpperLimitForType(dtype) + inc; - OkOrThrow(CheckValueWithinTypeRange("random_", "from", dtype, from)); - OkOrThrow(CheckValueWithinTypeRange("random_", "to", dtype, to_val - 1)); - OkOrThrow(tensor_methods::random_(self_tensor, from, to_val)); + XLA_THROW_IF_ERROR(CheckValueWithinTypeRange("random_", "from", dtype, from)); + XLA_THROW_IF_ERROR( + CheckValueWithinTypeRange("random_", "to", dtype, to_val - 1)); + XLA_THROW_IF_ERROR(tensor_methods::random_(xla_self, from, to_val)); return self; } @@ -3055,12 +3121,11 @@ at::Tensor& XLANativeFunctions::random_( ATEN_OP2(random_, to)>::call(self, to, generator); } + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + at::ScalarType dtype = xla_self->dtype(); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - at::ScalarType dtype = self_tensor->dtype(); - - OkOrThrow(CheckValueWithinTypeRange("random_", "to", dtype, to - 1)); - OkOrThrow(tensor_methods::random_(self_tensor, 0, to)); + XLA_THROW_IF_ERROR(CheckValueWithinTypeRange("random_", "to", dtype, to - 1)); + XLA_THROW_IF_ERROR(tensor_methods::random_(xla_self, 0, to)); return self; } @@ -3072,14 +3137,14 @@ at::Tensor& XLANativeFunctions::random_( return at::native::call_fallback_fn<&xla_fallback, ATEN_OP(random_)>::call( self, generator); } - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - at::ScalarType dtype = self_tensor->dtype(); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + at::ScalarType dtype = xla_self->dtype(); // Prevent "to_val" from overflowing with at::ScalarType::Long. int64_t inc = (dtype == at::ScalarType::Long) ? 0 : 1; int64_t to_val = GetIntegerUpperLimitForType(dtype) + inc; - OkOrThrow(tensor_methods::random_(self_tensor, 0, to_val)); + XLA_THROW_IF_ERROR(tensor_methods::random_(xla_self, 0, to_val)); return self; } @@ -3110,132 +3175,139 @@ at::Tensor XLANativeFunctions::randperm(int64_t n, at::Tensor XLANativeFunctions::reflection_pad1d(const at::Tensor& self, at::IntArrayRef padding) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::reflection_pad1d( - GetValueOrThrow(bridge::GetXlaTensor(self)), - torch::lazy::ToVector(padding))); + xla_self, torch::lazy::ToVector(padding))); } at::Tensor XLANativeFunctions::reflection_pad1d_backward( const at::Tensor& grad_output, const at::Tensor& self, at::IntArrayRef padding) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::reflection_pad1d_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad_output)), - GetValueOrThrow(bridge::GetXlaTensor(self)), - torch::lazy::ToVector(padding))); + xla_grad_output, xla_self, torch::lazy::ToVector(padding))); } at::Tensor XLANativeFunctions::reflection_pad2d(const at::Tensor& self, at::IntArrayRef padding) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::reflection_pad2d( - GetValueOrThrow(bridge::GetXlaTensor(self)), - torch::lazy::ToVector(padding))); + xla_self, torch::lazy::ToVector(padding))); } at::Tensor XLANativeFunctions::reflection_pad2d_backward( const at::Tensor& grad_output, const at::Tensor& self, at::IntArrayRef padding) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::reflection_pad2d_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad_output)), - GetValueOrThrow(bridge::GetXlaTensor(self)), - torch::lazy::ToVector(padding))); + xla_grad_output, xla_self, torch::lazy::ToVector(padding))); } at::Tensor XLANativeFunctions::reflection_pad3d(const at::Tensor& self, at::IntArrayRef padding) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::reflection_pad3d( - GetValueOrThrow(bridge::GetXlaTensor(self)), - torch::lazy::ToVector(padding))); + xla_self, torch::lazy::ToVector(padding))); } at::Tensor XLANativeFunctions::reflection_pad3d_backward( const at::Tensor& grad_output, const at::Tensor& self, at::IntArrayRef padding) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::reflection_pad3d_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad_output)), - GetValueOrThrow(bridge::GetXlaTensor(self)), - torch::lazy::ToVector(padding))); + xla_grad_output, xla_self, torch::lazy::ToVector(padding))); } at::Tensor XLANativeFunctions::remainder(const at::Tensor& self, const at::Tensor& other) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_other, bridge::GetXlaTensor(other)); return bridge::AtenFromXlaTensor( - tensor_methods::remainder(GetValueOrThrow(bridge::GetXlaTensor(self)), - GetValueOrThrow(bridge::GetXlaTensor(other)))); + tensor_methods::remainder(xla_self, xla_other)); } at::Tensor XLANativeFunctions::remainder(const at::Tensor& self, const at::Scalar& other) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::remainder( - GetValueOrThrow(bridge::GetXlaTensor(self)), other)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor(tensor_methods::remainder(xla_self, other)); } at::Tensor XLANativeFunctions::replication_pad1d(const at::Tensor& self, at::IntArrayRef padding) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::replication_pad1d( - GetValueOrThrow(bridge::GetXlaTensor(self)), - XlaHelpers::I64List(padding))); + xla_self, XlaHelpers::I64List(padding))); } at::Tensor XLANativeFunctions::replication_pad1d_backward( const at::Tensor& grad_output, const at::Tensor& self, at::IntArrayRef padding) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::replication_pad1d_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad_output)), - GetValueOrThrow(bridge::GetXlaTensor(self)), - XlaHelpers::I64List(padding))); + xla_grad_output, xla_self, XlaHelpers::I64List(padding))); } at::Tensor XLANativeFunctions::replication_pad2d(const at::Tensor& self, at::IntArrayRef padding) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::replication_pad2d( - GetValueOrThrow(bridge::GetXlaTensor(self)), - XlaHelpers::I64List(padding))); + xla_self, XlaHelpers::I64List(padding))); } at::Tensor XLANativeFunctions::replication_pad2d_backward( const at::Tensor& grad_output, const at::Tensor& self, at::IntArrayRef padding) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::replication_pad2d_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad_output)), - GetValueOrThrow(bridge::GetXlaTensor(self)), - XlaHelpers::I64List(padding))); + xla_grad_output, xla_self, XlaHelpers::I64List(padding))); } at::Tensor XLANativeFunctions::replication_pad3d(const at::Tensor& self, at::IntArrayRef padding) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::replication_pad3d( - GetValueOrThrow(bridge::GetXlaTensor(self)), - XlaHelpers::I64List(padding))); + xla_self, XlaHelpers::I64List(padding))); } at::Tensor XLANativeFunctions::replication_pad3d_backward( const at::Tensor& grad_output, const at::Tensor& self, at::IntArrayRef padding) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::replication_pad3d_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad_output)), - GetValueOrThrow(bridge::GetXlaTensor(self)), - XlaHelpers::I64List(padding))); + xla_grad_output, xla_self, XlaHelpers::I64List(padding))); } const at::Tensor& XLANativeFunctions::resize_( const at::Tensor& self, at::IntArrayRef size, std::optional /* memory_format */) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - tensor_methods::resize_(self_tensor, XlaHelpers::I64List(size)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + tensor_methods::resize_(xla_self, XlaHelpers::I64List(size)); return self; } @@ -3243,9 +3315,9 @@ at::Tensor XLANativeFunctions::roll(const at::Tensor& self, at::IntArrayRef shifts, at::IntArrayRef dims) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::roll( - GetValueOrThrow(bridge::GetXlaTensor(self)), XlaHelpers::I64List(shifts), - XlaHelpers::I64List(dims))); + xla_self, XlaHelpers::I64List(shifts), XlaHelpers::I64List(dims))); } at::Tensor XLANativeFunctions::rrelu_with_noise( @@ -3261,10 +3333,10 @@ at::Tensor XLANativeFunctions::rrelu_with_noise( upper, training, generator); } - XLATensorPtr noise_tensor = GetValueOrThrow(bridge::GetXlaTensor(noise)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_noise, bridge::GetXlaTensor(noise)); return bridge::AtenFromXlaTensor(tensor_methods::rrelu_with_noise( - GetValueOrThrow(bridge::GetXlaTensor(self)), noise_tensor, lower, upper, - training)); + xla_self, xla_noise, lower, upper, training)); } at::Tensor XLANativeFunctions::rrelu_with_noise_backward( @@ -3274,11 +3346,12 @@ at::Tensor XLANativeFunctions::rrelu_with_noise_backward( TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); double negative_slope = (lower.to() + upper.to()) / 2; XLA_CHECK(!self_is_result || negative_slope > 0.0); - XLATensorPtr noise_tensor = GetValueOrThrow(bridge::GetXlaTensor(noise)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_noise, bridge::GetXlaTensor(noise)); return bridge::AtenFromXlaTensor(tensor_methods::rrelu_with_noise_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad_output)), - GetValueOrThrow(bridge::GetXlaTensor(self)), noise_tensor, lower, upper, - training)); + xla_grad_output, xla_self, xla_noise, lower, upper, training)); } at::Tensor XLANativeFunctions::rsub(const at::Tensor& self, @@ -3308,15 +3381,15 @@ at::Tensor XLANativeFunctions::rsub(const at::Tensor& self, at::Tensor scatter_reduce_helper(const at::Tensor& self, int64_t dim, const at::Tensor& index, const at::Tensor& src, std::optional reduce) { - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_index, bridge::GetXlaTensor(index)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_src, bridge::GetXlaTensor(src)); if (!reduce.has_value()) { - return bridge::AtenFromXlaTensor(tensor_methods::scatter( - self_tensor, dim, GetValueOrThrow(bridge::GetXlaTensor(index)), - GetValueOrThrow(bridge::GetXlaTensor(src)))); + return bridge::AtenFromXlaTensor( + tensor_methods::scatter(xla_self, dim, xla_index, xla_src)); } else if (*reduce == "add") { - return bridge::AtenFromXlaTensor(tensor_methods::scatter_add( - self_tensor, dim, GetValueOrThrow(bridge::GetXlaTensor(index)), - GetValueOrThrow(bridge::GetXlaTensor(src)))); + return bridge::AtenFromXlaTensor( + tensor_methods::scatter_add(xla_self, dim, xla_index, xla_src)); } else { // TODO: implement scatter_mul return at::native::call_fallback_fn< @@ -3330,13 +3403,14 @@ at::Tensor scatter_reduce_helper(const at::Tensor& self, int64_t dim, const at::Scalar& value, std::optional reduce) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_index, bridge::GetXlaTensor(index)); if (!reduce.has_value()) { - return bridge::AtenFromXlaTensor(tensor_methods::scatter( - self_tensor, dim, GetValueOrThrow(bridge::GetXlaTensor(index)), value)); + return bridge::AtenFromXlaTensor( + tensor_methods::scatter(xla_self, dim, xla_index, value)); } else if (*reduce == "add") { - return bridge::AtenFromXlaTensor(tensor_methods::scatter_add( - self_tensor, dim, GetValueOrThrow(bridge::GetXlaTensor(index)), value)); + return bridge::AtenFromXlaTensor( + tensor_methods::scatter_add(xla_self, dim, xla_index, value)); } else { // TODO: implement scatter_mul return at::native::call_fallback_fn< @@ -3388,13 +3462,14 @@ at::Tensor XLANativeFunctions::scatter_reduce( const at::Tensor& self, int64_t dim, const at::Tensor& index, const at::Tensor& src, std::string_view reduce, bool include_self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_index, bridge::GetXlaTensor(index)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_src, bridge::GetXlaTensor(src)); if ((reduce == "sum" || reduce == "prod" || reduce == "amin" || reduce == "amax") && include_self) { return bridge::AtenFromXlaTensor(tensor_methods::scatter_reduce( - GetValueOrThrow(bridge::GetXlaTensor(self)), dim, - GetValueOrThrow(bridge::GetXlaTensor(index)), - GetValueOrThrow(bridge::GetXlaTensor(src)), reduce, include_self)); + xla_self, dim, xla_index, xla_src, reduce, include_self)); } else { return at::native::call_fallback_fn< &xla_fallback, ATEN_OP2(scatter_reduce, two)>::call(self, dim, index, @@ -3406,62 +3481,63 @@ at::Tensor XLANativeFunctions::scatter_reduce( at::Tensor XLANativeFunctions::select_copy(const at::Tensor& self, int64_t dim, int64_t index) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::select( - GetValueOrThrow(bridge::GetXlaTensor(self)), dim, index)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor( + tensor_methods::select(xla_self, dim, index)); } at::Tensor XLANativeFunctions::select_scatter(const at::Tensor& base, const at::Tensor& mutated_view, int64_t dim, int64_t index) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto base_tensor = GetValueOrThrow(bridge::GetXlaTensor(base)); - auto base_tensor_shape = base_tensor->shape(); - auto mutated_view_tensor = - GetValueOrThrow(bridge::GetXlaTensor(mutated_view)); - auto mutated_view_tensor_shape = mutated_view_tensor->shape(); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_base, bridge::GetXlaTensor(base)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_mutated_view, + bridge::GetXlaTensor(mutated_view)); + xla::Shape xla_base_shape = xla_base->shape(); + xla::Shape xla_mutated_view_shape = xla_mutated_view->shape(); auto common_device = torch_xla::bridge::GetXlaDevice(base); dim = torch::lazy::GetCanonicalDimensionIndex( - dim, base_tensor_shape.get().dimensions_size()); - xla::Shape narrow_shape = base_tensor_shape; + dim, xla_base_shape.dimensions_size()); + xla::Shape narrow_shape = xla_base_shape; narrow_shape.set_dimensions(dim, 1); - torch::lazy::NodePtr mutated_view_tensor_reshaped_node = - torch_xla::MakeNode( - mutated_view_tensor->GetIrValue(), - torch::lazy::ToVector(narrow_shape.dimensions())); + torch::lazy::NodePtr mutated_view_reshaped_node = torch_xla::MakeNode( + xla_mutated_view->GetIrValue(), + torch::lazy::ToVector(narrow_shape.dimensions())); - std::vector indices(base_tensor_shape.get().dimensions_size(), 0); + std::vector indices(xla_base_shape.dimensions_size(), 0); indices[dim] = torch::lazy::GetCanonicalPosition( - runtime::util::ToVector(base_tensor_shape.get().dimensions()), - dim, index); + runtime::util::ToVector(xla_base_shape.dimensions()), dim, + index); return bridge::AtenFromXlaTensor( - base_tensor->CreateFrom(torch_xla::MakeNode( - base_tensor->GetIrValue(), mutated_view_tensor_reshaped_node, - indices))); + xla_base->CreateFrom(torch_xla::MakeNode( + xla_base->GetIrValue(), mutated_view_reshaped_node, indices))); } // TODO(JackCaoG): Remove after elu being codegened at::Tensor& XLANativeFunctions::selu_(at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - tensor_methods::selu_(self_tensor); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + tensor_methods::selu_(xla_self); return self; } at::Tensor& XLANativeFunctions::set_(at::Tensor& self, const at::Tensor& source) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr source_tensor = GetValueOrThrow(bridge::GetXlaTensor(source)); - OkOrThrow(bridge::ReplaceXlaTensor(self, source_tensor)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_source, bridge::GetXlaTensor(source)); + XLA_THROW_IF_ERROR(bridge::ReplaceXlaTensor(self, xla_source)); return self; } at::Tensor XLANativeFunctions::sigmoid_backward(const at::Tensor& grad_output, const at::Tensor& output) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::sigmoid_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad_output)), - GetValueOrThrow(bridge::GetXlaTensor(output)))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_output, bridge::GetXlaTensor(output)); + return bridge::AtenFromXlaTensor( + tensor_methods::sigmoid_backward(xla_grad_output, xla_output)); } at::Tensor XLANativeFunctions::slice_copy(const at::Tensor& self, int64_t dim, @@ -3469,24 +3545,24 @@ at::Tensor XLANativeFunctions::slice_copy(const at::Tensor& self, int64_t dim, std::optional end, int64_t step) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); int64_t start_val = start.has_value() ? start.value() : 0; int64_t end_val = end.has_value() ? end.value() : INT64_MAX; return bridge::AtenFromXlaTensor(bridge::SetBaseTensor( - tensor_methods::slice(GetValueOrThrow(bridge::GetXlaTensor(self)), dim, - start_val, end_val, step), - self)); + tensor_methods::slice(xla_self, dim, start_val, end_val, step), self)); } at::Tensor XLANativeFunctions::slice_scatter( const at::Tensor& base, const at::Tensor& mutated_view, int64_t dim, std::optional start, std::optional end, int64_t step) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto base_ = GetValueOrThrow(bridge::GetXlaTensor(base)); - auto mutated_view_ = GetValueOrThrow(bridge::GetXlaTensor(mutated_view)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_base, bridge::GetXlaTensor(base)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_mutated_view, + bridge::GetXlaTensor(mutated_view)); int64_t start_val = start.has_value() ? start.value() : 0; int64_t end_val = end.has_value() ? end.value() : INT64_MAX; - auto input_shape = base_->shape(); + auto input_shape = xla_base->shape(); dim = torch::lazy::GetCanonicalDimensionIndex( dim, input_shape.get().dimensions_size()); start_val = torch::lazy::GetCanonicalPosition( @@ -3502,36 +3578,40 @@ at::Tensor XLANativeFunctions::slice_scatter( step = std::min(step, end_val - start_val); return bridge::AtenFromXlaTensor( - base_->CreateFrom(torch_xla::MakeNode( - base_->GetIrValue(), mutated_view_->GetIrValue(), dim, start_val, - end_val, step))); + xla_base->CreateFrom(torch_xla::MakeNode( + xla_base->GetIrValue(), xla_mutated_view->GetIrValue(), dim, + start_val, end_val, step))); } at::Tensor XLANativeFunctions::smooth_l1_loss(const at::Tensor& self, const at::Tensor& target, int64_t reduction, double beta) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::smooth_l1_loss( - GetValueOrThrow(bridge::GetXlaTensor(self)), - GetValueOrThrow(bridge::GetXlaTensor(target)), reduction, beta)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_target, bridge::GetXlaTensor(target)); + return bridge::AtenFromXlaTensor( + tensor_methods::smooth_l1_loss(xla_self, xla_target, reduction, beta)); } at::Tensor XLANativeFunctions::smooth_l1_loss_backward( const at::Tensor& grad_output, const at::Tensor& self, const at::Tensor& target, int64_t reduction, double beta) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_target, bridge::GetXlaTensor(target)); return bridge::AtenFromXlaTensor(tensor_methods::smooth_l1_loss_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad_output)), - GetValueOrThrow(bridge::GetXlaTensor(self)), - GetValueOrThrow(bridge::GetXlaTensor(target)), reduction, beta)); + xla_grad_output, xla_self, xla_target, reduction, beta)); } at::Tensor XLANativeFunctions::softplus(const at::Tensor& self, const at::Scalar& beta, const at::Scalar& threshold) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::softplus( - GetValueOrThrow(bridge::GetXlaTensor(self)), beta, threshold)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor( + tensor_methods::softplus(xla_self, beta, threshold)); } at::Tensor XLANativeFunctions::softplus_backward(const at::Tensor& grad_output, @@ -3539,17 +3619,19 @@ at::Tensor XLANativeFunctions::softplus_backward(const at::Tensor& grad_output, const at::Scalar& beta, const at::Scalar& threshold) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::softplus_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad_output)), - GetValueOrThrow(bridge::GetXlaTensor(self)), beta, threshold)); + xla_grad_output, xla_self, beta, threshold)); } std::tuple XLANativeFunctions::sort( const at::Tensor& self, int64_t dim, bool descending) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto results = tensor_methods::topk( - GetValueOrThrow(bridge::GetXlaTensor(self)), self.size(dim), dim, - descending, /*sorted=*/true, /*stable=*/false); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + auto results = tensor_methods::topk(xla_self, self.size(dim), dim, descending, + /*sorted=*/true, /*stable=*/false); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), bridge::AtenFromXlaTensor(std::get<1>(results))); } @@ -3558,10 +3640,9 @@ std::tuple XLANativeFunctions::sort( const at::Tensor& self, std::optional stable, int64_t dim, bool descending) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); auto results = tensor_methods::topk( - GetValueOrThrow(bridge::GetXlaTensor(self)), self.size(dim), dim, - descending, - /*sorted=*/false, + xla_self, self.size(dim), dim, descending, /*sorted=*/false, /*stable=*/stable.has_value() ? stable.value() : false); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), bridge::AtenFromXlaTensor(std::get<1>(results))); @@ -3571,39 +3652,39 @@ std::vector XLANativeFunctions::split_copy(const at::Tensor& self, int64_t split_size, int64_t dim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto xla_tensors = tensor_methods::split( - GetValueOrThrow(bridge::GetXlaTensor(self)), split_size, dim); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + auto xla_tensors = tensor_methods::split(xla_self, split_size, dim); return bridge::AtenFromXlaTensors(xla_tensors); } std::vector XLANativeFunctions::split_with_sizes_copy( const at::Tensor& self, at::IntArrayRef split_sizes, int64_t dim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); auto xla_tensors = tensor_methods::split_with_sizes( - GetValueOrThrow(bridge::GetXlaTensor(self)), - XlaHelpers::I64List(split_sizes), dim); + xla_self, XlaHelpers::I64List(split_sizes), dim); return bridge::AtenFromXlaTensors(xla_tensors); } at::Tensor XLANativeFunctions::squeeze_copy(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::squeeze(GetValueOrThrow(bridge::GetXlaTensor(self)))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor(tensor_methods::squeeze(xla_self)); } at::Tensor XLANativeFunctions::squeeze_copy(const at::Tensor& self, int64_t dim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::squeeze( - GetValueOrThrow(bridge::GetXlaTensor(self)), dim)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor(tensor_methods::squeeze(xla_self, dim)); } at::Tensor XLANativeFunctions::squeeze_copy(const at::Tensor& self, at::IntArrayRef dim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor( - tensor_methods::squeeze(GetValueOrThrow(bridge::GetXlaTensor(self)), - torch::lazy::ToVector(dim))); + tensor_methods::squeeze(xla_self, torch::lazy::ToVector(dim))); } at::Tensor XLANativeFunctions::stack(at::TensorList tensors, int64_t dim) { @@ -3612,16 +3693,17 @@ at::Tensor XLANativeFunctions::stack(at::TensorList tensors, int64_t dim) { std::vector c_tensors(tensors.size()); std::transform(tensors.begin(), tensors.end(), c_tensors.begin(), [=](const at::Tensor& t) { return t.to(result_type); }); - return bridge::AtenFromXlaTensor(tensor_methods::stack( - GetValueOrThrow(bridge::GetXlaTensors(c_tensors)), dim)); + XLA_ASSIGN_OR_THROW(std::vector xla_c_tensors, + bridge::GetXlaTensors(c_tensors)); + return bridge::AtenFromXlaTensor(tensor_methods::stack(xla_c_tensors, dim)); } at::Tensor XLANativeFunctions::std(const at::Tensor& self, bool unbiased) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::std( - self_tensor, - torch::lazy::Iota(self_tensor->shape().get().dimensions_size()), + xla_self, + torch::lazy::Iota(xla_self->shape().get().dimensions_size()), /*keep_reduced_dimensions=*/false, /*correction=*/unbiased ? 1.0 : 0.0)); } @@ -3629,13 +3711,13 @@ at::Tensor XLANativeFunctions::std(const at::Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - return bridge::AtenFromXlaTensor(tensor_methods::std( - self_tensor, - dim ? torch::lazy::ToVector(*dim) - : torch::lazy::Iota( - self_tensor->shape().get().dimensions_size()), - keepdim, /*correction=*/unbiased ? 1.0 : 0.0)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor( + tensor_methods::std(xla_self, + dim ? torch::lazy::ToVector(*dim) + : torch::lazy::Iota( + xla_self->shape().get().dimensions_size()), + keepdim, /*correction=*/unbiased ? 1.0 : 0.0)); } at::Tensor XLANativeFunctions::std(const at::Tensor& self, @@ -3643,25 +3725,25 @@ at::Tensor XLANativeFunctions::std(const at::Tensor& self, const std::optional& correction, bool keepdim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - return bridge::AtenFromXlaTensor(tensor_methods::std( - self_tensor, - dim ? torch::lazy::ToVector(*dim) - : torch::lazy::Iota( - self_tensor->shape().get().dimensions_size()), - keepdim, correction ? correction->toDouble() : 1.0)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor( + tensor_methods::std(xla_self, + dim ? torch::lazy::ToVector(*dim) + : torch::lazy::Iota( + xla_self->shape().get().dimensions_size()), + keepdim, correction ? correction->toDouble() : 1.0)); } std::tuple XLANativeFunctions::std_mean( const at::Tensor& self, at::OptionalIntArrayRef dim, const std::optional& correction, bool keepdim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); auto results = tensor_methods::std_mean( - self_tensor, + xla_self, dim ? torch::lazy::ToVector(*dim) : torch::lazy::Iota( - self_tensor->shape().get().dimensions_size()), + xla_self->shape().get().dimensions_size()), correction ? correction->toDouble() : 1.0, keepdim); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), bridge::AtenFromXlaTensor(std::get<1>(results))); @@ -3702,10 +3784,10 @@ at::Tensor XLANativeFunctions::sub(const at::Tensor& self, at::Tensor XLANativeFunctions::sum(const at::Tensor& self, std::optional dtype) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::sum( - self_tensor, - torch::lazy::Iota(self_tensor->shape().get().dimensions_size()), + xla_self, + torch::lazy::Iota(xla_self->shape().get().dimensions_size()), /*keep_reduced_dimensions=*/false, dtype)); } @@ -3713,20 +3795,20 @@ at::Tensor XLANativeFunctions::sum(const at::Tensor& self, at::OptionalIntArrayRef dim, bool keepdim, std::optional dtype) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - return bridge::AtenFromXlaTensor(tensor_methods::sum( - self_tensor, - dim ? torch::lazy::ToVector(*dim) - : torch::lazy::Iota( - self_tensor->shape().get().dimensions_size()), - keepdim, dtype)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor( + tensor_methods::sum(xla_self, + dim ? torch::lazy::ToVector(*dim) + : torch::lazy::Iota( + xla_self->shape().get().dimensions_size()), + keepdim, dtype)); } std::tuple XLANativeFunctions::svd( const at::Tensor& self, bool some, bool compute_uv) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - auto results = tensor_methods::svd( - GetValueOrThrow(bridge::GetXlaTensor(self)), some, compute_uv); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + auto results = tensor_methods::svd(xla_self, some, compute_uv); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), bridge::AtenFromXlaTensor(std::get<1>(results)), bridge::AtenFromXlaTensor(std::get<2>(results))); @@ -3734,57 +3816,62 @@ std::tuple XLANativeFunctions::svd( at::Tensor XLANativeFunctions::t_copy(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::transpose( - GetValueOrThrow(bridge::GetXlaTensor(self)), 0, 1)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor(tensor_methods::transpose(xla_self, 0, 1)); } at::Tensor XLANativeFunctions::tanh_backward(const at::Tensor& grad_output, const at::Tensor& output) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::tanh_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad_output)), - GetValueOrThrow(bridge::GetXlaTensor(output)))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_output, bridge::GetXlaTensor(output)); + return bridge::AtenFromXlaTensor( + tensor_methods::tanh_backward(xla_grad_output, xla_output)); } at::Tensor XLANativeFunctions::threshold(const at::Tensor& self, const at::Scalar& threshold, const at::Scalar& value) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::threshold(GetValueOrThrow(bridge::GetXlaTensor(self)), - threshold.to(), value.to())); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor(tensor_methods::threshold( + xla_self, threshold.to(), value.to())); } at::Tensor XLANativeFunctions::threshold_backward(const at::Tensor& grad_output, const at::Tensor& self, const at::Scalar& threshold) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::threshold_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad_output)), - GetValueOrThrow(bridge::GetXlaTensor(self)), threshold.to())); + xla_grad_output, xla_self, threshold.to())); } std::tuple XLANativeFunctions::topk( const at::Tensor& self, int64_t k, int64_t dim, bool largest, bool sorted) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); auto results = - tensor_methods::topk(GetValueOrThrow(bridge::GetXlaTensor(self)), k, dim, - largest, sorted, /*stable=*/false); + tensor_methods::topk(xla_self, k, dim, largest, sorted, /*stable=*/false); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), bridge::AtenFromXlaTensor(std::get<1>(results))); } at::Tensor XLANativeFunctions::trace(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::trace(GetValueOrThrow(bridge::GetXlaTensor(self)))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor(tensor_methods::trace(xla_self)); } at::Tensor XLANativeFunctions::transpose_copy(const at::Tensor& self, int64_t dim0, int64_t dim1) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::transpose( - GetValueOrThrow(bridge::GetXlaTensor(self)), dim0, dim1)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor( + tensor_methods::transpose(xla_self, dim0, dim1)); } std::tuple XLANativeFunctions::triangular_solve( @@ -3793,10 +3880,10 @@ std::tuple XLANativeFunctions::triangular_solve( TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); // Currently, ATen doesn't have a left_side option. Once this // is added, this API will have to be changed. + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_b, bridge::GetXlaTensor(b)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_A, bridge::GetXlaTensor(A)); auto results = tensor_methods::triangular_solve( - GetValueOrThrow(bridge::GetXlaTensor(b)), - GetValueOrThrow(bridge::GetXlaTensor(A)), /*left_side=*/true, upper, - transpose, unitriangular); + xla_b, xla_A, /*left_side=*/true, upper, transpose, unitriangular); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), bridge::AtenFromXlaTensor(std::get<1>(results))); } @@ -3804,8 +3891,8 @@ std::tuple XLANativeFunctions::triangular_solve( std::vector XLANativeFunctions::unbind_copy(const at::Tensor& self, int64_t dim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensors( - tensor_methods::unbind(GetValueOrThrow(bridge::GetXlaTensor(self)), dim)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensors(tensor_methods::unbind(xla_self, dim)); } at::Tensor& XLANativeFunctions::uniform_( @@ -3816,25 +3903,24 @@ at::Tensor& XLANativeFunctions::uniform_( return at::native::call_fallback_fn<&xla_fallback, ATEN_OP(uniform_)>::call( self, from, to, generator); } - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - tensor_methods::uniform_(self_tensor, from, to); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + tensor_methods::uniform_(xla_self, from, to); return self; } at::Tensor XLANativeFunctions::unsqueeze_copy(const at::Tensor& self, int64_t dim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::unsqueeze( - GetValueOrThrow(bridge::GetXlaTensor(self)), dim)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor(tensor_methods::unsqueeze(xla_self, dim)); } at::Tensor XLANativeFunctions::upsample_bilinear2d( const at::Tensor& self, at::IntArrayRef output_size, bool align_corners, std::optional scales_h, std::optional scales_w) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - absl::Span input_dims = - self_tensor->shape().get().dimensions(); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + absl::Span input_dims = xla_self->shape().get().dimensions(); std::vector scaled_output_size = torch::lazy::ToVector(output_size); if ((scales_h && *scales_h != 1.0) || (scales_w && *scales_w != 1.0)) { @@ -3847,7 +3933,7 @@ at::Tensor XLANativeFunctions::upsample_bilinear2d( } } return bridge::AtenFromXlaTensor(tensor_methods::upsample_bilinear2d( - self_tensor, scaled_output_size, align_corners)); + xla_self, scaled_output_size, align_corners)); } at::Tensor XLANativeFunctions::upsample_bilinear2d_backward( @@ -3855,12 +3941,12 @@ at::Tensor XLANativeFunctions::upsample_bilinear2d_backward( at::IntArrayRef input_size, bool align_corners, std::optional scales_h, std::optional scales_w) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr grad_output_tensor = - GetValueOrThrow(bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); // Only the XLA TPU backend for now implements the CustomCall required by // our XLA lowering. XlaDeviceType hw_type = - static_cast(grad_output_tensor->GetDevice().type()); + static_cast(xla_grad_output->GetDevice().type()); if (!CheckTpuDevice(hw_type)) { return at::native::call_fallback_fn< &xla_fallback, @@ -3880,7 +3966,7 @@ at::Tensor XLANativeFunctions::upsample_bilinear2d_backward( } } return bridge::AtenFromXlaTensor(tensor_methods::upsample_bilinear2d_backward( - grad_output_tensor, torch::lazy::ToVector(scaled_output_size), + xla_grad_output, torch::lazy::ToVector(scaled_output_size), torch::lazy::ToVector(input_size), align_corners)); } @@ -3888,9 +3974,8 @@ at::Tensor XLANativeFunctions::upsample_nearest2d( const at::Tensor& self, at::IntArrayRef output_size, std::optional scales_h, std::optional scales_w) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - absl::Span input_dims = - self_tensor->shape().get().dimensions(); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + absl::Span input_dims = xla_self->shape().get().dimensions(); std::vector scaled_output_size = torch::lazy::ToVector(output_size); if ((scales_h && *scales_h != 1.0) || (scales_w && *scales_w != 1.0)) { @@ -3903,7 +3988,7 @@ at::Tensor XLANativeFunctions::upsample_nearest2d( } } return bridge::AtenFromXlaTensor( - tensor_methods::upsample_nearest2d(self_tensor, scaled_output_size)); + tensor_methods::upsample_nearest2d(xla_self, scaled_output_size)); } at::Tensor XLANativeFunctions::upsample_nearest2d_backward( @@ -3911,12 +3996,12 @@ at::Tensor XLANativeFunctions::upsample_nearest2d_backward( at::IntArrayRef input_size, std::optional scales_h, std::optional scales_w) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr grad_output_tensor = - GetValueOrThrow(bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); // Only the XLA TPU backend for now implements the CustomCall required by // our XLA lowering. XlaDeviceType hw_type = - static_cast(grad_output_tensor->GetDevice().type()); + static_cast(xla_grad_output->GetDevice().type()); if (!CheckTpuDevice(hw_type) && !CheckNeuronDevice(hw_type)) { return at::native::call_fallback_fn< &xla_fallback, ATEN_OP(upsample_nearest2d_backward)>::call(grad_output, @@ -3937,7 +4022,7 @@ at::Tensor XLANativeFunctions::upsample_nearest2d_backward( } } return bridge::AtenFromXlaTensor(tensor_methods::upsample_nearest2d_backward( - grad_output_tensor, torch::lazy::ToVector(scaled_output_size), + xla_grad_output, torch::lazy::ToVector(scaled_output_size), torch::lazy::ToVector(input_size))); } @@ -3946,15 +4031,12 @@ at::Tensor XLANativeFunctions::var(const at::Tensor& self, const std::optional& correction, bool keepdim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor( - tensor_methods::var(self_tensor, + tensor_methods::var(xla_self, dim ? XlaHelpers::I64List(*dim) : torch::lazy::Iota( - GetValueOrThrow(bridge::GetXlaTensor(self)) - ->shape() - .get() - .dimensions_size()), + xla_self->shape().get().dimensions_size()), correction ? correction->toDouble() : 1.0, keepdim)); } @@ -3962,12 +4044,12 @@ std::tuple XLANativeFunctions::var_mean( const at::Tensor& self, at::OptionalIntArrayRef dim, const std::optional& correction, bool keepdim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); auto results = tensor_methods::var_mean( - self_tensor, + xla_self, dim ? torch::lazy::ToVector(*dim) : torch::lazy::Iota( - self_tensor->shape().get().dimensions_size()), + xla_self->shape().get().dimensions_size()), correction ? correction->toDouble() : 1.0, keepdim); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), bridge::AtenFromXlaTensor(std::get<1>(results))); @@ -3983,9 +4065,9 @@ at::Tensor XLANativeFunctions::view_as_complex_copy(const at::Tensor& self) { "tensors, but got a tensor of scalar type: " << self.scalar_type(); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor( - tensor_methods::view_as_complex_copy(self_tensor)); + tensor_methods::view_as_complex_copy(xla_self)); } at::Tensor XLANativeFunctions::view_as_real_copy(const at::Tensor& self) { @@ -3995,9 +4077,8 @@ at::Tensor XLANativeFunctions::view_as_real_copy(const at::Tensor& self) { "tensors, but got a tensor of scalar type: " << self.scalar_type(); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - return bridge::AtenFromXlaTensor( - tensor_methods::view_as_real_copy(self_tensor)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor(tensor_methods::view_as_real_copy(xla_self)); } at::Tensor XLANativeFunctions::view_copy_symint(const at::Tensor& self, @@ -4005,7 +4086,7 @@ at::Tensor XLANativeFunctions::view_copy_symint(const at::Tensor& self, TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); std::optional int_shape = c10::asIntArrayRefSlowOpt(shape); bool input_shape_static = int_shape.has_value(); - XLATensorPtr xla_input = GetValueOrThrow(bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_input, bridge::GetXlaTensor(self)); bool input_has_dyn_shape = xla_input->shape().get().is_dynamic(); XLA_CHECK(!(input_has_dyn_shape && input_shape_static)) @@ -4023,16 +4104,18 @@ at::Tensor XLANativeFunctions::where(const at::Tensor& condition, c10::MaybeOwned b_condition, b_self, b_other; std::tie(b_condition, b_self, b_other) = xla_expand_outplace(condition, self, other, "where"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_condition, + bridge::GetXlaTensor(*b_condition)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(*b_self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_other, bridge::GetXlaTensor(*b_other)); return bridge::AtenFromXlaTensor( - tensor_methods::where(GetValueOrThrow(bridge::GetXlaTensor(*b_condition)), - GetValueOrThrow(bridge::GetXlaTensor(*b_self)), - GetValueOrThrow(bridge::GetXlaTensor(*b_other)))); + tensor_methods::where(xla_condition, xla_self, xla_other)); } at::Tensor& XLANativeFunctions::zero_(at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - tensor_methods::zero_(self_tensor); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + tensor_methods::zero_(xla_self); return self; } @@ -4044,18 +4127,18 @@ std::tuple XLANativeFunctions::_linalg_svd( // As per https://pytorch.org/docs/stable/generated/torch.svd.html, // The second boolean argument is exactly opposite between // torch::svd and torch::_linalg_svd, hence the negation of full_matrices. - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - auto results = tensor_methods::svd(self_tensor, !full_matrices, compute_uv); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + auto results = tensor_methods::svd(xla_self, !full_matrices, compute_uv); auto u = std::get<0>(results); auto s = std::get<1>(results); auto vh = tensor_methods::transpose(std::get<2>(results), 0, 1); if (!compute_uv) { // When compute_uv is false, torch::_linalg_svd returns an empty tensor for // u and vh. - u = GetValueOrThrow(tensor_methods::full({0}, 0, self_tensor->GetDevice(), - self_tensor->dtype())); - vh = GetValueOrThrow(tensor_methods::full({0}, 0, self_tensor->GetDevice(), - self_tensor->dtype())); + XLA_ASSIGN_OR_THROW(u, tensor_methods::full({0}, 0, xla_self->GetDevice(), + xla_self->dtype())); + XLA_ASSIGN_OR_THROW(vh, tensor_methods::full({0}, 0, xla_self->GetDevice(), + xla_self->dtype())); } return std::make_tuple(bridge::AtenFromXlaTensor(u), bridge::AtenFromXlaTensor(s), @@ -4065,8 +4148,8 @@ std::tuple XLANativeFunctions::_linalg_svd( at::Scalar XLANativeFunctions::_local_scalar_dense(const at::Tensor& self) { if (DebugUtil::ExperimentEnabled("early_sync")) { // sync tensors in order to save computation when step is marked later. - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); - XLAGraphExecutor::Get()->SyncLiveTensorsGraph(&self_tensor->GetDevice(), + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XLAGraphExecutor::Get()->SyncLiveTensorsGraph(&xla_self->GetDevice(), /*devices=*/{}, /*wait=*/true); TORCH_LAZY_COUNTER("EarlySyncLiveTensorsCount", 1); @@ -4106,23 +4189,21 @@ at::Tensor XLANativeFunctions::_cdist_forward( // (compute_mode is 0 or 1) is achieved through composite ops from // native pytorch. TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_x1, bridge::GetXlaTensor(x1)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_x2, bridge::GetXlaTensor(x2)); XLA_CHECK(p >= 0) << "p value for the p-norm distance must be >= 0"; - return bridge::AtenFromXlaTensor(tensor_methods::cdist_forward( - GetValueOrThrow(bridge::GetXlaTensor(x1)), - GetValueOrThrow(bridge::GetXlaTensor(x2)), p)); + return bridge::AtenFromXlaTensor( + tensor_methods::cdist_forward(xla_x1, xla_x2, p)); } at::Tensor XLANativeFunctions::_pdist_forward(const at::Tensor& self, double p) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); XLA_CHECK(p >= 0) << "p value for the p-norm distance must be >= 0"; - XLA_CHECK(GetValueOrThrow(bridge::GetXlaTensor(self)) - ->shape() - .get() - .dimensions_size() == 2) + XLA_CHECK(xla_self->shape().get().dimensions_size() == 2) << "pdist only support 2d dimension"; - return bridge::AtenFromXlaTensor(tensor_methods::pdist_forward( - GetValueOrThrow(bridge::GetXlaTensor(self)), p)); + return bridge::AtenFromXlaTensor(tensor_methods::pdist_forward(xla_self, p)); } // All of the below ops correspond to CompositeExplicitAutograd kernels from @@ -4206,24 +4287,24 @@ XLANativeFunctions::convolution_backward( at::Tensor XLANativeFunctions::count_nonzero(const at::Tensor& self, std::optional dim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr xla_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); std::vector dims; if (dim) { dims = torch::lazy::GetCanonicalDimensionIndices( - {dim.value()}, xla_tensor->shape().get().dimensions_size()); + {dim.value()}, xla_self->shape().get().dimensions_size()); } return bridge::AtenFromXlaTensor( - tensor_methods::count_nonzero(xla_tensor, dims)); + tensor_methods::count_nonzero(xla_self, dims)); } at::Tensor XLANativeFunctions::count_nonzero(const at::Tensor& self, at::IntArrayRef dim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr xla_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); std::vector canonical_dims = torch::lazy::GetCanonicalDimensionIndices( - dim, xla_tensor->shape().get().dimensions_size()); + dim, xla_self->shape().get().dimensions_size()); std::unordered_set dims_set; for (int dim : canonical_dims) { XLA_CHECK(dims_set.find(dim) == dims_set.end()) @@ -4232,7 +4313,7 @@ at::Tensor XLANativeFunctions::count_nonzero(const at::Tensor& self, } return bridge::AtenFromXlaTensor( - tensor_methods::count_nonzero(xla_tensor, XlaHelpers::I64List(dim))); + tensor_methods::count_nonzero(xla_self, XlaHelpers::I64List(dim))); } at::Tensor XLANativeFunctions::diag_embed(const at::Tensor& self, @@ -4257,9 +4338,10 @@ at::Tensor XLANativeFunctions::embedding_symint(const at::Tensor& weight, // TODO: We need to make use of the TPU embedding core here eventually. TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::embedding( - GetValueOrThrow(bridge::GetXlaTensor(weight)), - GetValueOrThrow(bridge::GetXlaTensor(indices)))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_weight, bridge::GetXlaTensor(weight)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_indices, bridge::GetXlaTensor(indices)); + return bridge::AtenFromXlaTensor( + tensor_methods::embedding(xla_weight, xla_indices)); } at::Tensor XLANativeFunctions::_euclidean_dist(const at::Tensor& x1, @@ -4296,8 +4378,9 @@ at::Tensor XLANativeFunctions::narrow_copy_symint(const at::Tensor& self, at::Tensor XLANativeFunctions::pixel_shuffle(const at::Tensor& self, int64_t upscale_factor) { - return bridge::AtenFromXlaTensor(tensor_methods::pixel_shuffle( - GetValueOrThrow(bridge::GetXlaTensor(self)), upscale_factor)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor( + tensor_methods::pixel_shuffle(xla_self, upscale_factor)); } at::Tensor XLANativeFunctions::pixel_unshuffle(const at::Tensor& self, @@ -4382,12 +4465,12 @@ at::Tensor XLANativeFunctions::linalg_vector_norm( TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLA_CHECK(at::isFloatingType(self.scalar_type())) << "Input must be a floating type"; - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); return bridge::AtenFromXlaTensor(tensor_methods::linalg_vector_norm( - self_tensor, ord, + xla_self, ord, dim ? torch::lazy::ToVector(*dim) : torch::lazy::Iota( - self_tensor->shape().get().dimensions_size()), + xla_self->shape().get().dimensions_size()), keepdim, dtype)); } @@ -4428,34 +4511,34 @@ at::Tensor XLANativeFunctions::as_strided( const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride, std::optional storage_offset) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); auto xsize = XlaHelpers::I64List(size); auto xstride = XlaHelpers::I64List(stride); - if (!AsStrided::StrideIsSupported(self_tensor->shape(), xsize, xstride, + if (!AsStrided::StrideIsSupported(xla_self->shape(), xsize, xstride, storage_offset.value_or(0))) { return at::native::call_fallback_fn< &xla_fallback, ATEN_OP(as_strided)>::call(self, size, stride, storage_offset); } - return bridge::AtenFromXlaTensor(tensor_methods::as_strided( - self_tensor, std::move(xsize), std::move(xstride), - XlaHelpers::I64Optional(storage_offset))); + return bridge::AtenFromXlaTensor( + tensor_methods::as_strided(xla_self, std::move(xsize), std::move(xstride), + XlaHelpers::I64Optional(storage_offset))); } const at::Tensor& XLANativeFunctions::as_strided_( const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride, std::optional storage_offset) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); auto xsize = XlaHelpers::I64List(size); auto xstride = XlaHelpers::I64List(stride); - if (!AsStrided::StrideIsSupported(self_tensor->shape(), xsize, xstride, + if (!AsStrided::StrideIsSupported(xla_self->shape(), xsize, xstride, storage_offset.value_or(0))) { return at::native::call_fallback_fn< &xla_fallback, ATEN_OP(as_strided_)>::call(self, size, stride, storage_offset); } - tensor_methods::as_strided_(self_tensor, std::move(xsize), std::move(xstride), + tensor_methods::as_strided_(xla_self, std::move(xsize), std::move(xstride), XlaHelpers::I64Optional(storage_offset)); return self; } @@ -4463,8 +4546,9 @@ const at::Tensor& XLANativeFunctions::as_strided_( at::Tensor XLANativeFunctions::diagonal(const at::Tensor& self, int64_t offset, int64_t dim1, int64_t dim2) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::diagonal( - GetValueOrThrow(bridge::GetXlaTensor(self)), offset, dim1, dim2)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor( + tensor_methods::diagonal(xla_self, offset, dim1, dim2)); } at::Tensor XLANativeFunctions::expand_symint(const at::Tensor& self, @@ -4472,15 +4556,15 @@ at::Tensor XLANativeFunctions::expand_symint(const at::Tensor& self, bool implicit) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); std::optional size = c10::asIntArrayRefSlowOpt(sym_size); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); if (size.has_value()) { - return bridge::AtenFromXlaTensor( - tensor_methods::expand(GetValueOrThrow(bridge::GetXlaTensor(self)), - torch::lazy::ToVector(*size))); + return bridge::AtenFromXlaTensor(tensor_methods::expand( + xla_self, torch::lazy::ToVector(*size))); } else { // at least one of the dimension is symbolic, use the sym_int version of the // node - return bridge::AtenFromXlaTensor(tensor_methods::expand_symint( - GetValueOrThrow(bridge::GetXlaTensor(self)), sym_size)); + return bridge::AtenFromXlaTensor( + tensor_methods::expand_symint(xla_self, sym_size)); } } @@ -4491,8 +4575,9 @@ at::Tensor XLANativeFunctions::view_symint(const at::Tensor& self, // support dynamic shape. auto size = C10_AS_INTARRAYREF_SLOW(sym_size); TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::view( - GetValueOrThrow(bridge::GetXlaTensor(self)), XlaHelpers::I64List(size))); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + return bridge::AtenFromXlaTensor( + tensor_methods::view(xla_self, XlaHelpers::I64List(size))); } } // namespace torch_xla From 1bc7737381e49ceea01629ae97d110c31441327c Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 28 Aug 2025 12:38:52 -0300 Subject: [PATCH 075/133] Remove CUDA plugin. (#9597) This PR removes the CUDA plugin in `plugins/cuda` directory, as well as explicit mentions and usages. This is in line with the CUDA deprecation that started on release 2.8. There are still a few mentions to `torch_xla_cuda_plugin` at `.github/ci.md` markdown file. But, I will leave those for a future PR, since we would need to re-write those parts. --- .github/scripts/run_tests.sh | 9 +-- infra/ansible/Dockerfile | 1 - infra/ansible/e2e_tests.Dockerfile | 1 - infra/ansible/playbook.yaml | 13 +--- .../roles/build_plugin/tasks/main.yaml | 32 ---------- plugins/cuda/README.md | 39 ------------ plugins/cuda/pyproject.toml | 18 ------ plugins/cuda/setup.py | 17 ----- .../cuda/torch_xla_cuda_plugin/__init__.py | 62 ------------------- 9 files changed, 2 insertions(+), 190 deletions(-) delete mode 100644 infra/ansible/roles/build_plugin/tasks/main.yaml delete mode 100644 plugins/cuda/README.md delete mode 100644 plugins/cuda/pyproject.toml delete mode 100644 plugins/cuda/setup.py delete mode 100644 plugins/cuda/torch_xla_cuda_plugin/__init__.py diff --git a/.github/scripts/run_tests.sh b/.github/scripts/run_tests.sh index ccdc0b5e3d70..7ae422c47953 100755 --- a/.github/scripts/run_tests.sh +++ b/.github/scripts/run_tests.sh @@ -30,14 +30,7 @@ function run_torch_xla_cpp_tests() { TORCH_DIR=$(python -c "import pkgutil; import os; print(os.path.dirname(pkgutil.get_loader('torch').get_filename()))") export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${TORCH_DIR}/lib - if [ -x "$(command -v nvidia-smi)" ]; then - CUDA_PLUGIN_DIR=$(python -c "import pkgutil; import os; print(os.path.dirname(pkgutil.get_loader('torch_xla_cuda_plugin').get_filename()))") - export PJRT_LIBRARY_PATH=$CUDA_PLUGIN_DIR/lib/pjrt_c_api_gpu_plugin.so - export PJRT_DEVICE=LIBRARY - export PJRT_DYNAMIC_PLUGINS=1 - else - export PJRT_DEVICE=CPU - fi + export PJRT_DEVICE=CPU export XLA_EXPERIMENTAL="nonzero:masked_select:nms" test_names=("test_aten_xla_tensor_1" diff --git a/infra/ansible/Dockerfile b/infra/ansible/Dockerfile index 278b290fec43..3875442e3747 100644 --- a/infra/ansible/Dockerfile +++ b/infra/ansible/Dockerfile @@ -10,7 +10,6 @@ COPY . /ansible ARG ansible_vars # HACK: install build dependencies only, but skip build step RUN ansible-playbook -vvv playbook.yaml -e "stage=build" -e "${ansible_vars}" --tags "bazel,configure_env,install_deps" -RUN ansible-playbook -vvv playbook.yaml -e "stage=build_plugin" -e "${ansible_vars}" RUN ansible-playbook -vvv playbook.yaml -e "stage=build" -e "${ansible_vars}" --skip-tags=fetch_srcs,install_deps FROM python:${python_version}-${debian_version} AS release diff --git a/infra/ansible/e2e_tests.Dockerfile b/infra/ansible/e2e_tests.Dockerfile index be8c0d11f144..2a097e803f0c 100644 --- a/infra/ansible/e2e_tests.Dockerfile +++ b/infra/ansible/e2e_tests.Dockerfile @@ -10,7 +10,6 @@ COPY . /ansible # Build PyTorch and PyTorch/XLA wheels. ARG ansible_vars RUN ansible-playbook -vvv playbook.yaml -e "stage=build" -e "${ansible_vars}" -RUN ansible-playbook -vvv playbook.yaml -e "stage=build_plugin" -e "${ansible_vars}" --skip-tags=fetch_srcs,install_deps FROM python:${python_version}-${debian_version} diff --git a/infra/ansible/playbook.yaml b/infra/ansible/playbook.yaml index 3f3c736d133a..7626714e8d18 100644 --- a/infra/ansible/playbook.yaml +++ b/infra/ansible/playbook.yaml @@ -16,7 +16,7 @@ "Pass the required variable with: --e \"{{ item.name }}=\"" loop: - name: stage - pattern: ^(build|build_plugin|release)$ + pattern: ^(build|release)$ - name: arch pattern: ^(aarch64|amd64)$ - name: accelerator @@ -88,17 +88,6 @@ when: stage == "build" tags: build_srcs - - role: build_plugin - vars: - src_root: "/src" - env_vars: "{{ - build_env.common | default({}, true) | - combine(build_env[arch] | default({}, true)) | - combine(build_env[accelerator] | default({}, true)) - }}" - when: stage == "build_plugin" - tags: build_plugin - - role: configure_env vars: env_vars: "{{ diff --git a/infra/ansible/roles/build_plugin/tasks/main.yaml b/infra/ansible/roles/build_plugin/tasks/main.yaml deleted file mode 100644 index 142d29c3718f..000000000000 --- a/infra/ansible/roles/build_plugin/tasks/main.yaml +++ /dev/null @@ -1,32 +0,0 @@ -- name: Create /dist directory for exported wheels - ansible.builtin.file: - path: /dist - state: directory - mode: '0755' - -- name: Build PyTorch/XLA CUDA Plugin - ansible.builtin.command: - cmd: pip wheel -w /dist plugins/cuda -v - chdir: "{{ (src_root, 'pytorch/xla') | path_join }}" - environment: "{{ env_vars }}" - when: accelerator == "cuda" - -- name: Find CUDA plugin wheel pytorch/xla/dist - ansible.builtin.find: - path: "/dist" - pattern: "torch_xla_cuda_plugin*.whl" - when: accelerator == "cuda" - register: plugin_wheels - -- name: Install CUDA plugin wheels - ansible.builtin.pip: - name: "{{ plugin_wheels.files | map(attribute='path') }}" - state: "forcereinstall" - when: accelerator == "cuda" - -# TODO: Pass libtpu to next release stage somehow. This only runs during build -- name: Install libtpu - ansible.builtin.pip: - name: torch_xla[tpu] - extra_args: -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html - when: accelerator == "tpuvm" diff --git a/plugins/cuda/README.md b/plugins/cuda/README.md deleted file mode 100644 index 45a002e06f6c..000000000000 --- a/plugins/cuda/README.md +++ /dev/null @@ -1,39 +0,0 @@ -# CUDA PJRT plugin (experimental) - -This directory contains an experimental implementation of the PJRT GPU client as -a plugin. The actual implementation of the PJRT C API lives in the main OpenXLA -repository (see `bazel build` command below). - -## Building - -See our [contributing guide](../../CONTRIBUTING.md) for build environment setup -steps. - -```bash -# Build wheel -pip wheel plugins/cuda -v -# Or install directly -pip install plugins/cuda -v -``` - -## Usage - -```python -import os - -# Log device type -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' -os.environ['TF_CPP_VMODULE'] = 'pjrt_registry=5' - -from torch_xla.experimental import plugins -import torch_xla_cuda_plugin -import torch_xla.core.xla_model as xm -import torch_xla.runtime as xr - -# Use dynamic plugin instead of built-in CUDA support -plugins.use_dynamic_plugins() -plugins.register_plugin('CUDA', torch_xla_cuda_plugin.CudaPlugin()) -xr.set_device_type('CUDA') - -print(torch_xla.device()) -``` diff --git a/plugins/cuda/pyproject.toml b/plugins/cuda/pyproject.toml deleted file mode 100644 index d44a2ea3bd53..000000000000 --- a/plugins/cuda/pyproject.toml +++ /dev/null @@ -1,18 +0,0 @@ -[build-system] -requires = ["setuptools", "numpy"] -build-backend = "setuptools.build_meta" - -[project] -name = "torch_xla_cuda_plugin" -authors = [ - {name = "PyTorch/XLA Dev Team", email = "pytorch-xla@googlegroups.com"}, -] -description = "PyTorch/XLA CUDA Plugin" -requires-python = ">=3.8" -dynamic = ["version"] - -[tool.setuptools.package-data] -torch_xla_cuda_plugin = ["lib/*.so"] - -[project.entry-points."torch_xla.plugins"] -cuda = "torch_xla_cuda_plugin:CudaPlugin" diff --git a/plugins/cuda/setup.py b/plugins/cuda/setup.py deleted file mode 100644 index 2652880c6fd7..000000000000 --- a/plugins/cuda/setup.py +++ /dev/null @@ -1,17 +0,0 @@ -import datetime -import os -import sys - -# add `build_util` to import path -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) - -import build_util -import setuptools - -build_util.bazel_build('@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so', - 'torch_xla_cuda_plugin/lib', ['--config=cuda']) - -setuptools.setup( - # TODO: Use a common version file - version=os.getenv('TORCH_XLA_VERSION', - f'2.8.0.dev{datetime.date.today().strftime("%Y%m%d")}')) diff --git a/plugins/cuda/torch_xla_cuda_plugin/__init__.py b/plugins/cuda/torch_xla_cuda_plugin/__init__.py deleted file mode 100644 index e6863ff711a1..000000000000 --- a/plugins/cuda/torch_xla_cuda_plugin/__init__.py +++ /dev/null @@ -1,62 +0,0 @@ -import os -from torch_xla.experimental import plugins -import torch_xla.utils.utils as xu - - -class CudaPlugin(plugins.DevicePlugin): - - def _get_process_rank(self) -> int: - local_process_rank = xu.getenv_as("PJRT_LOCAL_PROCESS_RANK", int, - xu.getenv_as("LOCAL_RANK", int, 0)) - global_process_rank = xu.getenv_as("RANK", int, local_process_rank) - - return local_process_rank, global_process_rank - - def _get_world_size(self) -> int: - local_world_size = xu.getenv_as("PJRT_LOCAL_PROCESS_COUNT", int, - xu.getenv_as("LOCAL_WORLD_SIZE", int, 1)) - global_world_size = xu.getenv_as("WORLD_SIZE", int, local_world_size) - - return local_world_size, global_world_size - - def library_path(self) -> str: - return os.path.join( - os.path.dirname(__file__), 'lib', 'pjrt_c_api_gpu_plugin.so') - - def physical_chip_count(self) -> int: - # TODO: default to actual device count - return xu.getenv_as('GPU_NUM_DEVICES', int, 1) - - def configure_single_process(self): - pass - - def client_create_options(self) -> dict: - local_process_rank, global_process_rank = self._get_process_rank() - local_world_size, global_world_size = self._get_world_size() - - # The available options are defined in OpenXLA: https://github.com/openxla/xla/blob/1bb2a74be91fabf5f9aa2702b2592b5b022c9052/xla/pjrt/c/pjrt_c_api_gpu_internal.cc#L58-L67 - options = { - "platform_name": - "gpu", - # TODO(wcromar): make this configurable - "allocator": - "cuda_async" if xu.getenv_as("PJRT_ALLOCATOR_CUDA_ASYNC", bool, - False) else "default", - "memory_fraction": - xu.getenv_as("PJRT_ALLOCATOR_FRACTION", float, None), - "preallocate": - xu.getenv_as("PJRT_ALLOCATOR_PREALLOCATE", bool, None), - # Use all devices by default and when using SPMD - "visible_devices": [local_process_rank] - if local_world_size > 1 else None, - "node_id": - global_process_rank, - "num_nodes": - global_world_size, - } - - return {k: v for k, v in options.items() if v is not None} - - def requires_xla_coordinator(self) -> bool: - _, global_world_size = self._get_world_size() - return global_world_size > 1 From d4cf42a8e432633279808d0eca9f133ce0f938d4 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 28 Aug 2025 14:32:58 -0300 Subject: [PATCH 076/133] Remove triton. (#9601) This PR removes all triton code (e.g. test and implementation) from the Pytorch/XLA codebase. This is in line with the CUDA deprecation that started on release 2.8. **Key Changes:** - (`docs/source/features/triton.md`) Remove the documentation - (`test/test_triton.py`) Remove the triton tests - (`torch_xla/experimental/triton.py`) Remove the triton implementation --- README.md | 2 +- docs/source/features/triton.md | 75 ------- docs/source/index.rst | 1 - test/test_triton.py | 336 ------------------------------- torch_xla/experimental/triton.py | 232 --------------------- 5 files changed, 1 insertion(+), 645 deletions(-) delete mode 100644 docs/source/features/triton.md delete mode 100644 test/test_triton.py delete mode 100644 torch_xla/experimental/triton.py diff --git a/README.md b/README.md index 11d5a5712560..989858ef16fd 100644 --- a/README.md +++ b/README.md @@ -97,7 +97,7 @@ Our github contains many useful docs on working with different aspects of PyTorc - [docs/source/learn](https://github.com/pytorch/xla/tree/master/docs/source/learn): docs for learning concepts associated with XLA, troubleshooting, pjrt, eager mode, and dynamic shape. - [docs/source/accelerators](https://github.com/pytorch/xla/tree/master/docs/source/accelerators): references to `GPU` and `TPU` accelerator documents. - [docs/source/perf](https://github.com/pytorch/xla/tree/master/docs/source/perf): documentation about performance specific aspects of PyTorch/XLA such as: `AMP`, `DDP`, `Dynamo`, Fori loop, `FSDP`, quantization, recompilation, and `SPMD` -- [docs/source/features](https://github.com/pytorch/xla/tree/master/docs/source/features): documentation on distributed torch, pallas, scan, stable hlo, and triton. +- [docs/source/features](https://github.com/pytorch/xla/tree/master/docs/source/features): documentation on distributed torch, pallas, scan, and stable hlo. - [docs/source/contribute](https://github.com/pytorch/xla/tree/master/docs/source/contribute): documents on setting up PyTorch for development, and guides for lowering operations. - PJRT plugins: - [CPU](https://github.com/pytorch/xla/blob/master/plugins/cpu/README.md) diff --git a/docs/source/features/triton.md b/docs/source/features/triton.md deleted file mode 100644 index 33bf1a4d5861..000000000000 --- a/docs/source/features/triton.md +++ /dev/null @@ -1,75 +0,0 @@ -# Custom GPU Kernels via Triton - -PyTorch/XLA now supports [Triton](https://openai.com/research/triton) -kernels, enabling high-performance deep learning model execution on -GPUs. Triton, a specialized language and compiler for GPU programming, -empowers developers to write custom kernels that leverage the full -potential of GPUs for various operations in deep learning models. - -Given a Triton kernel defined as follows: - -``` python3 -@triton.jit -def add_kernel( - x_ptr, # *Pointer* to first input vector. - y_ptr, # *Pointer* to second input vector. - output_ptr, # *Pointer* to output vector. - n_elements, # Size of the vector. - BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. - # NOTE: `constexpr` so it can be used as a shape value. -): - # Triton add kernel from https://github.com/openai/triton/blob/main/python/tutorials/01-vector-add.py#L28 - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - x = tl.load(x_ptr + offsets, mask=mask) - y = tl.load(y_ptr + offsets, mask=mask) - output = x + y - tl.store(output_ptr + offsets, output, mask=mask) -``` - -We can run make this kernel a part of the PyTorch/XLA execution graph as -follows: - -``` python3 -import torch - -import torch_xla.experimental.triton as xla_triton -import torch_xla - -import triton -import triton.language as tl - -size = 16 -x = torch.arange(size, dtype=torch.int64).to('xla') -y = torch.arange(size, dtype=torch.int64).to('xla') -output = torch.empty_like(x) -block_size = 8 -grid = (triton.cdiv(size, block_size),) - -# triton_call takes the same arguments as the triton.jit function, in addition -# to the kernel itself and the grid that is used to execute the kernel. -# All the tl.constexpr terms are passed as kwargs at the end. -payload = xla_triton.triton_call( - x, y, output, size, kernel=add_kernel, grid=grid, BLOCK_SIZE=block_size) - -# To make the triton kernel, a part of the PyTorch/XLA graph, we create a -# custom call node with the expected inputs, payload from triton_call, -# the output shapes and output dtypes. The payload already contains information -# regarding how the GPU buffers will be loaded when this node is executed. -output = torch_xla._XLAC._xla_gpu_custom_call([x, y], payload, - [output.shape], [torch.int64]) -``` - -For more complex kernels, you can also refer to the Triton Flash -Attention kernel test in PyTorch/XLA. - -## Dependencies - -The Triton integration depends on the `triton` package to function. This -code is tested with `triton==2.3.0`. To install: - -``` bash -pip install --no-deps triton==2.3.0 -``` diff --git a/docs/source/index.rst b/docs/source/index.rst index f3271936ef5d..acc92900fe43 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -90,7 +90,6 @@ Tutorials :caption: Training on GPU accelerators/gpu - features/triton perf/spmd_gpu .. toctree:: diff --git a/test/test_triton.py b/test/test_triton.py deleted file mode 100644 index f69def68c86d..000000000000 --- a/test/test_triton.py +++ /dev/null @@ -1,336 +0,0 @@ -import logging -import torch -from torch import nn as nn -import unittest - -import torch_xla.experimental.triton as xla_triton -import torch_xla -from torch_xla import runtime as xr -from torch_xla.test.test_utils import skipIfCUDA - -import triton -import triton.language as tl - - -@triton.jit -def add_kernel( - x_ptr, # *Pointer* to first input vector. - y_ptr, # *Pointer* to second input vector. - output_ptr, # *Pointer* to output vector. - n_elements, # Size of the vector. - BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. - # NOTE: `constexpr` so it can be used as a shape value. -): - # Triton add kernel from https://github.com/openai/triton/blob/main/python/tutorials/01-vector-add.py#L28 - # There are multiple 'programs' processing different data. We identify which program - # we are here: - pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. - # This program will process inputs that are offset from the initial data. - # For instance, if you had a vector of length 256 and block_size of 64, the programs - # would each access the elements [0:64, 64:128, 128:192, 192:256]. - # Note that offsets is a list of pointers: - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - # Create a mask to guard memory operations against out-of-bounds accesses. - mask = offsets < n_elements - # Load x and y from DRAM, masking out any extra elements in case the input is not a - # multiple of the block size. - x = tl.load(x_ptr + offsets, mask=mask) - y = tl.load(y_ptr + offsets, mask=mask) - output = x + y - # Write x + y back to DRAM. - tl.store(output_ptr + offsets, output, mask=mask) - - -@triton.jit -def _attn_fwd_inner( - acc, - l_i, - m_i, - q, # - K_block_ptr, - V_block_ptr, # - start_m, - qk_scale, # - BLOCK_M: tl.constexpr, - HEAD_DIM: tl.constexpr, - BLOCK_N: tl.constexpr, # - STAGE: tl.constexpr, - offs_m: tl.constexpr, - offs_n: tl.constexpr, # - N_CTX: tl.constexpr, - fp8_v: tl.constexpr): - # range of values handled by this stage - if STAGE == 1: - lo, hi = 0, start_m * BLOCK_M - elif STAGE == 2: - lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M - lo = tl.multiple_of(lo, BLOCK_M) - # causal = False - else: - lo, hi = 0, N_CTX - K_block_ptr = tl.advance(K_block_ptr, (0, lo)) - V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) - # loop over k, v and update accumulator - for start_n in range(lo, hi, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(K_block_ptr) - qk = tl.dot(q, k) - if STAGE == 2: - mask = offs_m[:, None] >= (start_n + offs_n[None, :]) - qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk -= m_ij[:, None] - else: - m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) - qk = qk * qk_scale - m_ij[:, None] - p = tl.math.exp2(qk) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - alpha = tl.math.exp2(m_i - m_ij) - l_i = l_i * alpha + l_ij - # -- update output accumulator -- - acc = acc * alpha[:, None] - # update acc - v = tl.load(V_block_ptr) - if fp8_v: - p = p.to(tl.float8e5) - else: - p = p.to(tl.float16) - acc = tl.dot(p, v, acc) - # update m_i and l_i - m_i = m_ij - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - return acc, l_i, m_i - - -@triton.jit -def _attn_fwd( - Q, - K, - V, - sm_scale, - M, - Out, # - stride_qz, - stride_qh, - stride_qm, - stride_qk, # - stride_kz, - stride_kh, - stride_kn, - stride_kk, # - stride_vz, - stride_vh, - stride_vk, - stride_vn, # - stride_oz, - stride_oh, - stride_om, - stride_on, # - Z, - H, - N_CTX, # - BLOCK_M: tl.constexpr, # - BLOCK_N: tl.constexpr, # - HEAD_DIM: tl.constexpr, # - STAGE: tl.constexpr # -): - tl.static_assert(BLOCK_N <= HEAD_DIM) - start_m = tl.program_id(0) - off_hz = tl.program_id(1) - off_z = off_hz // H - off_h = off_hz % H - qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh - - # block pointers - Q_block_ptr = tl.make_block_ptr( - base=Q + qvk_offset, - shape=(N_CTX, HEAD_DIM), - strides=(stride_qm, stride_qk), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, HEAD_DIM), - order=(1, 0), - ) - V_block_ptr = tl.make_block_ptr( - base=V + qvk_offset, - shape=(N_CTX, HEAD_DIM), - strides=(stride_vk, stride_vn), - offsets=(0, 0), - block_shape=(BLOCK_N, HEAD_DIM), - order=(1, 0), - ) - K_block_ptr = tl.make_block_ptr( - base=K + qvk_offset, - shape=(HEAD_DIM, N_CTX), - strides=(stride_kk, stride_kn), - offsets=(0, 0), - block_shape=(HEAD_DIM, BLOCK_N), - order=(0, 1), - ) - O_block_ptr = tl.make_block_ptr( - base=Out + qvk_offset, - shape=(N_CTX, HEAD_DIM), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, HEAD_DIM), - order=(1, 0), - ) - # initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 - acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) - # load scales - qk_scale = sm_scale - qk_scale *= 1.44269504 # 1/log(2) - # load q: it will stay in SRAM throughout - q = tl.load(Q_block_ptr) - # stage 1: off-band - # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE - # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE - if STAGE & 1: - acc, l_i, m_i = _attn_fwd_inner( - acc, - l_i, - m_i, - q, - K_block_ptr, - V_block_ptr, # - start_m, - qk_scale, # - BLOCK_M, - HEAD_DIM, - BLOCK_N, # - 4 - STAGE, - offs_m, - offs_n, - N_CTX, - V.dtype.element_ty == tl.float8e5 # - ) - # stage 2: on-band - if STAGE & 2: - # barrier makes it easier for compielr to schedule the - # two loops independently - acc, l_i, m_i = _attn_fwd_inner( - acc, - l_i, - m_i, - q, - K_block_ptr, - V_block_ptr, # - start_m, - qk_scale, # - BLOCK_M, - HEAD_DIM, - BLOCK_N, # - 2, - offs_m, - offs_n, - N_CTX, - V.dtype.element_ty == tl.float8e5 # - ) - # epilogue - m_i += tl.math.log2(l_i) - acc = acc / l_i[:, None] - m_ptrs = M + off_hz * N_CTX + offs_m - tl.store(m_ptrs, m_i) - tl.store(O_block_ptr, acc.to(Out.type.element_ty)) - - -class TritonTest(unittest.TestCase): - - @unittest.skipIf(xr.device_type() != 'CUDA', "This test only works on GPU.") - def test_gpu_custom_call_triton_add(self): - size = 16 - - x = torch.arange(size, dtype=torch.int64).to('xla') - y = torch.arange(size, dtype=torch.int64).to('xla') - output = torch.empty_like(x) - block_size = 8 - grid = (triton.cdiv(size, block_size),) - payload = xla_triton.triton_call( - x, y, output, size, kernel=add_kernel, grid=grid, BLOCK_SIZE=block_size) - output = torch_xla._XLAC._xla_gpu_custom_call([x, y], payload, - [output.shape], [torch.int64]) - output_torch = x + y - self.assertTrue(torch.allclose(output[0].cpu(), output_torch.cpu())) - - @unittest.skipIf(xr.device_type() != 'CUDA', "This test only works on GPU.") - def test_gpu_custom_call_triton_flash_attention(self): - torch.manual_seed(20) - Z, H, N_CTX, HEAD_DIM = (1, 2, 1024, 64) - causal = False - stage = 3 if causal else 1 - dtype = torch.float16 - q = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device='xla') - k = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device='xla') - v = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device='xla') - sm_scale = 0.5 - # reference implementation - triangle = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) - p = torch.matmul(q, k.transpose(2, 3)) * sm_scale - if causal: - p[:, :, triangle == 0] = float("-inf") - p = torch.softmax(p.float(), dim=-1).half() - - ref_out = torch.matmul(p, v) - # triton implementation - - o = torch.empty_like(q) - M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), - device=q.device, - dtype=torch.float32) - BLOCK_N = 32 - BLOCK_M = 64 - grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * - q.shape[1], 1) - payload = xla_triton.triton_call( - q, - k, - v, - sm_scale, - M, - o, - q.stride(0), - q.stride(1), - q.stride(2), - q.stride(3), # - k.stride(0), - k.stride(1), - k.stride(2), - k.stride(3), # - v.stride(0), - v.stride(1), - v.stride(2), - v.stride(3), # - o.stride(0), - o.stride(1), - o.stride(2), - o.stride(3), # - q.shape[0], - q.shape[1], - q.shape[2], - kernel=_attn_fwd, - grid=grid, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - HEAD_DIM=HEAD_DIM, - STAGE=stage) - - output = torch_xla._XLAC._xla_gpu_custom_call([q, k, v, M], payload, - [o.shape], [torch.float16]) - # compare - assert torch.allclose(ref_out, output[0], atol=1e-2, rtol=0) - - -if __name__ == '__main__': - logging.getLogger().setLevel(logging.INFO) - torch.set_default_dtype(torch.float32) - torch.manual_seed(42) - test = unittest.main() - sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/experimental/triton.py b/torch_xla/experimental/triton.py deleted file mode 100644 index 40410d6c0199..000000000000 --- a/torch_xla/experimental/triton.py +++ /dev/null @@ -1,232 +0,0 @@ -"""Module for calling Triton kernels from Pytorch/XLA. - -Reference: https://github.com/jax-ml/jax-triton/blob/main/jax_triton/triton_lib.py - -""" - -from __future__ import annotations - -import os -from typing import Any, Callable, Dict, Tuple, Union -import zlib -import torch - -import numpy as np -import triton -import triton.language as tl -from jax._src.lib import gpu_triton as lib_triton -import torch_xla - -# Register target corresponding to gpu custom call using the -# implementation provided by jaxlib. -torch_xla._XLAC._xla_register_custom_call_target( - 'triton_kernel_call', lib_triton._cuda_triton.get_custom_call(), 'CUDA') - -Grid = Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]] -GridOrLambda = Union[Grid, Callable[[Dict[str, Any]], Grid]] - -NUM_WARPS = 4 -NUM_STAGES = 3 -NUM_CTAS = 1 - - -def normalize_grid(grid: GridOrLambda, metaparams) -> Tuple[int, int, int]: - if callable(grid): - grid = grid(metaparams) - if isinstance(grid, int): - grid = (grid,) - elif len(grid) > 3: - raise ValueError("`grid` should have three or fewer dimensions.") - return tuple(grid) + (1,) * (3 - len(grid)) - - -_TORCH_TO_TRITON_TYPE_MAP = { - torch.bfloat16: - "bf16", - torch.float64: - "fp64", - torch.float32: - "fp32", - torch.float16: - "fp16", - # Triton has 'fp8' as well which Jax doesn't support yet. - torch.int64: - "i64", - torch.int32: - "i32", - torch.int16: - "i16", - torch.int8: - "i8", - torch.uint64: - "u64", - torch.uint32: - "u32", - torch.uint16: - "u16", - torch.uint8: - "u8", - # Triton defines a 'B' type, which is an alias for both i1 and bool. - torch.bool: - "B", -} - - -def get_triton_type(obj: Any) -> str: - if torch.is_tensor(obj): - return f"*{_TORCH_TO_TRITON_TYPE_MAP[obj.dtype]}" - if isinstance(obj, tl.constexpr): - obj = obj.value - if isinstance(obj, int): - if -(2**31) <= obj < 2**31: - return "i32" - elif 2**31 <= obj < 2**32: - return "u32" - elif -(2**63) <= obj < 2**63: - return "i64" - elif 2**63 <= obj < 2**64: - return "u64" - else: - raise ValueError(f"integer overflow representing {obj}") - if isinstance(obj, float): - return "fp64" - if isinstance(obj, np.float32): - return "fp32" - if isinstance(obj, bool): - return "B" - if isinstance(obj, str): - return "str" - raise NotImplementedError( - f"could not compute type name for {obj}: {type(obj)}") - - -def get_or_create_triton_kernel( - fn, - compiled_kernel, - args, - dump, -) -> Tuple[lib_triton.TritonKernel, Any]: - # Extract the compilation parameters and compiled ptx from the - # compiled triton kernel. - ttir = compiled_kernel.asm['ttir'] - ptx = compiled_kernel.asm['ptx'] - if (dump): - print(ptx) - - shared_mem_bytes = compiled_kernel.metadata.shared - kernel_name = compiled_kernel.metadata.name - cluster_dims = compiled_kernel.metadata.cluster_dims - compute_capability = lib_triton.get_compute_capability(0) - kernel = lib_triton.TritonKernel( - kernel_name, - NUM_WARPS, - shared_mem_bytes, - ptx, - ttir, - compute_capability, - *cluster_dims, - ) - - return kernel - - -# Taken from: https://github.com/triton-lang/triton/blob/da40a1e984bf57c4708daf603eb427442025f99b/python/triton/runtime/jit.py#L187-L198 -# Newer triton versions removed this function. -def _spec_and_divisible_by_16(fn, i, arg): - if i in fn.do_not_specialize: - return False - - if hasattr(arg, "data_ptr"): - return arg.data_ptr() % 16 == 0 - if isinstance(arg, int): - return arg % 16 == 0 - - return arg is None - - -# Taken from: https://github.com/triton-lang/triton/blob/da40a1e984bf57c4708daf603eb427442025f99b/python/triton/runtime/jit.py#L187-L198 -# Newer triton versions removed this function. -def _spec_and_equals_1(fn, i, arg): - if i in fn.do_not_specialize: - return False - return not isinstance(arg, bool) and isinstance(arg, int) and arg == 1 - - -def triton_kernel_call_lowering( - array_args, - fn, - compiled_kernel, - scalar_args, - grid, - debug, - **metaparams, -): - args = list(array_args) - arg_dtypes = list(map(get_triton_type, array_args)) - for idx, dtype, v in scalar_args: - args.insert(idx, v) - arg_dtypes.insert(idx, dtype) - - if not isinstance(fn, triton.JITFunction): - raise ValueError("`kernel` must be a Triton `JITFunction`.") - - #TODO: Add support for autotuner and heuristic functions. - config = triton.Config( - {}, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES, - num_ctas=NUM_CTAS, - ) - config_metaparams = {**metaparams, **config.kwargs} - config_grid = normalize_grid(grid, config_metaparams) - - kernel = get_or_create_triton_kernel( - fn, - compiled_kernel, - args, - dump=debug, - ) - - kernel_params = [] - for i, (arg, dtype) in enumerate(zip(args, arg_dtypes)): - if isinstance(arg, torch.Tensor): - kernel_params.append( - lib_triton.create_array_parameter( - 0, - 16 if _spec_and_divisible_by_16(fn, i, arg) else 0, - )) - elif not _spec_and_equals_1(fn, i, arg): - kernel_params.append(lib_triton.create_scalar_parameter(arg, dtype)) - - kernel_call = lib_triton.TritonKernelCall( - kernel, - config_grid[0], - config_grid[1], - config_grid[2], - kernel_params, - ) - - call_proto = kernel_call.to_proto("triton_kernel", b"") - return zlib.compress(call_proto) - - -def triton_call( - *args: Union[torch.Tensor, bool, int, float, np.float32], - kernel: triton.JITFunction, - grid: GridOrLambda, - debug: bool = False, - **metaparams: Any, -) -> Any: - array_args = [] - scalar_args = [] - for i, arg in enumerate(args): - if isinstance(arg, (bool, int, float)): - scalar_args.append((i, get_triton_type(arg), arg)) - elif isinstance(arg, np.float32): - scalar_args.append((i, get_triton_type(arg), float(arg))) - else: - array_args.append(arg) - - compiled_kernel = kernel.run(*args, grid=grid, warmup=True, **metaparams) - return triton_kernel_call_lowering(array_args, kernel, compiled_kernel, - scalar_args, grid, debug, **metaparams) From f5a22187d9535fe62c8f7103e8402a8df92c908a Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 28 Aug 2025 15:21:31 -0300 Subject: [PATCH 077/133] `torch_xla`: Use new macros for throwing exceptions (part 1). (#9593) Follow-up: #9588 and #9580 Target: `torch_xla/csrc` directory In summary, this PR: - Replaces all calls to `OkOrThrow()` and `GetValueOrThrow()` (that throws an exception without source location information of the *"throw-site"*) with the macros `XLA_THROW_IF_ERROR()` and `XLA_ASSIGN_OR_THROW()`. - Corresponds to the fine-grained set of PRs that came from breaking down PR #9580 - Focuses on the `torch_xla/csrc` directory, replacing every use of those, now deprecated, functions by the newly introduced macros. _Note: since there were lots of files in `torch_xla/csrc` that needed update, they were split in multiple parts._ --- torch_xla/csrc/aten_xla_bridge.cpp | 18 ++++++++++++------ torch_xla/csrc/debug_util.cpp | 6 +++--- torch_xla/csrc/dl_convertor.cpp | 11 ++++++----- torch_xla/csrc/helpers.cpp | 11 +++++++---- torch_xla/csrc/lowering_context.cpp | 2 +- torch_xla/csrc/tensor.cpp | 8 +++++--- torch_xla/csrc/xla_backend_impl.cpp | 7 +++++-- torch_xla/csrc/xla_manual_registration.cpp | 4 ++-- torch_xla/csrc/xla_op_builder.cpp | 12 ++++++------ torch_xla/csrc/xla_sharding_util.cpp | 2 +- 10 files changed, 48 insertions(+), 33 deletions(-) diff --git a/torch_xla/csrc/aten_xla_bridge.cpp b/torch_xla/csrc/aten_xla_bridge.cpp index 8bc0cc32a615..8af6f5816756 100644 --- a/torch_xla/csrc/aten_xla_bridge.cpp +++ b/torch_xla/csrc/aten_xla_bridge.cpp @@ -170,7 +170,9 @@ torch_xla::XLATensorPtr GetXlaTensorOrCreateForWrappedNumber( (tensor.dim() == 0 && tensor.numel() == 1)) { return torch_xla::bridge::GetOrCreateXlaTensor(tensor, device); } else { - return GetValueOrThrow(torch_xla::bridge::GetXlaTensor(tensor)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_tensor, + torch_xla::bridge::GetXlaTensor(tensor)); + return xla_tensor; } } @@ -186,9 +188,13 @@ XLATensorPtr GetOrCreateXlaTensor(const at::Tensor& tensor, } auto xtensor = GetXlaTensor(tensor); - return xtensor.ok() - ? xtensor.value() - : GetValueOrThrow(XLATensor::Create(inner_tensor, device)); + if (xtensor.ok()) { + return xtensor.value(); + } + + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_tensor, + XLATensor::Create(inner_tensor, device)); + return xla_tensor; } XLATensorPtr GetOrCreateXlaTensor(const std::optional& tensor, @@ -479,8 +485,8 @@ at::Tensor CreateXlaTensor( at::Tensor tensor, const std::optional& device) { if (tensor.defined() && device) { - XLATensorPtr xla_tensor = - GetValueOrThrow(XLATensor::Create(std::move(tensor), *device)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_tensor, + XLATensor::Create(std::move(tensor), *device)); tensor = AtenFromXlaTensor(xla_tensor); } return tensor; diff --git a/torch_xla/csrc/debug_util.cpp b/torch_xla/csrc/debug_util.cpp index 812f7122efa3..81c76e386602 100644 --- a/torch_xla/csrc/debug_util.cpp +++ b/torch_xla/csrc/debug_util.cpp @@ -450,9 +450,9 @@ void DebugUtil::post_compilation_analysis( // This can be used to verify the hash of the underlying computation proto. // Note that for UserComputation computations, the protobuf is factored in // the graph hash. - std::string serialized_computation = - GetValueOrThrow(runtime::util::GetDeterministicSerializedModuleProto( - computation->computation().proto())); + XLA_ASSIGN_OR_THROW(std::string serialized_computation, + runtime::util::GetDeterministicSerializedModuleProto( + computation->computation().proto())); ss << "\n" << "Computation hash: " << torch::lazy::HashToString(torch::lazy::Hash(serialized_computation)) diff --git a/torch_xla/csrc/dl_convertor.cpp b/torch_xla/csrc/dl_convertor.cpp index 638bcdbff67b..c4f8fc38efa2 100644 --- a/torch_xla/csrc/dl_convertor.cpp +++ b/torch_xla/csrc/dl_convertor.cpp @@ -141,10 +141,10 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { DLTensor& dt = pack->tensor.dl_tensor; { // AcquireExternalReference may block - pack->external_reference = - GetValueOrThrow(pjrt_buffer->AcquireExternalReference()); + XLA_ASSIGN_OR_THROW(pack->external_reference, + pjrt_buffer->AcquireExternalReference()); xla::PjRtFuture<> future = pjrt_buffer->GetReadyFuture(); - OkOrThrow(future.Await()); + XLA_THROW_IF_ERROR(future.Await()); } pack->buffer_reference = pjrt_buffer; @@ -329,8 +329,9 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) { if (dlmt->deleter) { on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); }; } - std::unique_ptr pjrt_buffer = - GetValueOrThrow(device->client()->CreateViewOfDeviceBuffer( + XLA_ASSIGN_OR_THROW( + std::unique_ptr pjrt_buffer, + device->client()->CreateViewOfDeviceBuffer( static_cast(dlmt->dl_tensor.data) + dlmt->dl_tensor.byte_offset, shape, *device->default_memory_space(), on_delete_callback)); diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index 05d2a307becf..4808bf9db8a1 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -41,7 +41,8 @@ xla::XlaComputation CreateComputation( xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(type, {}), "x"); xla::XlaOp y = xla::Parameter(&builder, 1, xla::ShapeUtil::MakeShape(type, {}), "y"); - return GetValueOrThrow(builder.Build(op(x, y))); + XLA_ASSIGN_OR_THROW(xla::XlaComputation computation, builder.Build(op(x, y))); + return computation; } xla::XlaComputation CreateMinMaxComputation(const std::string& name, @@ -66,7 +67,8 @@ xla::XlaComputation CreateMinMaxComputation(const std::string& name, xla::XlaOp tie_id = xla::Min(lhs_index, rhs_index); arg_max = xla::Select(eq, tie_id, arg_max); xla::Tuple(&builder, {max, arg_max}); - return GetValueOrThrow(builder.Build()); + XLA_ASSIGN_OR_THROW(xla::XlaComputation min_max_computation, builder.Build()); + return min_max_computation; } } // namespace @@ -697,7 +699,8 @@ std::vector XlaHelpers::getBroadcastDimensions(xla::XlaOp op1, xla::Shape XlaHelpers::GetPromotedBinaryOpShape(const xla::Shape& shape1, const xla::Shape& shape2) { if (!shape1.is_dynamic() && !shape2.is_dynamic()) { - auto promoted_shape = GetValueOrThrow(GetPromotedShape(shape1, shape2)); + XLA_ASSIGN_OR_THROW(xla::Shape promoted_shape, + GetPromotedShape(shape1, shape2)); return xla::ShapeUtil::MakeShape( PromoteType(shape1.element_type(), shape2.element_type()), promoted_shape.dimensions()); @@ -776,7 +779,7 @@ std::pair XlaHelpers::PromoteShapes(xla::XlaOp op1, const xla::Shape& shape1 = ShapeHelper::ShapeOfXlaOp(op1); const xla::Shape& shape2 = ShapeHelper::ShapeOfXlaOp(op2); - xla::Shape shape = GetValueOrThrow(GetPromotedShape(shape1, shape2)); + XLA_ASSIGN_OR_THROW(xla::Shape shape, GetPromotedShape(shape1, shape2)); if (shape1.is_unbounded_dynamic() || shape2.is_unbounded_dynamic()) { return ImplicitBroadcastWithUnboundedDynamicShapes(op1, op2, shape); } diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index 65f438643f17..4e9c2fd013ba 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -325,7 +325,7 @@ void LoweringContext::AddParameter(const torch::lazy::Output& output, } torch::lazy::ComputationPtr LoweringContext::Build() { - xla::XlaComputation xla_computation = GetValueOrThrow(BuildXla()); + XLA_ASSIGN_OR_THROW(xla::XlaComputation xla_computation, BuildXla()); return std::make_shared( builder_.name(), std::move(xla_computation), device_); } diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 106b2603e843..d80b27b16e14 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -518,8 +518,8 @@ at::Tensor XLATensor::ToTensor(bool detached) { XLAGraphExecutor::Get()->DeviceBarrier(GetDevice()); // The GetXlaData() call will trigger an ApplyPendingGraph() if an IR // XlaNode is available on the tensor. - std::vector tensors = - GetValueOrThrow(XlaDataToTensors({GetXlaData()}, {dtype()})); + XLA_ASSIGN_OR_THROW(std::vector tensors, + XlaDataToTensors({GetXlaData()}, {dtype()})); tensor = std::move(tensors.front()); if (!detached) { SetTensorData(tensor); @@ -627,7 +627,9 @@ std::vector XLATensor::MakeOutputTensors( XLATensorPtr XLATensor::CopyTensorToDevice( const torch::lazy::BackendDevice& device) { // TODO: This can be optimized via proper XRT/XLA computation. - return GetValueOrThrow(Create(ToTensor(/*detached=*/true), device)); + XLA_ASSIGN_OR_THROW(XLATensorPtr result, + Create(ToTensor(/*detached=*/true), device)); + return result; } torch::lazy::Value XLATensor::MaybeCastIrValue( diff --git a/torch_xla/csrc/xla_backend_impl.cpp b/torch_xla/csrc/xla_backend_impl.cpp index 78f8548ff178..39e488307619 100644 --- a/torch_xla/csrc/xla_backend_impl.cpp +++ b/torch_xla/csrc/xla_backend_impl.cpp @@ -94,7 +94,9 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface { const torch::lazy::BackendDataPtr data, std::optional logical_scalar_type) const override { // TODO(JackCaoG): handle the logical_scalar_type == nullptr case - return GetValueOrThrow(XlaDataToTensors({data}, {*logical_scalar_type}))[0]; + XLA_ASSIGN_OR_THROW(std::vector tensors, + XlaDataToTensors({data}, {*logical_scalar_type})); + return tensors[0]; } std::unique_ptr CreateLoweringContext( @@ -163,7 +165,8 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface { torch::lazy::ComputationPtr computation, c10::ArrayRef arguments, const torch::lazy::BackendDevice& device) const override { - std::vector results = GetValueOrThrow( + XLA_ASSIGN_OR_THROW( + std::vector results, runtime::GetComputationClientOrDie()->ExecuteComputation( *std::dynamic_pointer_cast( computation), diff --git a/torch_xla/csrc/xla_manual_registration.cpp b/torch_xla/csrc/xla_manual_registration.cpp index f439d4634bdd..f5e5f74f3805 100644 --- a/torch_xla/csrc/xla_manual_registration.cpp +++ b/torch_xla/csrc/xla_manual_registration.cpp @@ -38,8 +38,8 @@ at::Tensor nms_kernel(const at::Tensor& boxes, const at::Tensor& scores, XLA_CHECK_EQ(boxes.size(0), scores.size(0)) << "nms(): boxes and scores should have the same size for dimension 0."; - XLATensorPtr xla_boxes = GetValueOrThrow(bridge::GetXlaTensor(boxes)); - XLATensorPtr xla_scores = GetValueOrThrow(bridge::GetXlaTensor(scores)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_boxes, bridge::GetXlaTensor(boxes)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_scores, bridge::GetXlaTensor(scores)); return bridge::AtenFromXlaTensor( tensor_methods::nms(xla_boxes, xla_scores, iou_threshold), /*skip_functionalization=*/true); diff --git a/torch_xla/csrc/xla_op_builder.cpp b/torch_xla/csrc/xla_op_builder.cpp index 30a4da0e68fe..3403f4af42b9 100644 --- a/torch_xla/csrc/xla_op_builder.cpp +++ b/torch_xla/csrc/xla_op_builder.cpp @@ -690,8 +690,8 @@ xla::XlaOp ConcatInDim(const BuilderPtr& builder, xla::XlaOp Convert(const BuilderPtr& builder, const std::vector& operands, py::dict args) { std::string type = args["to_type"].cast(); - xla::PrimitiveType xla_type = - GetValueOrThrow(xla::primitive_util::StringToPrimitiveType(type)); + XLA_ASSIGN_OR_THROW(xla::PrimitiveType xla_type, + xla::primitive_util::StringToPrimitiveType(type)); return MaybeConvertTo(operands.at(0)->op, xla_type); } @@ -717,8 +717,8 @@ xla::XlaOp SetDimensionSize(const BuilderPtr& builder, xla::XlaOp BitcastConvert(const BuilderPtr& builder, const std::vector& operands, py::dict args) { std::string type = args["to_type"].cast(); - xla::PrimitiveType xla_type = - GetValueOrThrow(xla::primitive_util::StringToPrimitiveType(type)); + XLA_ASSIGN_OR_THROW(xla::PrimitiveType xla_type, + xla::primitive_util::StringToPrimitiveType(type)); return xla::BitcastConvertType(operands.at(0)->op, xla_type); } @@ -873,8 +873,8 @@ xla::Shape PyShapeToShape(py::object shape) { std::string type = py_shape["type"].cast(); std::vector dimensions = GetTupleVector(py_shape["sizes"].cast()); - xla::PrimitiveType xla_type = - GetValueOrThrow(xla::primitive_util::StringToPrimitiveType(type)); + XLA_ASSIGN_OR_THROW(xla::PrimitiveType xla_type, + xla::primitive_util::StringToPrimitiveType(type)); if (py_shape.contains("dynamic_dimensions")) { std::vector dynamic_dimensions = GetTupleVector(py_shape["dynamic_dimensions"]); diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index abc18206d420..55c6ebf186f8 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -767,7 +767,7 @@ void ShardingUtil::XlaMarkSharding(const at::Tensor& input, << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; XLA_CHECK(sharding.type() != xla::OpSharding::UNKNOWN) << "Can't explicilty annotate with UNKNOWN sharding type."; - XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(input)); // For Non DeviceData IR values, we directly attach the sharding spec to the // xtensor. From e7b11599d0793f0a3d02010b48c3b9bafefbe821 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 28 Aug 2025 15:22:00 -0300 Subject: [PATCH 078/133] `torch_xla`: Use new macros for throwing exceptions (part 2). (#9594) Follow-up: #9588 and #9580 Target: `torch_xla/csrc` directory In summary, this PR: - Replaces all calls to `OkOrThrow()` and `GetValueOrThrow()` (that throws an exception without source location information of the *"throw-site"*) with the macros `XLA_THROW_IF_ERROR()` and `XLA_ASSIGN_OR_THROW()`. - Corresponds to the fine-grained set of PRs that came from breaking down PR #9580 - Focuses on the `torch_xla/csrc` directory, replacing every use of those, now deprecated, functions by the newly introduced macros. _Note: since there were lots of files in `torch_xla/csrc` that needed update, they were split in multiple parts._ --- torch_xla/csrc/aten_autograd_ops.cpp | 47 ++++++----- torch_xla/csrc/ir_dump_util.cpp | 24 ++++-- torch_xla/csrc/xla_graph_executor.cpp | 63 +++++++------- torch_xla/csrc/xla_lower_util.cpp | 113 ++++++++++++++------------ 4 files changed, 132 insertions(+), 115 deletions(-) diff --git a/torch_xla/csrc/aten_autograd_ops.cpp b/torch_xla/csrc/aten_autograd_ops.cpp index c8fe95d536db..b86a40552707 100644 --- a/torch_xla/csrc/aten_autograd_ops.cpp +++ b/torch_xla/csrc/aten_autograd_ops.cpp @@ -34,8 +34,8 @@ torch::Tensor EinsumAutogradFunction::forward( } ctx->save_for_backward(vars); - std::vector xla_tensors = - GetValueOrThrow(bridge::GetXlaTensors(tensors)); + XLA_ASSIGN_OR_THROW(std::vector xla_tensors, + bridge::GetXlaTensors(tensors)); XLATensorPtr output = tensor_methods::einsum(eq_str, xla_tensors); return bridge::AtenFromXlaTensor(output); } @@ -45,13 +45,12 @@ torch::autograd::variable_list EinsumAutogradFunction::backward( torch::autograd::variable_list grad_output) { std::string equation = ctx->saved_data["equation"].toString()->string(); torch::autograd::variable_list tensors = ctx->get_saved_variables(); - std::vector xla_tensors = - GetValueOrThrow(bridge::GetXlaTensors(tensors)); - + XLA_ASSIGN_OR_THROW(std::vector xla_tensors, + bridge::GetXlaTensors(tensors)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output_0, + bridge::GetXlaTensor(grad_output[0])); std::tuple outputs = - tensor_methods::einsum_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad_output[0])), xla_tensors, - equation); + tensor_methods::einsum_backward(xla_grad_output_0, xla_tensors, equation); // For both einsum and max pool, we use "undef" as a placeholder for the // non-tensor grad inputs, in this case the equation string. @@ -193,10 +192,10 @@ torch::Tensor MaxPool3dAutogradFunction::forward( return std::get<0>(results); } ctx->save_for_backward({self}); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); auto outputs = tensor_methods::max_pool_nd( - GetValueOrThrow(bridge::GetXlaTensor(self)), /*spatial_dim_count=*/3, - XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), - XlaHelpers::I64List(padding), ceil_mode); + xla_self, /*spatial_dim_count=*/3, XlaHelpers::I64List(kernel_size), + XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode); return bridge::AtenFromXlaTensor(std::get<0>(outputs)); } @@ -221,11 +220,13 @@ torch::autograd::variable_list MaxPool3dAutogradFunction::backward( padding, dilation, ceil_mode, indices); } + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output_0, + bridge::GetXlaTensor(grad_output[0])); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); grad = bridge::AtenFromXlaTensor(tensor_methods::max_pool_nd_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad_output[0])), - GetValueOrThrow(bridge::GetXlaTensor(self)), - /*spatial_dim_count=*/3, XlaHelpers::I64List(kernel_size), - XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode)); + xla_grad_output_0, xla_self, /*spatial_dim_count=*/3, + XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), + XlaHelpers::I64List(padding), ceil_mode)); torch::Tensor undef; torch::autograd::variable_list grad_inputs = {grad, undef, undef, @@ -238,10 +239,10 @@ torch::Tensor max_pool2d_forward(torch::Tensor self, torch::IntArrayRef stride, torch::IntArrayRef padding, torch::IntArrayRef dilation, bool ceil_mode) { + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); auto outputs = tensor_methods::max_pool_nd( - GetValueOrThrow(bridge::GetXlaTensor(self)), /*spatial_dim_count=*/2, - XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), - XlaHelpers::I64List(padding), ceil_mode); + xla_self, /*spatial_dim_count=*/2, XlaHelpers::I64List(kernel_size), + XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode); return bridge::AtenFromXlaTensor(std::get<0>(outputs)); } @@ -249,11 +250,13 @@ torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self, torch::IntArrayRef kernel_size, torch::IntArrayRef stride, torch::IntArrayRef padding, bool ceil_mode) { + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + bridge::GetXlaTensor(grad_output)); + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); auto grad = bridge::AtenFromXlaTensor(tensor_methods::max_pool_nd_backward( - GetValueOrThrow(bridge::GetXlaTensor(grad_output)), - GetValueOrThrow(bridge::GetXlaTensor(self)), - /*spatial_dim_count=*/2, XlaHelpers::I64List(kernel_size), - XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode)); + xla_grad_output, xla_self, /*spatial_dim_count=*/2, + XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), + XlaHelpers::I64List(padding), ceil_mode)); return grad; } diff --git a/torch_xla/csrc/ir_dump_util.cpp b/torch_xla/csrc/ir_dump_util.cpp index e59269e3cc9f..3453bd642c38 100644 --- a/torch_xla/csrc/ir_dump_util.cpp +++ b/torch_xla/csrc/ir_dump_util.cpp @@ -264,14 +264,15 @@ std::string DumpUtil::ToHlo(c10::ArrayRef values, // Annotate HLO sharding selectively in the compuation. // This is no-op if an instruction doesn't have any sharding annotation. auto is_sharded = ShardingUtil::SetHloSharding(&lowering_ctx); - xla::XlaComputation computation = GetValueOrThrow(lowering_ctx.BuildXla()); + XLA_ASSIGN_OR_THROW(xla::XlaComputation computation, lowering_ctx.BuildXla()); static bool dump_post_optimizations = runtime::sys_util::GetEnvBool("XLA_DUMP_POST_OPTIMIZATIONS", false); if (dump_post_optimizations) { + XLA_ASSIGN_OR_THROW(xla::ProgramShape program_shape, + computation.GetProgramShape()); xla::Shape shape = MakeShapeWithDeviceLayout( - GetValueOrThrow(computation.GetProgramShape()).result(), - static_cast(device.type())); + program_shape.result(), static_cast(device.type())); std::vector instances; instances.push_back( {std::move(computation), device.toString(), @@ -286,12 +287,17 @@ std::string DumpUtil::ToHlo(c10::ArrayRef values, } switch (mode) { - case EmitMode::kHloReadable: - return GetValueOrThrow(runtime::util::GetComputationHloText(computation)); - case EmitMode::kHloProto: - return GetValueOrThrow( - runtime::util::GetDeterministicSerializedModuleProto( - computation.proto())); + case EmitMode::kHloReadable: { + XLA_ASSIGN_OR_THROW(std::string hlo_text, + runtime::util::GetComputationHloText(computation)); + return hlo_text; + } + case EmitMode::kHloProto: { + XLA_ASSIGN_OR_THROW(std::string serialized_proto, + runtime::util::GetDeterministicSerializedModuleProto( + computation.proto())); + return serialized_proto; + } case EmitMode::kStableHloReadable: return hloToStablehlo(&computation.proto(), /* emit_bytecode = */ false); diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 15f38cae2333..cf6a5a4105d2 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -497,8 +497,8 @@ std::vector XLAGraphExecutor::GetTensors( async != nullptr ? async->tensors_data : absl::Span()); - std::vector literals = - GetValueOrThrow(ReleaseGilAndTransferData(tensors_data)); + XLA_ASSIGN_OR_THROW(std::vector literals, + ReleaseGilAndTransferData(tensors_data)); return FetchTensors(tensors, literals, async != nullptr ? &async->indices : nullptr); @@ -846,12 +846,12 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( // OutputHandler creates sharded data for sharded // tensor results. Both sharded and unsharded results should be // "Assign"ed to the corresponding data placeholders. - std::vector outputs = - GetValueOrThrow( - runtime::GetComputationClientOrDie()->ExecuteReplicated( - *async->cached_computation->computation, - UnwrapXlaData(async->parameters_data), devices, - execute_options)); + XLA_ASSIGN_OR_THROW( + std::vector outputs, + runtime::GetComputationClientOrDie()->ExecuteReplicated( + *async->cached_computation->computation, + UnwrapXlaData(async->parameters_data), devices, + execute_options)); results = WrapXlaData(outputs); TF_VLOG(3) << "Executing Dynamo IR sharded graph hash " << torch::lazy::HashToString(hash) << " on devices " @@ -913,8 +913,8 @@ std::vector XLAGraphExecutor::ExecuteStablehlo( // Get program output shape. // TODO(lsy323): Get shape info from MLIR Module. - xla::ProgramShape program_shape = - GetValueOrThrow(computation.GetProgramShape()); + XLA_ASSIGN_OR_THROW(xla::ProgramShape program_shape, + computation.GetProgramShape()); xla::Shape shape = MakeShapeWithDeviceLayout( program_shape.result(), static_cast(device.type())); @@ -946,8 +946,9 @@ std::vector XLAGraphExecutor::ExecuteStablehlo( } } - std::vector result_data = - GetValueOrThrow(runtime::GetComputationClientOrDie()->ExecuteComputation( + XLA_ASSIGN_OR_THROW( + std::vector result_data, + runtime::GetComputationClientOrDie()->ExecuteComputation( *computations[0], UnwrapXlaData(arguments), device.toString())); return WrapXlaData(result_data); @@ -1123,12 +1124,12 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph( // OutputHandler creates sharded data for sharded // tensor results. Both sharded and unsharded results should be // "Assign"ed to the corresponding data placeholders. - std::vector outputs = - GetValueOrThrow( - runtime::GetComputationClientOrDie()->ExecuteReplicated( - *async->cached_computation->computation, - UnwrapXlaData(async->parameters_data), devices, - execute_options)); + XLA_ASSIGN_OR_THROW( + std::vector outputs, + runtime::GetComputationClientOrDie()->ExecuteReplicated( + *async->cached_computation->computation, + UnwrapXlaData(async->parameters_data), devices, + execute_options)); results = WrapXlaData(outputs); TORCH_LAZY_COUNTER("ExecuteReplicated", 1); TF_VLOG(3) << "Executing IR graph hash " @@ -1139,14 +1140,13 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph( TF_VLOG(3) << "Executing IR graph hash " << torch::lazy::HashToString(hash) << " on device " << async->device << " ..."; - std::vector outputs = - GetValueOrThrow( - runtime::GetComputationClientOrDie()->ExecuteComputation( - *async->cached_computation->computation, - UnwrapXlaData(async->parameters_data), - async->device.toString(), - {/*explode_tuple=*/true, - /*eager_mode=*/use_eager_mode})); + XLA_ASSIGN_OR_THROW( + std::vector outputs, + runtime::GetComputationClientOrDie()->ExecuteComputation( + *async->cached_computation->computation, + UnwrapXlaData(async->parameters_data), async->device.toString(), + {/*explode_tuple=*/true, + /*eager_mode=*/use_eager_mode})); results = WrapXlaData(outputs); TORCH_LAZY_COUNTER("ExecuteComputation", 1); TF_VLOG(3) << "Executing IR graph hash " @@ -1416,9 +1416,9 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( SetBufferDonors(&lowering_ctx, buffer_donor_indices); - xla::XlaComputation computation = GetValueOrThrow(lowering_ctx.BuildXla()); - xla::ProgramShape program_shape = - GetValueOrThrow(computation.GetProgramShape()); + XLA_ASSIGN_OR_THROW(xla::XlaComputation computation, lowering_ctx.BuildXla()); + XLA_ASSIGN_OR_THROW(xla::ProgramShape program_shape, + computation.GetProgramShape()); // TODO(yeounoh) enable wrapping with auto-sharding. bool should_wrap_parameter = @@ -1435,10 +1435,11 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( param_shardings = XlaHelpers::ExtractInputShardings(computation); } - computation = GetValueOrThrow( + XLA_ASSIGN_OR_THROW( + computation, XlaHelpers::WrapXlaComputation(computation, program_shape.parameters(), param_shardings, buffer_donor_indices)); - program_shape = GetValueOrThrow(computation.GetProgramShape()); + XLA_ASSIGN_OR_THROW(program_shape, computation.GetProgramShape()); } xla::Shape shape = MakeShapeWithDeviceLayout( program_shape.result(), static_cast(coll.device.type())); diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index 1f8327acc362..cbabaf3f146b 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -62,8 +62,8 @@ ConditionMaskData CreateConditionMaskData(xla::XlaOp condition) { xla::XlaOp GetPromotedMask(xla::XlaOp mask, const xla::Shape& input_shape) { const xla::Shape& mask_shape = ShapeHelper::ShapeOfXlaOp(mask); - xla::Shape promoted_mask_shape = - GetValueOrThrow(XlaHelpers::GetPromotedShape(mask_shape, input_shape)); + XLA_ASSIGN_OR_THROW(xla::Shape promoted_mask_shape, + XlaHelpers::GetPromotedShape(mask_shape, input_shape)); return XlaHelpers::ImplicitBroadcast(mask, mask_shape, promoted_mask_shape); } @@ -150,7 +150,9 @@ xla::XlaComputation MakeScatterComputation( if (combiner != nullptr) { result = combiner(p0, result); } - return GetValueOrThrow(cb.Build(result)); + XLA_ASSIGN_OR_THROW(xla::XlaComputation scatter_computation, + cb.Build(result)); + return scatter_computation; } xla::XlaOp CreateIndexAlongDim( @@ -543,8 +545,8 @@ std::vector CreateBroadcastTensors( for (const xla::XlaOp operand : operands) { const xla::Shape& operand_shape = ShapeHelper::ShapeOfXlaOp(operand); operand_shapes.push_back(operand_shape); - result_shape = GetValueOrThrow( - XlaHelpers::GetPromotedShape(result_shape, operand_shape)); + XLA_ASSIGN_OR_THROW(result_shape, XlaHelpers::GetPromotedShape( + result_shape, operand_shape)); } std::vector result; for (size_t i = 0; i < operands.size(); ++i) { @@ -1366,54 +1368,59 @@ std::vector BuildBoxSelectionLoop(int64_t num_boxes, // 3. The actual IoU threshold matrix. init_values[2] = iou_threshold_mask; - return GetValueOrThrow(xla::WhileLoopHelper( - [=](absl::Span values, xla::XlaBuilder* builder) { - xla::XlaOp box_index = values[0]; - // Check: current loop counter is within bounds, i.e. has a - // corresponding box. - return xla::Lt(box_index, - xla::ConstantR0(builder, num_boxes)); - }, - [=](absl::Span values, xla::XlaBuilder* builder) { - const xla::XlaOp ONE = xla::One(builder, XLAIndexType); - const xla::XlaOp ZERO = xla::Zero(builder, XLAIndexType); - - xla::XlaOp box_index = values[0]; - xla::XlaOp state = values[1]; - xla::XlaOp iou_threshold_mask = values[2]; - - // Retrieve the IoU mask row corresponding to this box. - xla::XlaOp box_iou_threshold_mask = xla::DynamicSlice( - iou_threshold_mask, {box_index, ZERO}, {1, num_boxes}); - - // Update the current state with the IoU mask. - // Basically, sets to false every box X whose IoU with the current box - // is less-than or equal than the given threshold. - xla::XlaOp updated_state = xla::And( - state, - // Update the mask so that if we select this box - // (i.e. state[box] == true), we don't de-select it. - xla::DynamicUpdateSlice( - // Before that, we need to pre-process the mask. - // 1. Negate the mask: if this box is selected, we only want - // those that have a low intersection ratio. - // 2. Reshape it to: [num_boxes]. - xla::Reshape(xla::Not(box_iou_threshold_mask), {num_boxes}), - xla::ConstantR1(builder, {true}), {box_index})); - - // Flag: should this box (loop counter) be included in the output? - xla::XlaOp should_include = xla::DynamicSlice(state, {box_index}, {1}); - // Pick the new values of state, depending on whether we should include - // this box or not. - xla::XlaOp new_state = - xla::Select(xla::BroadcastInDim(should_include, {num_boxes}, {0}), - updated_state, state); - - xla::XlaOp next_box_index = box_index + ONE; - return std::vector{next_box_index, new_state, - iou_threshold_mask}; - }, - init_values, "BoxSelectionLoop", builder)); + XLA_ASSIGN_OR_THROW( + std::vector result, + xla::WhileLoopHelper( + [=](absl::Span values, xla::XlaBuilder* builder) { + xla::XlaOp box_index = values[0]; + // Check: current loop counter is within bounds, i.e. has a + // corresponding box. + return xla::Lt(box_index, + xla::ConstantR0(builder, num_boxes)); + }, + [=](absl::Span values, xla::XlaBuilder* builder) { + const xla::XlaOp ONE = xla::One(builder, XLAIndexType); + const xla::XlaOp ZERO = xla::Zero(builder, XLAIndexType); + + xla::XlaOp box_index = values[0]; + xla::XlaOp state = values[1]; + xla::XlaOp iou_threshold_mask = values[2]; + + // Retrieve the IoU mask row corresponding to this box. + xla::XlaOp box_iou_threshold_mask = xla::DynamicSlice( + iou_threshold_mask, {box_index, ZERO}, {1, num_boxes}); + + // Update the current state with the IoU mask. + // Basically, sets to false every box X whose IoU with the current + // box is less-than or equal than the given threshold. + xla::XlaOp updated_state = xla::And( + state, + // Update the mask so that if we select this box + // (i.e. state[box] == true), we don't de-select it. + xla::DynamicUpdateSlice( + // Before that, we need to pre-process the mask. + // 1. Negate the mask: if this box is selected, we only + // want + // those that have a low intersection ratio. + // 2. Reshape it to: [num_boxes]. + xla::Reshape(xla::Not(box_iou_threshold_mask), {num_boxes}), + xla::ConstantR1(builder, {true}), {box_index})); + + // Flag: should this box (loop counter) be included in the output? + xla::XlaOp should_include = + xla::DynamicSlice(state, {box_index}, {1}); + // Pick the new values of state, depending on whether we should + // include this box or not. + xla::XlaOp new_state = xla::Select( + xla::BroadcastInDim(should_include, {num_boxes}, {0}), + updated_state, state); + + xla::XlaOp next_box_index = box_index + ONE; + return std::vector{next_box_index, new_state, + iou_threshold_mask}; + }, + init_values, "BoxSelectionLoop", builder)); + return result; } xla::XlaOp BuildNms(xla::XlaOp boxes, xla::XlaOp scores, From 004f19e7e6ae0a9c0b76c5e6118f74c282aad156 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 29 Aug 2025 18:08:58 -0300 Subject: [PATCH 079/133] Remove CUDA specific logic from runtime. (#9598) This PR removes CUDA specific logic from `torch_xla/csrc/runtime` directory, as well as uses of deleted functions and environment variables from outside. This is in line with the CUDA deprecation that started on release 2.8. **Key Changes:** - Removed environment variable `ZERO_COPY_ENABLED`, which was used to use DLPack for moving without copying tensors from PyTorch CUDA to PyTorch/XLA XLA:CUDA device - Removed Python API function `_get_stream_for_cuda_device`, which was used in `dlpack.py` for DLPack related logic on CUDA capsules - Removed `ComputationClient::GetCudaStreamForDevice()`, which was used by the Python API above - Removed `PjRtComputationClient::RegisterCustomCall()`, since it only worked when `platform == "CUDA"` - Removed `GetGpuAllocatorConfig()` - Removed `from_xla_cuda_to_cuda()` DLPack function - Removed CUDA branch from `InitializePjRt()` --- test/dynamo/test_dynamo.py | 6 -- torch_xla/_dynamo/dynamo_bridge.py | 17 ++--- torch_xla/core/xla_env_vars.py | 1 - torch_xla/csrc/init_python_bindings.cpp | 5 -- torch_xla/csrc/runtime/computation_client.h | 2 - torch_xla/csrc/runtime/env_vars.h | 4 -- .../csrc/runtime/ifrt_computation_client.cpp | 2 - .../csrc/runtime/ifrt_computation_client.h | 4 -- .../csrc/runtime/pjrt_computation_client.cpp | 42 ----------- .../csrc/runtime/pjrt_computation_client.h | 15 +--- torch_xla/csrc/runtime/pjrt_registry.cpp | 71 ------------------- torch_xla/utils/dlpack.py | 38 ++-------- 12 files changed, 12 insertions(+), 195 deletions(-) diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index 572d255514a6..2c05adf7716c 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -157,9 +157,6 @@ def _choose_proper_device(self, initialize_on_cuda): self.skipTest( "Skip this test because it requires xr.device_type()=='CUDA' and torch.cuda.is_available()." ) - os.environ.update({ - xenv.ZERO_COPY_ENABLED: "1", - }) return "cuda:0" @skipOnNeuron @@ -205,9 +202,6 @@ def test_simple_model(self): "1", ) def test_simple_model_automoves_tensors(self, zero_copy_enabled): - os.environ.update({ - xenv.ZERO_COPY_ENABLED: zero_copy_enabled, - }) x = torch.tensor(100.0, requires_grad=True, device="cuda:0") y = torch.tensor(200.0, requires_grad=True, device="cuda:0") original_device = x.device diff --git a/torch_xla/_dynamo/dynamo_bridge.py b/torch_xla/_dynamo/dynamo_bridge.py index 7cae4f7392e5..bb69e91d9518 100644 --- a/torch_xla/_dynamo/dynamo_bridge.py +++ b/torch_xla/_dynamo/dynamo_bridge.py @@ -148,19 +148,10 @@ def _maybe_move_tensors_to_device(tensors: tuple, if dynamo_debug: print("Moving Tensor {} to device {}".format(tensor, target_device)) - zero_copy_enabled = xu.getenv_as(xenv.ZERO_COPY_ENABLED, bool, defval=False) - if zero_copy_enabled and tensor.device.type == 'cuda' and target_device.type == 'xla': - # If the input cuda tensor requires gradient, we need to call detach. Otherwise, we'd get the error "RuntimeError: Can't export tensors that require gradient, use tensor.detach()" - moved_tensor = torch_xla_dlpack.from_dlpack(tensor.detach()) - elif zero_copy_enabled and tensor.device.type == 'xla' and target_device.type == 'cuda': - # `torch_xla.sync()` is need to make sure the pjrt buffer is valid. - torch_xla.sync() - moved_tensor = torch_xla_dlpack.from_xla_cuda_to_cuda(tensor) - else: - # Have to move to CPU before moving it to target device. - cpu_device: torch.device = torch.device("cpu") - moved_tensor = tensor.to(cpu_device) - moved_tensor = moved_tensor.to(target_device) + # Have to move to CPU before moving it to target device. + cpu_device: torch.device = torch.device("cpu") + moved_tensor = tensor.to(cpu_device) + moved_tensor = moved_tensor.to(target_device) # Explicitly have to copy requires_grad attribute because it's dropped # with torch.to(..) diff --git a/torch_xla/core/xla_env_vars.py b/torch_xla/core/xla_env_vars.py index 32adf48f0c4a..2d256c77a540 100644 --- a/torch_xla/core/xla_env_vars.py +++ b/torch_xla/core/xla_env_vars.py @@ -30,4 +30,3 @@ RANK = 'RANK' WORLD_SIZE = 'WORLD_SIZE' LOCAL_WORLD_SIZE = 'LOCAL_WORLD_SIZE' -ZERO_COPY_ENABLED = 'ZERO_COPY_ENABLED' diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index ffff87ee0a14..57e5e16a877b 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1771,11 +1771,6 @@ void InitXlaModuleBindings(py::module m) { []() { return runtime::GetComputationClientOrDie()->GetPlatformVersion(); }) - .def("_get_stream_for_cuda_device", - [](const int device_id) { - return runtime::GetComputationClientOrDie()->GetCudaStreamForDevice( - device_id); - }) .def("_xla_num_devices", []() -> int64_t { if (UseVirtualDevice()) { diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 6d05137a89a3..79ff199eb2ff 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -375,8 +375,6 @@ class ComputationClient { virtual absl::StatusOr LookupAddressableDevice( int local_device_id) const = 0; - virtual std::intptr_t GetCudaStreamForDevice(int local_device_id) const = 0; - virtual size_t GetNumLocalDevices() const = 0; virtual size_t GetNumDevices() const = 0; diff --git a/torch_xla/csrc/runtime/env_vars.h b/torch_xla/csrc/runtime/env_vars.h index 827c4822d491..ee1ce63687b3 100644 --- a/torch_xla/csrc/runtime/env_vars.h +++ b/torch_xla/csrc/runtime/env_vars.h @@ -10,14 +10,12 @@ namespace env { inline constexpr char kEnvLocalWorker[] = "LOCAL_WORKER"; inline constexpr char kEnvTpuConfig[] = "TPU_CONFIG"; inline constexpr char kEnvNumTpu[] = "TPU_NUM_DEVICES"; -inline constexpr char kEnvNumGpu[] = "GPU_NUM_DEVICES"; inline constexpr char kEnvNumCpu[] = "CPU_NUM_DEVICES"; inline constexpr char kEnvTpuvmMode[] = "TPUVM_MODE"; inline constexpr char kEnvPjRtDevice[] = "PJRT_DEVICE"; inline constexpr char kEnvPjRtTpuMaxInflightComputations[] = "PJRT_TPU_MAX_INFLIGHT_COMPUTATIONS"; inline constexpr char kEnvPjrtAsyncCpuClient[] = "PJRT_CPU_ASYNC_CLIENT"; -inline constexpr char kEnvPjrtAsyncGpuClient[] = "PJRT_GPU_ASYNC_CLIENT"; inline constexpr char kEnvTpuLibraryPath[] = "TPU_LIBRARY_PATH"; inline constexpr char kEnvInferredTpuLibraryPath[] = "PTXLA_TPU_LIBRARY_PATH"; inline constexpr char kEnvXpuLibraryPath[] = "XPU_LIBRARY_PATH"; @@ -25,8 +23,6 @@ inline constexpr char kEnvNeuronLibraryPath[] = "NEURON_LIBRARY_PATH"; inline constexpr char kEnvPjrtDistServiceAddr[] = "PJRT_DIST_SERVICE_ADDR"; inline constexpr char kEnvPjRtLocalProcessCount[] = "PJRT_LOCAL_PROCESS_COUNT"; inline constexpr char kEnvPjRtLocalRank[] = "PJRT_LOCAL_PROCESS_RANK"; -inline constexpr char kEnvPjrtAllocatorCudaAsync[] = - "PJRT_ALLOCATOR_CUDA_ASYNC"; inline constexpr char kEnvPjrtAllocatorPreallocate[] = "PJRT_ALLOCATOR_PREALLOCATE"; inline constexpr char kEnvPjrtAllocatorFraction[] = "PJRT_ALLOCATOR_FRACTION"; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cpp b/torch_xla/csrc/runtime/ifrt_computation_client.cpp index db9ec8dab512..54eccb2727f5 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cpp +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cpp @@ -161,8 +161,6 @@ IfrtComputationClient::Create() { } IfrtComputationClient::~IfrtComputationClient() { - // In the GPU case, the PjRtClient depends on the DistributedRuntimeClient - // tracked in XlaCoordinator, so the PjRtClient must be destroyed first. client_ = nullptr; coordinator_ = nullptr; } diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index e1bcc751bbf3..8b45922c397f 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -110,10 +110,6 @@ class IfrtComputationClient : public ComputationClient { XLA_ERROR() << __FUNCTION__ << " not implemented"; } - std::intptr_t GetCudaStreamForDevice(int local_device_id) const override { - XLA_ERROR() << __FUNCTION__ << " not implemented"; - } - std::vector GetLocalDevices() const override; std::vector GetAllDevices() const override; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cpp b/torch_xla/csrc/runtime/pjrt_computation_client.cpp index 7e2833fc8f16..280b50964d82 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cpp +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cpp @@ -23,7 +23,6 @@ #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/literal.h" -#include "xla/pjrt/c/pjrt_c_api_gpu_extension.h" #include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h" #include "xla/pjrt/pjrt_api.h" #include "xla/pjrt/pjrt_c_api_client.h" @@ -152,8 +151,6 @@ PjRtComputationClient::Create() { } PjRtComputationClient::~PjRtComputationClient() { - // In the GPU case, the PjRtClient depends on the DistributedRuntimeClient - // tracked in XlaCoordinator, so the PjRtClient must be destroyed first. client_ = nullptr; coordinator_ = nullptr; } @@ -1038,45 +1035,6 @@ ComputationClient::MemoryInfo PjRtComputationClient::GetMemoryInfo( }; } -void PjRtComputationClient::RegisterCustomCall(const std::string& fn_name, - void* function_ptr, - const std::string& platform) { - if (platform != "CUDA") { - XLA_ERROR() << "Custom call targets can only be registered for " - "PJRT CUDA runtime."; - return; - } - - auto* c_api_client = dynamic_cast(client_.get()); - if (!c_api_client) { - XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(fn_name, function_ptr, platform); - return; - } - const PJRT_Api* pjrt_api = c_api_client->pjrt_c_api(); - - // See openxla reference: - // https://github.com/openxla/xla/blob/b604c8d87df842002a7a8de79a434026329fbcb2/xla/pjrt/c/pjrt_c_api_gpu_test.cc#L414 - const PJRT_Extension_Base* next = - reinterpret_cast(pjrt_api->extension_start); - while (next != nullptr && - next->type != - PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call) { - next = next->next; - } - XLA_CHECK(next) << "Custom call extension not found"; - PJRT_Gpu_Register_Custom_Call_Args args; - args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE; - args.function_name = fn_name.c_str(); - args.function_name_size = fn_name.size(); - args.api_version = 0; - args.handler_execute = function_ptr; - PJRT_Error* error = - reinterpret_cast(next)->custom_call(&args); - if (error) { - XLA_ERROR() << error->status; - } -} - void PjRtComputationClient::OnReadyCallback( ComputationClient::DataPtr data, const std::function& callback) { std::shared_ptr buffer; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 3c13d3489cae..d550f1cce0cb 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -118,17 +118,6 @@ class PjRtComputationClient : public ComputationClient { xla::PjRtLocalDeviceId(local_device_id)); } - std::intptr_t GetCudaStreamForDevice(int local_device_id) const override { - absl::StatusOr pjrt_device = - client_->LookupAddressableDevice( - xla::PjRtLocalDeviceId(local_device_id)); - XLA_CHECK(pjrt_device.ok()) << "Failed to get a PjRt device."; - absl::StatusOr stream = - pjrt_device.value()->GetStreamForExternalReadyEvents(); - XLA_CHECK(stream.ok()) << "Failed to get a stream."; - return stream.value(); - } - std::vector GetLocalDevices() const override; std::vector GetAllDevices() const override; @@ -169,7 +158,9 @@ class PjRtComputationClient : public ComputationClient { absl::Span devices) const; void RegisterCustomCall(const std::string& fn_name, void* function_ptr, - const std::string& platform) override; + const std::string& platform) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + }; void OnReadyCallback(DataPtr data, const std::function& callback) override; diff --git a/torch_xla/csrc/runtime/pjrt_registry.cpp b/torch_xla/csrc/runtime/pjrt_registry.cpp index 44603efca833..162e6dca9d2d 100644 --- a/torch_xla/csrc/runtime/pjrt_registry.cpp +++ b/torch_xla/csrc/runtime/pjrt_registry.cpp @@ -15,7 +15,6 @@ #include "xla/pjrt/distributed/client.h" #include "xla/pjrt/distributed/distributed.h" #include "xla/pjrt/distributed/in_memory_key_value_store.h" -#include "xla/pjrt/gpu/se_gpu_pjrt_client.h" #include "xla/pjrt/pjrt_api.h" #include "xla/pjrt/pjrt_c_api_client.h" #include "xla/pjrt/tfrt_cpu_pjrt_client.h" @@ -44,23 +43,6 @@ class LibraryPlugin : public PjRtPlugin { std::unordered_map> pjrt_plugins_ = {{"LIBRARY", std::make_shared()}}; -xla::GpuAllocatorConfig GetGpuAllocatorConfig() { - auto allocator_config = xla::GpuAllocatorConfig{}; - if (sys_util::GetEnvString(env::kEnvPjrtAllocatorCudaAsync, "").empty() && - sys_util::GetEnvString(env::kEnvPjrtAllocatorPreallocate, "").empty() && - sys_util::GetEnvString(env::kEnvPjrtAllocatorFraction, "").empty()) { - return allocator_config; - } - if (sys_util::GetEnvBool(env::kEnvPjrtAllocatorCudaAsync, false)) { - allocator_config.kind = xla::GpuAllocatorConfig::Kind::kCudaAsync; - } - allocator_config.preallocate = - sys_util::GetEnvBool(env::kEnvPjrtAllocatorPreallocate, true); - allocator_config.memory_fraction = - sys_util::GetEnvDouble(env::kEnvPjrtAllocatorFraction, 0.75); - return allocator_config; -} - absl::StatusOr> GetPjRtPlugin( const std::string& device_type) { auto entry = pjrt_plugins_.find(device_type); @@ -167,59 +149,6 @@ InitializePjRt(const std::string& device_type) { } else if (device_type == "TPU_LEGACY") { return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError( "TPU_LEGACY client is no longer available.")); - } else if (device_type == "CUDA") { - TORCH_WARN("The XLA:CUDA device is deprecated in release 2.8. ", - "Future releases might remove XLA:CUDA support entirely. ", - "Use the PyTorch native CUDA backend, instead.") - TF_VLOG(1) << "Initializing PjRt GPU client..."; - bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncGpuClient, true); - int local_process_rank = sys_util::GetEnvInt(env::kEnvPjRtLocalRank, 0); - int global_process_rank = sys_util::GetEnvInt("RANK", local_process_rank); - int local_world_size = sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1); - int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", local_world_size); - - TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id=" - << global_process_rank << ", num_nodes=" << global_world_size - << ", local_process_rank=" << local_process_rank - << ", local_world_size=" << local_world_size - << ", spmd case=" << sys_util::GetEnvBool("XLA_USE_SPMD", false) - << ", PJRT_LOCAL_PROCESS_RANK=" - << sys_util::GetEnvString(env::kEnvPjRtLocalRank, "") - << ", RANK=" << sys_util::GetEnvString("RANK", "") - << ", LOCAL_WORLD_SIZE=" - << sys_util::GetEnvString("LOCAL_WORLD_SIZE", "") - << ", WORLD_SIZE=" << sys_util::GetEnvString("WORLD_SIZE", ""); - std::optional> allowed_devices; - if (local_world_size > 1) { - allowed_devices = std::set{local_process_rank}; - } - - std::shared_ptr kv_store; - if (global_world_size > 1) { - // Use the distributed key-value store from DistributedRuntimeClient. - std::string master_addr = - runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost"); - std::string port = runtime::sys_util::GetEnvString( - "XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort); - XLA_ASSIGN_OR_RETURN( - coordinator, - XlaCoordinator::Create(global_process_rank, global_world_size, - master_addr, port)); - std::shared_ptr distributed_client = - coordinator->GetClient(); - kv_store = xla::GetDistributedKeyValueStore(distributed_client, - /*key_prefix=*/"gpu:"); - } - - xla::GpuClientOptions options; - options.allocator_config = GetGpuAllocatorConfig(); - options.node_id = global_process_rank; - options.num_nodes = global_world_size; - options.allowed_devices = allowed_devices; - options.platform_name = "gpu"; - options.should_stage_host_to_device_transfers = true; - options.kv_store = kv_store; - XLA_ASSIGN_OR_RETURN(client, xla::GetStreamExecutorGpuClient(options)); } else if (device_type == "XPU") { TF_VLOG(1) << "Initializing PjRt XPU client..."; XLA_RETURN_IF_ERROR(pjrt::LoadPjrtPlugin( diff --git a/torch_xla/utils/dlpack.py b/torch_xla/utils/dlpack.py index d66bafe749d9..8fefa5d65b3c 100644 --- a/torch_xla/utils/dlpack.py +++ b/torch_xla/utils/dlpack.py @@ -13,40 +13,12 @@ def to_dlpack(xla_tensor: Any): def from_dlpack(ext_tensor: Any): if hasattr(ext_tensor, '__dlpack_device__') and hasattr( ext_tensor, '__dlpack__'): - device_type, device_id = ext_tensor.__dlpack_device__() - if device_type == DLDeviceType.kDLGPU: - stream = torch_xla._XLAC._get_stream_for_cuda_device(device_id) - dlpack = ext_tensor.__dlpack__(stream=stream) - else: - dlpack = ext_tensor.__dlpack__() + device_type, _ = ext_tensor.__dlpack_device__() + if device_type != DLDeviceType.kDLCPU: + raise ValueError( + "PyTorch/XLA DLPack implementation currently only supports CPU.") + dlpack = ext_tensor.__dlpack__() else: dlpack = ext_tensor return torch_xla._XLAC._from_dlpack(dlpack) - - -def from_xla_cuda_to_cuda(tensor): - assert torch.cuda.is_available() - assert tensor.device.type == "xla", "The tensor is not an XLA tensor" - is_xla_cuda = True if xu.getenv_as("PJRT_DEVICE", str, - "").lower() == "cuda" else False - assert is_xla_cuda, "The XLA tensor is not on CUDA" - # consumer is torch, producer is torch_xla - - # Similar logic as torch.utils.dlpack.from_dlpack - # https://github.com/pytorch/pytorch/blob/b0ef363972203b163cddc95e4c6054b8221c2300/torch/utils/dlpack.py#L114-L115 - # The array API specify that the default legacy stream must be passed - # with a value of 1 for CUDA - device_id = tensor.device.index - stream = torch_xla._XLAC._get_stream_for_cuda_device(device_id) - stream = 1 if stream == 0 else stream - assert stream is None or type(stream) is int - external_stream = torch.cuda.ExternalStream(stream) - current_stream = torch.cuda.current_stream() - if external_stream != current_stream: - event = torch.cuda.Event() - event.record(current_stream) - external_stream.wait_event(event) - dlpack = to_dlpack(tensor) - cuda_tensor = torch.utils.dlpack.from_dlpack(dlpack) - return cuda_tensor From 763e5b78d4fcd74a9e812256656c075f99d9a781 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 29 Aug 2025 18:29:02 -0300 Subject: [PATCH 080/133] Remove `gpu_custom_call` logic. (#9600) This PR removes the implementation of `gpu_custom_call`. This is in line with the CUDA deprecation that started on release 2.8. **Key Changes:** - Delete both `ops/gpu_custom_call.cpp` and `ops/gpu_custom_call.h` - (`tensor_methods.{h,cpp}`) Remove `tensor_methods::gpu_custom_call` - (`ops/xla_ops.{h,cpp}`) Remove `OpKindWrapper xla_gpu_custom_call` global variable - (`init_python_bindings.cpp`) Remove the Python API function `_xla_gpu_custom_call` - (`init_python_bindings.cpp`) Make `XlaCustomCall` function into a TPU specific function `TpuCustomCall` --- torch_xla/csrc/init_python_bindings.cpp | 22 +++----------- torch_xla/csrc/ops/gpu_custom_call.cpp | 37 ----------------------- torch_xla/csrc/ops/gpu_custom_call.h | 25 ---------------- torch_xla/csrc/ops/xla_ops.cpp | 1 - torch_xla/csrc/ops/xla_ops.h | 3 +- torch_xla/csrc/tensor_methods.cpp | 40 ------------------------- torch_xla/csrc/tensor_methods.h | 5 ---- torch_xla/csrc/xla_lower_util.cpp | 25 ---------------- torch_xla/csrc/xla_lower_util.h | 4 --- 9 files changed, 5 insertions(+), 157 deletions(-) delete mode 100644 torch_xla/csrc/ops/gpu_custom_call.cpp delete mode 100644 torch_xla/csrc/ops/gpu_custom_call.h diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 57e5e16a877b..9ce45e8761a9 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -38,7 +38,6 @@ #include "pybind11/pytypes.h" #include "pybind11/stl.h" #include "pybind11/stl_bind.h" -#include "status.h" #include "torch_xla/csrc/XLANativeFunctions.h" #include "torch_xla/csrc/aten_autograd_ops.h" #include "torch_xla/csrc/aten_fallback.h" @@ -345,10 +344,10 @@ std::vector> CreateReduceGroups(const py::list& groups) { return replica_groups; } -std::vector XlaCustomCall( +std::vector TpuCustomCall( const std::vector& inputs, const std::string& payload, const std::vector>& output_shapes, - const std::vector& output_dtypes, bool is_tpu) { + const std::vector& output_dtypes) { std::vector dtypes; dtypes.reserve(output_dtypes.size()); for (auto& dtype : output_dtypes) { @@ -356,11 +355,7 @@ std::vector XlaCustomCall( } XLA_ASSIGN_OR_THROW(std::vector xla_inputs, bridge::GetXlaTensors(inputs)); - if (is_tpu) { - return bridge::AtenFromXlaTensors(tensor_methods::tpu_custom_call( - xla_inputs, payload, output_shapes, dtypes)); - } - return bridge::AtenFromXlaTensors(tensor_methods::gpu_custom_call( + return bridge::AtenFromXlaTensors(tensor_methods::tpu_custom_call( xla_inputs, payload, output_shapes, dtypes)); } @@ -3058,8 +3053,7 @@ void InitXlaModuleBindings(py::module m) { const std::vector>& output_shapes, const std::vector& output_dtypes) -> std::vector { - return XlaCustomCall(inputs, payload, output_shapes, output_dtypes, - /*is_tpu=*/true); + return TpuCustomCall(inputs, payload, output_shapes, output_dtypes); }) .def("_has_cuda_support", []() { @@ -3069,14 +3063,6 @@ void InitXlaModuleBindings(py::module m) { return false; #endif }) - .def("_xla_gpu_custom_call", - [](const std::vector& inputs, const std::string& payload, - const std::vector>& output_shapes, - const std::vector& output_dtypes) - -> std::vector { - return XlaCustomCall(inputs, payload, output_shapes, output_dtypes, - /*is_tpu=*/false); - }) .def("_xla_register_custom_call_target", [](const std::string& fn_name, const py::capsule& function_ptr, const std::string& platform) { diff --git a/torch_xla/csrc/ops/gpu_custom_call.cpp b/torch_xla/csrc/ops/gpu_custom_call.cpp deleted file mode 100644 index 26581f94899b..000000000000 --- a/torch_xla/csrc/ops/gpu_custom_call.cpp +++ /dev/null @@ -1,37 +0,0 @@ -#include "torch_xla/csrc/ops/gpu_custom_call.h" - -#include "torch_xla/csrc/lowering_context.h" -#include "torch_xla/csrc/ops/xla_ops.h" -#include "torch_xla/csrc/xla_lower_util.h" - -namespace torch_xla { - -GpuCustomCall::GpuCustomCall(torch::lazy::OpList inputs, - xla::Shape output_shape, - const std::string& payload) - : XlaNode(xla_gpu_custom_call, inputs, output_shape, - /*num_outputs=*/output_shape.tuple_shapes_size(), - torch::lazy::MHash(payload)), - payload_(payload) {} - -torch::lazy::NodePtr GpuCustomCall::Clone(torch::lazy::OpList operands) const { - return torch_xla::MakeNode(operands, xla_shape(), payload_); -} - -XlaOpVector GpuCustomCall::Lower(LoweringContext* loctx) const { - std::vector inputs; - inputs.reserve(operands().size()); - for (auto& operand : operands()) { - inputs.push_back(loctx->GetOutputOp(operand)); - } - auto output = BuildGpuCustomCall(inputs, xla_shape(), payload_); - return ReturnOps(output, loctx); -} - -std::string GpuCustomCall::ToString() const { - std::stringstream ss; - ss << XlaNode::ToString() << ", " << payload_; - return ss.str(); -} - -} // namespace torch_xla diff --git a/torch_xla/csrc/ops/gpu_custom_call.h b/torch_xla/csrc/ops/gpu_custom_call.h deleted file mode 100644 index fa08d62be676..000000000000 --- a/torch_xla/csrc/ops/gpu_custom_call.h +++ /dev/null @@ -1,25 +0,0 @@ -#ifndef XLA_TORCH_XLA_CSRC_OPS_GPU_CUSTOM_CALL_H_ -#define XLA_TORCH_XLA_CSRC_OPS_GPU_CUSTOM_CALL_H_ - -#include "torch_xla/csrc/ir.h" - -namespace torch_xla { -class GpuCustomCall : public XlaNode { - public: - // Make a GPU custom call with payload, e.g., Triton. - GpuCustomCall(torch::lazy::OpList inputs, xla::Shape output_shape, - const std::string& payload); - - torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; - - XlaOpVector Lower(LoweringContext* loctx) const override; - - std::string ToString() const override; - - private: - std::string payload_; -}; - -} // namespace torch_xla - -#endif // XLA_TORCH_XLA_CSRC_OPS_GPU_CUSTOM_CALL_H_ diff --git a/torch_xla/csrc/ops/xla_ops.cpp b/torch_xla/csrc/ops/xla_ops.cpp index 9187ee64fa95..a253d9cad8b9 100644 --- a/torch_xla/csrc/ops/xla_ops.cpp +++ b/torch_xla/csrc/ops/xla_ops.cpp @@ -39,6 +39,5 @@ const OpKindWrapper xla_unselect("xla::unselect"); const OpKindWrapper xla_update_slice("xla::update_slice"); const OpKindWrapper xla_custom_sharding("xla::custom_sharding"); const OpKindWrapper xla_tpu_custom_call("xla::tpu_custom_call"); -const OpKindWrapper xla_gpu_custom_call("xla::gpu_custom_call"); } // namespace torch_xla diff --git a/torch_xla/csrc/ops/xla_ops.h b/torch_xla/csrc/ops/xla_ops.h index 86ab2c57d4de..042de15e5cc7 100644 --- a/torch_xla/csrc/ops/xla_ops.h +++ b/torch_xla/csrc/ops/xla_ops.h @@ -64,8 +64,7 @@ extern const OpKindWrapper xla_unselect; extern const OpKindWrapper xla_update_slice; extern const OpKindWrapper xla_custom_sharding; extern const OpKindWrapper xla_tpu_custom_call; -extern const OpKindWrapper xla_gpu_custom_call; } // namespace torch_xla -#endif // XLA_TORCH_XLA_CSRC_OPS_XLA_OPS_H_ \ No newline at end of file +#endif // XLA_TORCH_XLA_CSRC_OPS_XLA_OPS_H_ diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index bfd67b59de24..e7814ce517d5 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -65,7 +65,6 @@ #include "torch_xla/csrc/ops/generic.h" #include "torch_xla/csrc/ops/generic_slice.h" #include "torch_xla/csrc/ops/get_dimensions_size.h" -#include "torch_xla/csrc/ops/gpu_custom_call.h" #include "torch_xla/csrc/ops/hardtanh_backward.h" #include "torch_xla/csrc/ops/index_ops.h" #include "torch_xla/csrc/ops/index_select.h" @@ -767,45 +766,6 @@ void custom_sharding_( input->SetShardingSpec(*sharding_spec); } -std::vector gpu_custom_call( - const std::vector& inputs, const std::string& payload, - const std::vector>& output_shapes, - const std::vector& output_dtypes) { - XLA_CHECK(inputs.size() > 0) << "inputs are empty"; - - std::vector values; - values.reserve(inputs.size()); - for (const auto& input : inputs) { - values.push_back(input->GetIrValue()); - } - - XLA_CHECK_EQ(output_shapes.size(), output_dtypes.size()); - std::vector output_xla_shapes; - output_xla_shapes.reserve(output_shapes.size()); - for (size_t i = 0; i < output_shapes.size(); ++i) { - output_xla_shapes.push_back(xla::ShapeUtil::MakeShape( - MakeXlaPrimitiveType(output_dtypes[i], &(inputs[0]->GetDevice())), - output_shapes[i])); - } - - auto node = torch_xla::MakeNode( - values, xla::ShapeUtil::MakeTupleShape(output_xla_shapes), payload); - - std::vector outputs; - outputs.reserve(output_shapes.size()); - for (size_t i = 0; i < output_shapes.size(); ++i) { - outputs.push_back(inputs[0]->CreateFrom(torch::lazy::Value(node, i), - output_dtypes[i], - /*delay_eager_execution=*/true)); - } - XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); - if (graph_executor->UseEagerMode()) { - // Execute the HLO that will run the `custom` and in one hlo - graph_executor->ApplyEagerSync(outputs); - } - return outputs; -} - std::vector tpu_custom_call( const std::vector& inputs, const std::string& payload, const std::vector>& output_shapes, diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index c28d7f2165e6..597640bf4c49 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -103,11 +103,6 @@ void custom_sharding_( const std::shared_ptr& spec, const CustomSharding::Type& type = CustomSharding::Type::kSharding); -std::vector gpu_custom_call( - const std::vector& inputs, const std::string& payload, - const std::vector>& output_shapes, - const std::vector& output_dtypes); - std::vector tpu_custom_call( const std::vector& inputs, const std::string& payload, const std::vector>& output_shapes, diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index cbabaf3f146b..0b41ea258a49 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -1281,31 +1281,6 @@ xla::XlaOp BuildCustomSharding(const xla::XlaOp& input, const std::string& type, output_shape); } -std::vector BuildGpuCustomCall( - const std::vector& inputs, const xla::Shape& output_shape, - const std::string& payload) { - std::vector input_shapes; - input_shapes.reserve(inputs.size()); - for (const auto& input : inputs) { - input_shapes.push_back(ShapeHelper::ShapeOfXlaOp(input)); - } - - XLA_CHECK(inputs.size() > 0) << "inputs are empty"; - xla::XlaOp outputs = xla::CustomCallWithLayout( - inputs[0].builder(), - /*call_target_name=*/"triton_kernel_call", inputs, output_shape, - input_shapes, payload, false, {}, nullptr, - xla::CustomCallSchedule::SCHEDULE_NONE, - xla::CustomCallApiVersion::API_VERSION_STATUS_RETURNING); - std::vector result; - int num_outputs = output_shape.tuple_shapes_size(); - result.reserve(num_outputs); - for (int i = 0; i < num_outputs; ++i) { - result.push_back(xla::GetTupleElement(outputs, i)); - } - return result; -} - std::vector BuildTpuCustomCall( const std::vector& inputs, const xla::Shape& output_shape, const std::string& payload) { diff --git a/torch_xla/csrc/xla_lower_util.h b/torch_xla/csrc/xla_lower_util.h index 60ebad6dcd6c..f2dc8a1915e9 100644 --- a/torch_xla/csrc/xla_lower_util.h +++ b/torch_xla/csrc/xla_lower_util.h @@ -162,10 +162,6 @@ std::vector BuildTpuCustomCall( xla::XlaOp BuildNms(xla::XlaOp boxes, xla::XlaOp scores, xla::XlaOp iou_threshold); -std::vector BuildGpuCustomCall( - const std::vector& inputs, const xla::Shape& output_shape, - const std::string& payload); - } // namespace torch_xla #endif // XLA_TORCH_XLA_CSRC_XLA_LOWER_UTIL_H_ From 8fb90c8a0cc9802ca8e96803d90af841704c2344 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Tue, 2 Sep 2025 11:21:24 -0300 Subject: [PATCH 081/133] Remove functions that throw status error. (#9602) Follow-up: #9580 This PR finalizes the work that started with #9580 for replacing `OkOrThrow()` and `GetValueOrThrow()` in favor of the macros introduced in #9588 (which were also part of that work). In summary, is the last one after: - #9580: initial work that was broken down into the PRs below - #9588: actual first PR merged - #9590 - #9591 - #9592 - #9593 - #9594 - #9595 - #9596 - #9602: last PR **Key Changes:** - (`test/cpp/test_status_common.h`) Remove tests for `OkOrThrow()` and `GetValueOrThrow()` - (`torch_xla/csrc/status.{h,cpp}`) Remove the implementation of those functions --- test/cpp/test_status_common.h | 101 ---------------------------------- torch_xla/csrc/status.cpp | 7 --- torch_xla/csrc/status.h | 29 ---------- 3 files changed, 137 deletions(-) diff --git a/test/cpp/test_status_common.h b/test/cpp/test_status_common.h index 17b0ef29f5ff..e09d2e58d28f 100644 --- a/test/cpp/test_status_common.h +++ b/test/cpp/test_status_common.h @@ -80,8 +80,6 @@ class StatusTest : public testing::TestWithParam { namespace cpp_test { // Prefix of the C++ stacktrace PyTorch adds to the error message. -constexpr inline char kTorchCppStacktracePrefixDeprecated[] = - "Exception raised from OkOrThrow at torch_xla/csrc/status.cpp:"; constexpr inline char kTorchCppStacktracePrefix[] = "Exception raised from ThrowStatusError at torch_xla/csrc/status.cpp:"; @@ -102,50 +100,6 @@ inline std::string GetStatusPropagationTrace(const absl::Status& status) { : ""; } -TEST_P(StatusTest, OkOrThrowWithOkStatus) { - absl::Status ok_status = absl::OkStatus(); - EXPECT_NO_THROW(OkOrThrow(ok_status)); -} - -TEST_P(StatusTest, OkOrThrowWithErrorStatus) { - try { - absl::Status error_status = absl::InvalidArgumentError(kMessage); - OkOrThrow(error_status); - } catch (const c10::Error& error) { - if (IsShowCppStacktracesMode()) { - EXPECT_THAT(std::string_view(error.what()), - ::testing::StartsWith(absl::StrCat( - kMessage, "\n\n", kTorchCppStacktracePrefixDeprecated))); - } else { - EXPECT_EQ(std::string_view(error.what_without_backtrace()), - std::string_view(kMessage)); - } - } -} - -TEST_P(StatusTest, GetValueOrThrowWithOkStatusOr) { - int value = 42; - absl::StatusOr status_or = value; - int result = GetValueOrThrow(std::move(status_or)); - EXPECT_EQ(result, value); -} - -TEST_P(StatusTest, GetValueOrThrowWithErrorStatusOr) { - try { - absl::StatusOr error_status = absl::InvalidArgumentError(kMessage); - int value = GetValueOrThrow(error_status); - } catch (const c10::Error& error) { - if (IsShowCppStacktracesMode()) { - EXPECT_THAT(std::string_view(error.what()), - ::testing::StartsWith(absl::StrCat( - kMessage, "\n\n", kTorchCppStacktracePrefixDeprecated))); - } else { - EXPECT_EQ(std::string_view(error.what_without_backtrace()), - std::string_view(kMessage)); - } - } -} - TEST_P(StatusTest, MaybeWithLocationPropagatesErrorStatus) { absl::Status error_status = absl::InvalidArgumentError(kMessage); absl::Status result = @@ -345,61 +299,6 @@ TEST_P(StatusTest, MacroErrorWithLocation) { } } -TEST_P(StatusTest, OkOrThrowWithErrorPropagationWithNewMessage) { - int32_t errline0 = __LINE__ + 2; - auto innerfn = [&]() -> absl::Status { - return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(kMessage)); - }; - - int32_t errline1 = __LINE__ + 2; - auto midfn = [&]() -> absl::Status { - XLA_RETURN_IF_ERROR(innerfn(), kNewMessage); - return absl::OkStatus(); - }; - - int32_t errline2 = __LINE__ + 2; - auto outerfn = [&]() -> absl::Status { - XLA_RETURN_IF_ERROR(midfn()); - return absl::OkStatus(); - }; - - try { - OkOrThrow(outerfn()); - } catch (const c10::Error& error) { - if (IsShowCppStacktracesMode()) { - // Expected Error Message Prefix - // ============================= - // - // New test error kMessage - // - // Status Propagation Stacktrace: - // From: ./test/cpp/test_status_common.h:329 (error: Test error - // kMessage) From: ./test/cpp/test_status_common.h:335 (error: New - // test error kMessage) From: ./test/cpp/test_status_common.h:342 - // - // C++ Stacktrace: - // - std::ostringstream oss; - oss << kNewMessage; - oss << "\n\n"; - oss << "Status Propagation Trace:"; - oss << kEntryPrefix << "From: operator() at " << __FILE__ << ":" - << errline0 << " (error: " << kMessage << ")"; - oss << kEntryPrefix << "From: operator() at " << __FILE__ << ":" - << errline1 << " (error: " << kNewMessage << ")"; - oss << kEntryPrefix << "From: operator() at " << __FILE__ << ":" - << errline2; - oss << "\n\n"; - oss << kTorchCppStacktracePrefixDeprecated; - EXPECT_THAT(std::string_view(error.what()), - ::testing::StartsWith(oss.str())); - } else { - EXPECT_EQ(std::string_view(error.what_without_backtrace()), - std::string_view(kNewMessage)); - } - } -} - TEST_P(StatusTest, MacroThrowIfErrorWithErrorPropagationWithNewMessage) { int32_t errline0 = __LINE__ + 2; auto innerfn = [&]() -> absl::Status { diff --git a/torch_xla/csrc/status.cpp b/torch_xla/csrc/status.cpp index 56636874f0b9..44f732fd6e97 100644 --- a/torch_xla/csrc/status.cpp +++ b/torch_xla/csrc/status.cpp @@ -129,13 +129,6 @@ void status_internal::ThrowStatusError(const absl::Status& status, LineBreakIfCppStacktracesEnabled())); } -void OkOrThrow(const absl::Status& status) { - TORCH_CHECK(status.ok(), absl::StrCat(BuildStatusErrorMessage(status), - LineBreakIfCppStacktracesEnabled())); -} - -void GetValueOrThrow(const absl::Status& status) { OkOrThrow(status); } - void status_internal::OkOrDie(const absl::Status& status, const char* file, const int32_t line, const char* function, std::string_view message) { diff --git a/torch_xla/csrc/status.h b/torch_xla/csrc/status.h index 87a9227672bb..bbf3bad1ed56 100644 --- a/torch_xla/csrc/status.h +++ b/torch_xla/csrc/status.h @@ -297,35 +297,6 @@ void OkOrDie(const absl::Status& status, const char* file, const int32_t line, // It doesn't add a trailing line break. std::string BuildStatusErrorMessage(const absl::Status& status); -// Throws an exception if `status` has a non-ok code. -// -// Ideally, this function should be used only used in the project's -// boundary, e.g. when we need to throw an exception for the user to see. -void OkOrThrow(const absl::Status& status); - -// Either returns the value `status` holds, if it's an ok-status, or throw an -// exception from its error status. -template -T& GetValueOrThrow(absl::StatusOr& status) { - OkOrThrow(status.status()); - return status.value(); -} - -template -const T& GetValueOrThrow(const absl::StatusOr& status) { - OkOrThrow(status.status()); - return status.value(); -} - -template -T GetValueOrThrow(absl::StatusOr&& status) { - OkOrThrow(status.status()); - return std::move(status).value(); -} - -// `GetValueOrThrow` overload for `Status`. -void GetValueOrThrow(const absl::Status& status); - } // namespace torch_xla #endif // XLA_TORCH_XLA_CSRC_STATUS_H_ From 05d9cba860ba9e522db531d06d18f1351e04607b Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Tue, 2 Sep 2025 13:16:23 -0300 Subject: [PATCH 082/133] Remove CUDA logic from C++ files in `torch_xla/csrc` directory. (#9603) This PR removes CUDA specific code from C++ files in `torch_xla/csrc` directory. This is in line with the CUDA deprecation that started on release 2.8. **Key Changes:** - (`init_python_bindings.cpp`) Removed `_has_cuda_support` Python API - (`dl_convertor.cpp`) Removed CUDA handling of DLPack capsules - (`tensor_impl.cpp`) Removed special handling of `Autocast` dispatch key for XLA:CUDA device - Also added a check, crashing on `XLA:CUDA` device (shouldn't be supported anymore) --- torch_xla/__init__.py | 3 +-- torch_xla/csrc/dl_convertor.cpp | 7 ------- torch_xla/csrc/init_python_bindings.cpp | 8 -------- torch_xla/csrc/random.cpp | 11 +---------- torch_xla/csrc/tensor_impl.cpp | 16 ++++++---------- 5 files changed, 8 insertions(+), 37 deletions(-) diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index 3f5f71ba6e5c..20b6f2adaad4 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -259,8 +259,7 @@ def _init_xla_lazy_backend(): from .experimental import plugins from ._internal import neuron, xpu # Additional built-in plugins -if os.getenv('XLA_REGISTER_INSTALLED_PLUGINS', - '0' if _XLAC._has_cuda_support() else '1') == '1': +if os.getenv('XLA_REGISTER_INSTALLED_PLUGINS', '1') == '1': plugins.use_dynamic_plugins() plugins.register_installed_plugins() diff --git a/torch_xla/csrc/dl_convertor.cpp b/torch_xla/csrc/dl_convertor.cpp index c4f8fc38efa2..c6a68a65f609 100644 --- a/torch_xla/csrc/dl_convertor.cpp +++ b/torch_xla/csrc/dl_convertor.cpp @@ -51,8 +51,6 @@ void DLPackTensorDeleter(DLManagedTensor* t) { DLDeviceType DLDeviceTypeForDevice(const xla::PjRtDevice& device) { if (device.client()->platform_id() == xla::CpuId()) { return DLDeviceType::kDLCPU; - } else if (device.client()->platform_id() == xla::CudaId()) { - return DLDeviceType::kDLCUDA; } XLA_ERROR() << "Device " << device.DebugString() << " cannot be used as a DLPack device."; @@ -176,11 +174,6 @@ absl::StatusOr DeviceForDLDevice(const DLDevice& context) { xla::CpuId()); return runtime::GetComputationClientOrDie()->LookupAddressableDevice( context.device_id); - case DLDeviceType::kDLCUDA: - XLA_CHECK_EQ(runtime::GetComputationClientOrDie()->GetPlatformID(), - xla::CudaId()); - return runtime::GetComputationClientOrDie()->LookupAddressableDevice( - context.device_id); default: return tsl::errors::InvalidArgument( "Unknown/unsupported DLPack device type %d", context.device_type); diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 9ce45e8761a9..1d409850b808 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -3055,14 +3055,6 @@ void InitXlaModuleBindings(py::module m) { -> std::vector { return TpuCustomCall(inputs, payload, output_shapes, output_dtypes); }) - .def("_has_cuda_support", - []() { -#ifdef GOOGLE_CUDA - return true; -#else - return false; -#endif - }) .def("_xla_register_custom_call_target", [](const std::string& fn_name, const py::capsule& function_ptr, const std::string& platform) { diff --git a/torch_xla/csrc/random.cpp b/torch_xla/csrc/random.cpp index 24e09e1188d7..a2f4699b933a 100644 --- a/torch_xla/csrc/random.cpp +++ b/torch_xla/csrc/random.cpp @@ -16,16 +16,7 @@ namespace torch_xla { namespace { -std::string GetDefaultGitGeneratorName() { - XlaDeviceType hw_type = - static_cast(bridge::GetCurrentDevice().type()); - switch (hw_type) { - case XlaDeviceType::CUDA: - return "three_fry"; - default: - return "default"; - } -} +std::string GetDefaultGitGeneratorName() { return "default"; } xla::BitGeneratorTy GetBitGenerator() { static const std::string* bit_generator = diff --git a/torch_xla/csrc/tensor_impl.cpp b/torch_xla/csrc/tensor_impl.cpp index a5527a671a7c..ee5ae3bce804 100644 --- a/torch_xla/csrc/tensor_impl.cpp +++ b/torch_xla/csrc/tensor_impl.cpp @@ -9,6 +9,7 @@ #include #include +#include "absl/log/absl_check.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/ir_builder.h" @@ -71,16 +72,11 @@ XLATensorImpl::XLATensorImpl(XLATensor&& tensor) GetTypeMeta(tensor), bridge::XlaDeviceToAtenDevice(tensor.GetDevice())), tensor_(c10::make_intrusive(std::move(tensor))) { - // Update the Autocast key based off the backend device. - // Upstream TensorImpl cannot differentiate between XLA:TPU and XLA:GPU - // so we must manually update Autocast to AutocastCUDA on XLA:GPU. - torch::lazy::BackendDevice current_device = bridge::GetCurrentDevice(); - auto dev_type = static_cast(current_device.type()); - if (dev_type == XlaDeviceType::CUDA) { - auto autocast_cuda_ks = c10::DispatchKeySet(c10::DispatchKey::AutocastCUDA); - auto autocast_xla_ks = c10::DispatchKeySet(c10::DispatchKey::AutocastXLA); - key_set_ = (key_set_ - autocast_xla_ks) | autocast_cuda_ks; - } + auto dev_type = static_cast(bridge::GetCurrentDevice().type()); + ABSL_CHECK(dev_type != XlaDeviceType::CUDA) + << "XLA:CUDA is not supported anymore. " + "If you are seeing this error, report a bug to the PyTorch/XLA GitHub " + "repository: https://github.com/pytorch/xla"; const_cast(this)->SetupSizeProperties(); set_sizes_and_strides(sym_sizes_, c10::fromIntArrayRefSlow( sizes_and_strides_.strides_arrayref())); From c0eeb5791b2e43eaef6ff11beb3e119b5a66b16d Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Tue, 2 Sep 2025 13:16:38 -0300 Subject: [PATCH 083/133] Remove CUDA specific path from internal Python packages. (#9606) This PR removes CUDA specific code from internal Python packages, such as `_dynamo`, files in `_internal`, and the main `__init__.py` file. This is in line with the CUDA deprecation that started on release 2.8. **Key Changes:** - (`torch_xla/__init__.py`) Removed GPU specific OpenXLA flag - (`torch_xla/_dynamo/dynamo_bridge.py`) Removed CUDA tensor movement - As far as I know, mainly created for the zero overhead CUDA tensor movement --- torch_xla/__init__.py | 2 -- torch_xla/_dynamo/dynamo_bridge.py | 58 ------------------------------ torch_xla/_internal/gpu.py | 15 -------- torch_xla/_internal/pjrt.py | 6 +--- 4 files changed, 1 insertion(+), 80 deletions(-) delete mode 100644 torch_xla/_internal/gpu.py diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index 20b6f2adaad4..05072113cce9 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -31,8 +31,6 @@ def _set_missing_flags(flags, sets): def _setup_xla_flags(): flags = os.environ.get('XLA_FLAGS', '').split(' ') flags = _set_missing_flags(flags, (('xla_cpu_enable_fast_math', 'false'),)) - flags = _set_missing_flags(flags, - (('xla_gpu_force_compilation_parallelism', '8'),)) os.environ['XLA_FLAGS'] = ' '.join(flags) diff --git a/torch_xla/_dynamo/dynamo_bridge.py b/torch_xla/_dynamo/dynamo_bridge.py index bb69e91d9518..906205429cdd 100644 --- a/torch_xla/_dynamo/dynamo_bridge.py +++ b/torch_xla/_dynamo/dynamo_bridge.py @@ -119,48 +119,6 @@ def _get_input_arg_device(input_args: tuple) -> torch.device: return device -# Returns True if all the input args are on a CUDA device. -def _args_on_cuda(input_args: tuple) -> bool: - input_device: torch.device = _get_input_arg_device(input_args) - if input_device is None: - return False - - return input_device.type == "cuda" - - -# Given an input list, moves the tensors to the given target_device. -# The output order will be the same as the input. Non tensors will also still -# be in the list. -def _maybe_move_tensors_to_device(tensors: tuple, - target_device: torch.device) -> tuple: - assert target_device, "Moving tensors to None device not supported" - - moved_tensors = [] - for tensor in tensors: - if not isinstance(tensor, torch.Tensor): - moved_tensors.append(tensor) - continue - - if tensor.device == target_device: - moved_tensors.append(tensor) - continue - - if dynamo_debug: - print("Moving Tensor {} to device {}".format(tensor, target_device)) - - # Have to move to CPU before moving it to target device. - cpu_device: torch.device = torch.device("cpu") - moved_tensor = tensor.to(cpu_device) - moved_tensor = moved_tensor.to(target_device) - - # Explicitly have to copy requires_grad attribute because it's dropped - # with torch.to(..) - moved_tensor.requires_grad = tensor.requires_grad - moved_tensors.append(moved_tensor) - - return tuple(moved_tensors) - - def _split_xla_args_tensor_sym_constant(args): tensors = deque(maxlen=len(args)) constants = [] @@ -552,14 +510,6 @@ def optimized_mod(*args: tuple): special_return_handler, xla_args_need_update) = extract_graph_helper( xla_model, sym_constants_to_graph_vars) - original_device: torch.device = _get_input_arg_device(args) - is_cuda_args: bool = False - if original_device: - is_cuda_args = original_device.type == "cuda" - - if is_cuda_args: - args = _maybe_move_tensors_to_device(args, torch_xla.device()) - if not config.skip_input_data_check: # `torch_xla.sync()` needs to be blocking since we want to access args's # XLADatas and they can't be placeholder. @@ -610,11 +560,7 @@ def optimized_mod(*args: tuple): # First few elements might be xla_args that needs to be in place updated result = res[len(xla_args_need_update):] - result = none_remover.add_nones(result) - if is_cuda_args: - result = _maybe_move_tensors_to_device(tuple(result), original_device) - if len(result) == 1: return result[0] else: @@ -802,10 +748,6 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: def extract_compiled_graph_helper(xla_model: torch.fx.GraphModule, xla_args): - if _args_on_cuda(xla_args): - xla_args = tuple( - _maybe_move_tensors_to_device(xla_args, torch_xla.device())) - # Synchronize xla_args, so that each FunctionalTensorWrapper argument updates its # value reference before actually computing it. for a in xla_args: diff --git a/torch_xla/_internal/gpu.py b/torch_xla/_internal/gpu.py deleted file mode 100644 index 45452a3c6c1f..000000000000 --- a/torch_xla/_internal/gpu.py +++ /dev/null @@ -1,15 +0,0 @@ -import os -import torch_xla.core.xla_env_vars as xenv - - -def num_local_processes() -> int: - """Returns number of processes to create on this host. - - Raises: - AssertionError: if GPU_NUM_DEVICES environment variable - is not configured - """ - assert xenv.GPU_NUM_DEVICES in os.environ, \ - "Must set `GPU_NUM_DEVICES` environment variable to use the PjRt GPU client" - os.environ[xenv.LOCAL_WORLD_SIZE] = os.environ[xenv.GPU_NUM_DEVICES] - return int(os.environ[xenv.LOCAL_WORLD_SIZE]) diff --git a/torch_xla/_internal/pjrt.py b/torch_xla/_internal/pjrt.py index 25a0ee36c36e..578e7f77fa6a 100644 --- a/torch_xla/_internal/pjrt.py +++ b/torch_xla/_internal/pjrt.py @@ -12,7 +12,7 @@ import torch_xla.core.xla_env_vars as xenv import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_backend -from torch_xla._internal import tpu, gpu, neuron +from torch_xla._internal import tpu, neuron from torch_xla import runtime import torch_xla.utils.utils as xu from torch_xla.experimental import plugins @@ -149,8 +149,6 @@ def run_multiprocess(fn: Callable[..., R], num_processes = plugins.default().physical_chip_count() elif runtime.device_type() == 'TPU': num_processes = tpu.num_local_processes() - elif runtime.device_type() == 'CUDA': - num_processes = gpu.num_local_processes() elif runtime.device_type() == 'NEURON': num_processes = neuron.num_local_processes() else: @@ -220,8 +218,6 @@ def _initialize_single_process(local_rank: int, local_world_size: int): def spawn_threads(fn: Callable, args: Tuple = ()) -> None: """Run function in one process with one thread per addressable device.""" - assert runtime.device_type() not in ( - 'CUDA'), "spawn_threads does not support GPU device" spawn_fn = _SpawnFn(fn, *args) _run_thread_per_device( local_rank=0, From 89f929b6642148cc969f706c3818b9e82e115665 Mon Sep 17 00:00:00 2001 From: Jialei Chen Date: Tue, 2 Sep 2025 09:53:44 -0700 Subject: [PATCH 084/133] Move `_jax_forward` and `_jax_backward` inside `j2t_autograd` to avoid cache collisions (#9585) --- torchax/torchax/interop.py | 60 +++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/torchax/torchax/interop.py b/torchax/torchax/interop.py index a87efe9dfe74..34ab79b10838 100644 --- a/torchax/torchax/interop.py +++ b/torchax/torchax/interop.py @@ -237,6 +237,36 @@ def j2t_autograd(fn, call_jax=call_jax): the PyTorch autograd framework by saving the residuals into the context object. """ + # NOTE(qihqi): This function cannot be inlined from the callsite + # Becuase if it does, then it won't hit the compilation cache for + # call_jax. Call jax uses functions' id as key. + # It is nested inside j2t_autograd to ensure it gets a unique ID for each + # wrapped pure function, preventing cache collisions between different pure modules. + def _jax_forward(fn, other, tree_def, tensors): + """JAX function to compute output and vjp function. + + primals should be a tuple (args, kwargs). + """ + import jax + from jax.tree_util import tree_flatten, tree_unflatten + + def fn_wrapper(*tensors): + # Reconstruct the original args and kwargs + flat_inputs = util.merge(tensors, other) + args, kwargs = tree_unflatten(tree_def, flat_inputs) + return fn(*args, **kwargs) + + return jax.vjp(fn_wrapper, *tensors) + + def _jax_backward(vjp_spec, saved_tensors, grad_out): + """JAX function to compute input gradients. + + Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function. + """ + from jax.tree_util import tree_unflatten + fun_vjp = tree_unflatten(vjp_spec, saved_tensors) + return fun_vjp(grad_out) + @wraps(fn) def inner(*args, **kwargs): from jax.tree_util import tree_flatten @@ -290,36 +320,6 @@ def backward(ctx, *grad_out): return inner -# NOTE(qihqi): This function cannot be inlined from the callsite -# Becuase if it does, then it won't hit the compilation cache for -# call_jax. Call jax uses functions' id as key. -def _jax_forward(fn, other, tree_def, tensors): - """JAX function to compute output and vjp function. - - primals should be a tuple (args, kwargs). - """ - import jax - from jax.tree_util import tree_flatten, tree_unflatten - - def fn_wrapper(*tensors): - # Reconstruct the original args and kwargs - flat_inputs = util.merge(tensors, other) - args, kwargs = tree_unflatten(tree_def, flat_inputs) - return fn(*args, **kwargs) - - return jax.vjp(fn_wrapper, *tensors) - - -def _jax_backward(vjp_spec, saved_tensors, grad_out): - """JAX function to compute input gradients. - - Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function. - """ - from jax.tree_util import tree_unflatten - fun_vjp = tree_unflatten(vjp_spec, saved_tensors) - return fun_vjp(grad_out) - - fori_loop = torch_view(jax.lax.fori_loop) From 647804ca96d2446226b8834497149a4dd06c8e02 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Tue, 2 Sep 2025 18:48:44 -0300 Subject: [PATCH 085/133] Remove remaining GPU/CUDA mentions in `torch_xla` directory. (#9608) This PR removes the remaining CUDA specific code from the PyTorch/XLA package (i.e. `torch_xla` directory) as well a few other related files. This is in line with the CUDA deprecation that started on release 2.8. **Key Changes:** - (`CONTRIBUTING.md`) Removed mention to CUDA specific environment variables - (`configuration.yaml`) Removed description of CUDA specific environment variables - (`docs/source/learn/_pjrt.md`) Removed PjRt documentation on CUDA - (`torch_xla/amp`) Removed CUDA specific branches, as well as `GradScaler` - (`torch_xla/core/xla_env_vars.py`) Removed CUDA specific environment variables - (`torch_xla/utils/checkpoint.py`) Fixed incorrect function name --- CONTRIBUTING.md | 6 -- configuration.yaml | 9 -- docs/source/learn/_pjrt.md | 63 -------------- docs/source/perf/amp.md | 53 ------------ test/test_autocast.py | 2 +- test/test_train_mp_imagenet_amp.py | 4 +- test/test_train_mp_mnist_amp.py | 7 +- torch_xla/_internal/pjrt.py | 2 +- torch_xla/amp/__init__.py | 1 - torch_xla/amp/autocast_mode.py | 67 +-------------- torch_xla/amp/grad_scaler.py | 82 ------------------- torch_xla/core/xla_env_vars.py | 2 - torch_xla/core/xla_model.py | 8 +- torch_xla/debug/profiler.py | 4 +- .../fsdp/xla_fully_sharded_data_parallel.py | 2 +- .../distributed_checkpoint/manager.py | 1 - torch_xla/runtime.py | 2 - torch_xla/utils/checkpoint.py | 2 +- 18 files changed, 17 insertions(+), 300 deletions(-) delete mode 100644 torch_xla/amp/grad_scaler.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a6eb0af8a54d..6c05fd88f747 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -291,12 +291,6 @@ To run the tests, follow __one__ of the options below: export PJRT_DEVICE=TPU ``` -* Run on GPU: - - ```shell - export PJRT_DEVICE=CUDA GPU_NUM_DEVICES=${NUM_GPU} - ``` - For more detail on configuring the runtime, please refer to [this doc](https://github.com/pytorch/xla/blob/master/docs/pjrt.md#quickstart) If you are planning to be building from source and hence using the latest _PyTorch/TPU_ code base, diff --git a/configuration.yaml b/configuration.yaml index c1760d608ae9..6c1a1844ca67 100644 --- a/configuration.yaml +++ b/configuration.yaml @@ -15,11 +15,6 @@ variables: - Whether or not to create an async PJRT client for the CPU device(s). type: bool default_value: false - PJRT_GPU_ASYNC_CLIENT: - description: - - Whether or not to create an async PJRT client for the GPU device(s). - type: bool - default_value: false PJRT_TPU_MAX_INFLIGHT_COMPUTATIONS: description: - Max inflight computations that the PJRT client can handle for TPU. @@ -229,10 +224,6 @@ variables: description: - Number of CPU devices being used by this instance of XRT. type: int - GPU_NUM_DEVICES: - description: - - Number of GPU devices being used by this instance of XRT. - type: int debug_variables: XLA_FNTRACKER_FILE: description: diff --git a/docs/source/learn/_pjrt.md b/docs/source/learn/_pjrt.md index 3b0c0eeb9dff..16300239353a 100644 --- a/docs/source/learn/_pjrt.md +++ b/docs/source/learn/_pjrt.md @@ -188,69 +188,6 @@ time. See the [Cloud TPU documentation](https://cloud.google.com/tpu/docs/run-in-container) for more information. -### GPU - -### Single-node GPU training - -To use GPUs with PJRT, simply set `PJRT_DEVICE=CUDA` and configure -`GPU_NUM_DEVICES` to the number of devices on the host. For example: - - PJRT_DEVICE=CUDA GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=128 --num_epochs=1 - -You can also use `torchrun` to initiate the single-node multi-GPU -training. For example, - - PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc-per-node ${NUM_GPU_DEVICES} xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1 - -In the above example, `--nnodes` means how many machines (physical -machines or VMs) to be used (it is 1 since we do single-node training). -`--nproc-per-node` means how many GPU devices to be used. - -### Multi-node GPU training - -**Note that this feature only works for cuda 12+**. Similar to how -PyTorch uses multi-node training, you can run the command as below: - - PJRT_DEVICE=CUDA torchrun \ - --nnodes=${NUMBER_GPU_VM} \ - --node_rank=${CURRENT_NODE_RANK} \ - --nproc_per_node=${NUMBER_LOCAL_GPU_DEVICES} \ - --rdzv_endpoint= multinode_training.py - -- `--nnodes`: how many GPU machines to be used. -- `--node_rank`: the index of the current GPU machines. The value can - be 0, 1, ..., \${NUMBER_GPU_VM}-1. -- `--nproc_per_node`: the number of GPU devices to be used on the - current machine. -- `--rdzv_endpoint`: the endpoint of the GPU machine with - node_rank==0, in the form `host:port`. The `host` will be the - internal IP address. The `port` can be any available port on the - machine. For single-node training/inference, this parameter can be - omitted. - -For example, if you want to train on 2 GPU machines: machine_0 and -machine_1, on the first GPU machine machine_0, run - - # PJRT_DEVICE=CUDA torchrun \ - --nnodes=2 \ - --node_rank=0 \ - --nproc_per_node=4 \ - --rdzv_endpoint=":12355" pytorch/xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1 - -On the second GPU machine, run - - # PJRT_DEVICE=CUDA torchrun \ - --nnodes=2 \ - --node_rank=1 \ - --nproc_per_node=4 \ - --rdzv_endpoint=":12355" pytorch/xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1 - -the difference between the 2 commands above are `--node_rank` and -potentially `--nproc_per_node` if you want to use different number of -GPU devices on each machine. All the rest are identical. For more -information about `torchrun`, please refer to this -[page](https://pytorch.org/docs/stable/elastic/run.html). - ## Differences from XRT Although in most cases we expect PJRT and XRT to work mostly diff --git a/docs/source/perf/amp.md b/docs/source/perf/amp.md index 4ad48753d45c..36d777fd865f 100644 --- a/docs/source/perf/amp.md +++ b/docs/source/perf/amp.md @@ -95,59 +95,6 @@ unlisted ops run if they're downstream from autocasted ops. `stack`, `cat`, `index_copy` -## AMP for XLA:GPU - -AMP on XLA:GPU devices reuse Pytorch's AMP rules. See [Pytorch's AMP -documentation](https://pytorch.org/docs/stable/amp.html) for CUDA -specific behavior. A simple CUDA AMP example is below: - -``` python -from torch_xla.amp import syncfree -import torch_xla.core.xla_model as xm - -# Creates model and optimizer in default precision -model = Net().to('xla') -# Pytorch/XLA provides sync-free optimizers for improved performance -optimizer = syncfree.SGD(model.parameters(), ...) -scaler = GradScaler() - -for input, target in data: - optimizer.zero_grad() - - # Enables autocasting for the forward pass - with autocast(torch_xla.device()): - output = model(input) - loss = loss_fn(output, target) - - # Exits the context manager before backward pass - scaler.scale(loss).backward() - gradients = xm._fetch_gradients(optimizer) - xm.all_reduce('sum', gradients, scale=1.0 / xr.world_size()) - scaler.step(optimizer) - scaler.update() -``` - -`autocast(torch_xla.device())` aliases `torch.cuda.amp.autocast()` when the -XLA Device is a CUDA device (XLA:GPU). Alternatively, if a script is -only used with CUDA devices, then `torch.cuda.amp.autocast` can be -directly used, but requires `torch` is compiled with `cuda` support for -datatype of `torch.bfloat16`. We recommend using -`autocast(torch_xla.device())` on XLA:GPU as it does not require -`torch.cuda` support for any datatypes, including `torch.bfloat16`. - -### AMP for XLA:GPU Best Practices - -1. `autocast` should wrap only the forward pass(es) and loss - computation(s) of the network. Backward ops run in the same type - that autocast used for the corresponding forward ops. -2. Do not set `XLA_USE_F16` flag when using AMP on Cuda devices. This - will override the per-operator precision settings provided by AMP - and cause all operators to execute in float16. -3. Use gradient scaling to prevent float16 gradients from underflowing. -4. Pytorch/XLA provides modified version of - [optimizers](https://github.com/pytorch/xla/tree/master/torch_xla/amp/syncfree) - that avoid the additional sync between device and host. - ## Examples Our [mnist training script](https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist_amp.py) diff --git a/test/test_autocast.py b/test/test_autocast.py index ca1f26c05ec1..32f72ae9762a 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -12,7 +12,7 @@ import collections import unittest from torch.testing._internal.autocast_test_lists import AutocastTestLists -from torch_xla.amp import autocast, GradScaler +from torch_xla.amp import autocast class AutocastTPUTestLists: diff --git a/test/test_train_mp_imagenet_amp.py b/test/test_train_mp_imagenet_amp.py index 290857281fd7..31aaccf179b9 100644 --- a/test/test_train_mp_imagenet_amp.py +++ b/test/test_train_mp_imagenet_amp.py @@ -67,7 +67,7 @@ import torch_xla.utils.utils as xu import torch_xla.core.xla_model as xm import torch_xla.test.test_utils as test_utils -from torch_xla.amp import autocast, GradScaler +from torch_xla.amp import autocast try: from torch_xla.amp import syncfree except ImportError: @@ -220,8 +220,6 @@ def train_imagenet(): if FLAGS.amp: if device_hw == 'TPU': scaler = None - elif device_hw == 'CUDA': - scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad) def train_loop_fn(loader, epoch): tracker = xm.RateTracker() diff --git a/test/test_train_mp_mnist_amp.py b/test/test_train_mp_mnist_amp.py index 0bd393b21f2e..db4516e7d7a4 100644 --- a/test/test_train_mp_mnist_amp.py +++ b/test/test_train_mp_mnist_amp.py @@ -38,7 +38,7 @@ import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp import torch_xla.test.test_utils as test_utils -from torch_xla.amp import autocast, GradScaler +from torch_xla.amp import autocast try: from torch_xla.amp import syncfree except ImportError: @@ -143,11 +143,8 @@ def train_mnist(flags, **kwargs): if device_hw == 'TPU': scaler = None - elif device_hw == 'CUDA': - # GradScaler only used for GPU - scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad) else: - print("Only TPU or GPU supported for AMP.") + print("Only TPU supported for AMP.") sys.exit(1) def train_loop_fn(loader): diff --git a/torch_xla/_internal/pjrt.py b/torch_xla/_internal/pjrt.py index 578e7f77fa6a..20b0a56da774 100644 --- a/torch_xla/_internal/pjrt.py +++ b/torch_xla/_internal/pjrt.py @@ -205,7 +205,7 @@ def spawn(fn: Callable, return _run_singleprocess(spawn_fn) elif nprocs is not None: raise ValueError( - 'Unsupported nprocs (%d). Please use nprocs=1 or None (default). If None, spawn will use all available devices. Use the environment variable X_NUM_DEVICES (where X is CPU, GPU, TPU, NEURONCORE, etc) to limit the number of devices used.' + 'Unsupported nprocs (%d). Please use nprocs=1 or None (default). If None, spawn will use all available devices. Use the environment variable X_NUM_DEVICES (where X is CPU, TPU, NEURONCORE, etc) to limit the number of devices used.' % nprocs) run_multiprocess(spawn_fn, start_method=start_method) diff --git a/torch_xla/amp/__init__.py b/torch_xla/amp/__init__.py index 739f55cc0dcf..5dfca306b7a7 100644 --- a/torch_xla/amp/__init__.py +++ b/torch_xla/amp/__init__.py @@ -1,2 +1 @@ from .autocast_mode import autocast # noqa: F401 -from .grad_scaler import GradScaler # noqa: F401 diff --git a/torch_xla/amp/autocast_mode.py b/torch_xla/amp/autocast_mode.py index 867dddd07bb5..d06836ee58a4 100644 --- a/torch_xla/amp/autocast_mode.py +++ b/torch_xla/amp/autocast_mode.py @@ -10,8 +10,7 @@ class autocast(torch.amp.autocast_mode.autocast): r""" `torch.autocast` for XLA backend devices. See :class:`torch.autocast`. ``torch_xla.amp.autocast(device, **kwargs)`` is equivalent to - ``torch.autocast("xla", **kwargs)`` for XLA:GPU and XLA:TPU for dtype torch.bfloat16, - ``torch.autocast("cuda", **kwargs)`` for XLA:GPU and other dtypes. + ``torch.autocast("xla", **kwargs)`` for XLA:TPU for dtype torch.bfloat16. """ def __init__(self, @@ -20,34 +19,11 @@ def __init__(self, dtype: torch.dtype = None, cache_enabled: bool = True): # `torch_xla.amp.autocast` is intended for XLA backend, with AutocastXLA dispatch key. - assert 'xla' in device.__str__( - ), "torch_xla.autocast is available for XLA:TPU, XLA:GPU" + assert 'xla' in str(device), "torch_xla.autocast is available for XLA:TPU" self._enabled = enabled self._xla_device = xm.xla_device_hw(device) - if self._xla_device == 'CUDA': - backend = 'cuda' - self._xla_bfloat16 = False # True if xla backend with bfloat16 dtype. - if dtype is None: - dtype = torch.float16 - elif dtype == torch.bfloat16 and not torch.cuda.is_available(): - if xr.is_bf16_supported(): - # XLA:GPU with bfloat16 should run on `xla` backend - # unless torch.autocast is compiled with cuda. - backend = 'xla' - self._xla_bfloat16 = True - else: - # This has been the default behavior for unsupported bfloat16 dtype - dtype = torch.float16 - error_message = "In XLA:GPU autocast, but bfloat16 is not supported on this HW.\n" - error_message += ("Using the default cuda autocast dtype float16.") - self._dtype = dtype - super().__init__( - backend, - enabled=enabled, - dtype=self._dtype, - cache_enabled=cache_enabled) - elif self._xla_device == 'TPU' or self._xla_device == 'NEURON': + if self._xla_device == 'TPU' or self._xla_device == 'NEURON': if dtype is None: dtype = torch.bfloat16 if dtype != torch.bfloat16: @@ -63,39 +39,4 @@ def __init__(self, dtype=self._dtype, cache_enabled=cache_enabled) else: - print( - 'Warning: AMP only supported for XLA:TPU and XLA:GPU. Ignoring autocast.' - ) - - def __enter__(self): - # This ensures that xla autocast is enabled even for XLA:GPU, which calls - # `torch.amp.autocast_mode.autocast` with `cuda` backend. - if self._xla_device == 'CUDA': - self.prev = torch.is_autocast_xla_enabled() # type: ignore[attr-defined] - self.prev_dtype = torch.get_autocast_xla_dtype( - ) # type: ignore[attr-defined] - if self._xla_bfloat16: - # autocast_xla flags will be set by `torch.autocast` and we need to - # set autocast flags as we call into `torch.autocast` apis. - torch.set_autocast_enabled(self._enabled) - torch.set_autocast_gpu_dtype(self._dtype) - else: - torch.set_autocast_xla_enabled(self._enabled) - torch.set_autocast_xla_dtype(self._dtype) - return super().__enter__() - - def __exit__(self, exc_type: Any, exc_val: Any, - exc_tb: Any): # type: ignore[override] - if self._xla_device == 'CUDA': - if self._xla_bfloat16: - # autocast_xla flags will be set by `torch.autocast` and we need to - # set autocast flags as we call into `torch.autocast` apis. - torch.set_autocast_enabled(self.prev) - torch.set_autocast_gpu_dtype(self.prev_dtype) - else: - torch.set_autocast_xla_enabled(self.prev) - torch.set_autocast_xla_dtype(self.prev_dtype) - return super().__exit__(exc_type, exc_val, exc_tb) - - def __call__(self, func): - return super().__call__(func) + print('Warning: AMP only supported for XLA:TPU. Ignoring autocast.') diff --git a/torch_xla/amp/grad_scaler.py b/torch_xla/amp/grad_scaler.py deleted file mode 100644 index 62ebf560c8aa..000000000000 --- a/torch_xla/amp/grad_scaler.py +++ /dev/null @@ -1,82 +0,0 @@ -import torch -import torch_xla -import torch_xla.core.xla_model as xm -import torch_xla.core.xla_builder as xb -import torch_xla.core.xla_op_registry as xor -import inspect - - -class GradScaler(torch.cuda.amp.GradScaler): - """ - An torch_xla variant of torch.cuda.amp.GradScaler that helps perform the steps of gradient scaling - conveniently. - Args: - init_scale (float, optional, default=2.**16): Initial scale factor. - growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during - :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. - backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during - :meth:`update` if inf/NaN gradients occur in an iteration. - growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients - that must occur for the scale to be multiplied by ``growth_factor``. - enabled (bool, optional, default=True): If ``False``, disables gradient scaling. :meth:`step` simply - invokes the underlying ``optimizer.step()``, and other methods become no-ops. - use_zero_grad (bool, optional, default=False): If ``True``, enables the torch_xla specific zero gradients - optimization that performs ``optimizer.step()`` with gradients set to zero instead of skipping it when - inf/NaN gradients occur. This may improve the performance by removing the barrier in GradScaler. - """ - - def __init__( - self, - init_scale=2.0**16, - growth_factor=2.0, - backoff_factor=0.5, - growth_interval=2000, - enabled=True, - use_zero_grad=False, - ): - super().__init__( - init_scale=init_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - enabled=enabled, - ) - - def get_scaling_factor(a): - - def if_true(a): - return xb.Op.zero(a.builder()) - - def if_false(a): - return xb.Op.one(a.builder()) - - cond = a != xb.Op.zero(a.builder()) - return cond.mkconditional((a,), if_true, if_false) - - self.get_scaling_factor = xor.register("get_scaling_factor", - get_scaling_factor) - self.use_zero_grad = use_zero_grad - - def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs): - retval = None - is_syncfree_optim = "found_inf" in inspect.signature( - optimizer.step).parameters - if is_syncfree_optim: - found_inf = torch.stack( - tuple(optimizer_state["found_inf_per_device"].values())).sum() - kwargs['found_inf'] = found_inf - retval = optimizer.step(*args, **kwargs) - elif self.use_zero_grad: - found_inf = torch.stack( - tuple(optimizer_state["found_inf_per_device"].values())).sum() - scaling_factor = self.get_scaling_factor(found_inf) - for grad in xm._fetch_gradients(optimizer): - grad.nan_to_num_() - grad.mul_(scaling_factor) - retval = optimizer.step(*args, **kwargs) - else: - torch_xla.sync() - if not sum( - v.item() for v in optimizer_state["found_inf_per_device"].values()): - retval = optimizer.step(*args, **kwargs) - return retval diff --git a/torch_xla/core/xla_env_vars.py b/torch_xla/core/xla_env_vars.py index 2d256c77a540..0bf232a2374d 100644 --- a/torch_xla/core/xla_env_vars.py +++ b/torch_xla/core/xla_env_vars.py @@ -1,6 +1,5 @@ TPUVM_MODE = 'TPUVM_MODE' TPU_NUM_DEVICES = 'TPU_NUM_DEVICES' -GPU_NUM_DEVICES = 'GPU_NUM_DEVICES' CPU_NUM_DEVICES = 'CPU_NUM_DEVICES' CLOUD_TPU_TASK_ID = 'CLOUD_TPU_TASK_ID' ACCELERATOR_TYPE = 'ACCELERATOR_TYPE' @@ -24,7 +23,6 @@ TPU_VISIBLE_CHIPS = 'TPU_VISIBLE_CHIPS' TPU_PROCESS_PORT = 'TPU_PROCESS_PORT' PJRT_CPU_ASYNC_CLIENT = 'PJRT_CPU_ASYNC_CLIENT' -PJRT_GPU_ASYNC_CLIENT = 'PJRT_GPU_ASYNC_CLIENT' PJRT_DIST_SERVICE_ADDR = 'PJRT_DIST_SERVICE_ADDR' LOCAL_RANK = 'LOCAL_RANK' RANK = 'RANK' diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 3dbad1a963eb..d5de93c618de 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -128,7 +128,7 @@ def xla_device(n: Optional[int] = None, specified, the specific XLA device instance will be returned. Otherwise the first device (default 0) will be returned. devkind (string..., optional): If specified, device type such as `TPU`, - `CUDA`, `CPU`, or custom PJRT device. Deprecated. + `CPU`, or custom PJRT device. Deprecated. Returns: A `torch.device` with the requested instance of an XLA device. @@ -152,7 +152,7 @@ def xla_real_devices(devices: Optional[List[torch.device]] = None) -> List[str]: devices: The list of torch devices such as ['xla:0', 'xla:1']. Returns: - A list of real devices' name such as ['CUDA:0', 'CUDA:1']. + A list of real devices' name such as ['CPU:0', 'CPU:1']. """ if not devices: devices = get_xla_supported_devices() @@ -210,7 +210,7 @@ def xla_replication_devices( format(len(local_devices), len(kind_devices))) replication_devices = [] for device in torch_xla._XLAC._xla_get_all_devices(): - # device is like 'CUDA:0' + # device is like 'CPU:0' xdev = _utils.parse_xla_device(device) if not xdev: raise RuntimeError('Invalid device format: {}'.format(device)) @@ -240,7 +240,7 @@ def set_replication(device: torch.device, devctx = _get_device_context(device=device) devices = [str(x) for x in devices] if devices: - # sample replication_devices: ['CUDA:0', 'CUDA:1', 'CUDA:2', 'CUDA:3'] + # sample replication_devices: ['CPU:0', 'CPU:1', 'CPU:2', 'CPU:3'] replication_devices = xla_replication_devices(devices) torch_xla._XLAC._xla_set_replication_devices(replication_devices) devctx.device_index = devices.index(device) diff --git a/torch_xla/debug/profiler.py b/torch_xla/debug/profiler.py index 4e046248b42f..ffbd754b1c76 100644 --- a/torch_xla/debug/profiler.py +++ b/torch_xla/debug/profiler.py @@ -72,7 +72,7 @@ def trace(service_addr: str, in case of failures. host_tracer_level (int): CPU tracing level. Values are: 1 - critical info only, 2 - info, 3 - verbose. - device_tracer_level (int): Device (TPU/GPU) tracing level. Values are: 1 - + device_tracer_level (int): Device (TPU) tracing level. Values are: 1 - enabled, 0 - disabled. delay_ms (int): Specifies the services to start profiling delay_ms milliseconds after the current time. @@ -218,7 +218,7 @@ def reset(self): def start_trace(log_dir: Union[os.PathLike, str]) -> None: """Starts a profiler trace. - The trace will capture CPU, GPU, and/or TPU activity, including Python + The trace will capture CPU, and/or TPU activity, including Python functions and PyTorch/XLA on-device operations. Use :func:`stop_trace` to end the trace and save the results to ``log_dir``. diff --git a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py index c5605d2b3ed2..e7006213ea62 100644 --- a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py +++ b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py @@ -280,7 +280,7 @@ def custom_auto_wrap_policy( >>> # responsible for initializing a module, such as with reset_parameters >>> ... >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy) - >>> print(next(fsdp_model.parameters()).device) # current CUDA device + >>> print(next(fsdp_model.parameters()).device) >>> # With torchdistX >>> module = deferred_init.deferred_init(MyModule, device="cuda") >>> # Will initialize via deferred_init.materialize_module(). diff --git a/torch_xla/experimental/distributed_checkpoint/manager.py b/torch_xla/experimental/distributed_checkpoint/manager.py index bcb958a6d2e5..d1ef2799a652 100644 --- a/torch_xla/experimental/distributed_checkpoint/manager.py +++ b/torch_xla/experimental/distributed_checkpoint/manager.py @@ -149,7 +149,6 @@ def __init__(self, self.chkpt_on_preemption = chkpt_on_preemption # Create a new group if none is provided - # TODO(jonbolin): Verify subgroup on GPU backend self.pg = process_group or dist.new_group() # Thread pool to run the async checkpoints. `_async_sem` is used to guard diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index d83bcba8a1dd..4e6352bc5527 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -162,8 +162,6 @@ def host_index() -> int: return plugins.default().host_index() elif device_type() == 'TPU': return tpu.worker_id() - - # TODO: Update this when we support multi-host GPU return 0 diff --git a/torch_xla/utils/checkpoint.py b/torch_xla/utils/checkpoint.py index 220dbe011882..5fa15f825c54 100644 --- a/torch_xla/utils/checkpoint.py +++ b/torch_xla/utils/checkpoint.py @@ -54,7 +54,7 @@ def set_device_states(devices: List[torch.device], assert all(isinstance(v, state_0_type) for v in states), f"all device states should have the same type" - device_module = xm if state_0_type == int else get_device_module(*states) + device_module = xm if state_0_type == int else _get_device_module(*states) for device, state in zip(devices, states): device_module.set_rng_state(state, device=device) From 94fdadcb088f3ff621772ee481e4e07e49daa226 Mon Sep 17 00:00:00 2001 From: qihqi Date: Tue, 2 Sep 2025 16:10:40 -0700 Subject: [PATCH 086/133] Update version to 0.0.6 (#9611) --- torchax/torchax/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchax/torchax/__init__.py b/torchax/torchax/__init__.py index fe4c1c8ff046..240cd70175a7 100644 --- a/torchax/torchax/__init__.py +++ b/torchax/torchax/__init__.py @@ -8,7 +8,7 @@ from torchax import tensor from contextlib import contextmanager -__version__ = "0.0.5" +__version__ = "0.0.6" VERSION = __version__ __all__ = [ From ddf75a17c8b8e2d4b82b95071cc14eff98ff3c71 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 3 Sep 2025 12:41:39 -0300 Subject: [PATCH 087/133] Remove CUDA from PyTorch/XLA build. (#9609) This PR removes the CUDA specific build related source code. This is in line with the CUDA deprecation that started on release 2.8. **Key Changes:** - (`.bazelrc`) Removed CUDA specific environment variables - (`WORKSPACE`) Removed CUDA dependencies and CUDA specific OpenXLA patches - (`BUILD` files) Removed CUDA specific build dependencies --- .bazelrc | 6 +++--- BUILD | 9 +-------- WORKSPACE | 16 +++++++-------- openxla_patches/gpu_nvml.diff | 26 ------------------------- openxla_patches/gpu_race_condition.diff | 14 ------------- torch_xla/csrc/runtime/BUILD | 12 ++---------- 6 files changed, 13 insertions(+), 70 deletions(-) delete mode 100644 openxla_patches/gpu_nvml.diff delete mode 100644 openxla_patches/gpu_race_condition.diff diff --git a/.bazelrc b/.bazelrc index 3dec0dc40643..9c2667a8ac1b 100644 --- a/.bazelrc +++ b/.bazelrc @@ -91,7 +91,7 @@ build:short_logs --output_filter=DONT_MATCH_ANYTHING #build:tpu --@xla//xla/python:enable_tpu=true build:tpu --define=with_tpu_support=true -# Run tests serially with TPU and GPU (only 1 device is available). +# Run tests serially with TPU (only 1 device is available). test:tpu --local_test_jobs=1 ######################################################################### @@ -100,11 +100,11 @@ test:tpu --local_test_jobs=1 common --experimental_repo_remote_exec # Inherit environmental variables that are used in testing. -test --test_env=TPU_NUM_DEVICES --test_env=GPU_NUM_DEVICES --test_env=CPU_NUM_DEVICES --test_env=XRT_LOCAL_WORKER +test --test_env=TPU_NUM_DEVICES --test_env=CPU_NUM_DEVICES --test_env=XRT_LOCAL_WORKER test --test_env=XRT_TPU_CONFIG --test_env=XRT_DEVICE_MAP --test_env=XRT_WORKERS --test_env=XRT_MESH_SERVICE_ADDRESS test --test_env=XRT_SHARD_WORLD_SIZE --test_env=XRT_MULTI_PROCESSING_DEVICE --test_env=XRT_HOST_ORDINAL --test_env=XRT_SHARD_ORDINAL test --test_env=XRT_START_LOCAL_SERVER --test_env=TPUVM_MODE --test_env=PJRT_DEVICE --test_env=PJRT_TPU_MAX_INFLIGHT_COMPUTATIONS -test --test_env=PJRT_CPU_ASYNC_CLIENT --test_env=PJRT_GPU_ASYNC_CLIENT --test_env=TPU_LIBRARY_PATH --test_env=PJRT_DIST_SERVICE_ADDR +test --test_env=PJRT_CPU_ASYNC_CLIENT --test_env=TPU_LIBRARY_PATH --test_env=PJRT_DIST_SERVICE_ADDR test --test_env=PJRT_LOCAL_PROCESS_RANK # This environmental variable is important for properly integrating with XLA. diff --git a/BUILD b/BUILD index 1b82e9d4b975..128f83dcd56e 100644 --- a/BUILD +++ b/BUILD @@ -1,8 +1,3 @@ -load( - "@xla//xla/tsl/platform/default:cuda_build_defs.bzl", - "if_cuda_is_configured", -) - load("@python//:defs.bzl", "compile_pip_requirements") load("@python_version_repo//:py_version.bzl", "REQUIREMENTS") @@ -41,9 +36,7 @@ cc_binary( "@torch//:libtorch", "@torch//:libtorch_cpu", "@torch//:libtorch_python", - ] + if_cuda_is_configured([ - "@xla//xla/stream_executor:cuda_platform", - ]), + ], ) test_suite( diff --git a/WORKSPACE b/WORKSPACE index 8222c5797bba..70b7d9cc098d 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -56,8 +56,6 @@ http_archive( ], patch_tool = "patch", patches = [ - "//openxla_patches:gpu_nvml.diff", - "//openxla_patches:gpu_race_condition.diff", "//openxla_patches:no_fortify.diff", ], strip_prefix = "xla-" + xla_hash, @@ -142,16 +140,16 @@ load("@xla//:workspace0.bzl", "xla_workspace0") xla_workspace0() +# Even though we don't support XLA:CUDA anymore, we still need to keep the +# following. The reason being that `pjrt_computation_client_test` depends on +# `@xla//xla/tools`, which calls: +# +# ``` +# load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")` +# ``` load( "@xla//third_party/gpus:cuda_configure.bzl", "cuda_configure", ) cuda_configure(name = "local_config_cuda") - -load( - "@xla//third_party/nccl:nccl_configure.bzl", - "nccl_configure", -) - -nccl_configure(name = "local_config_nccl") diff --git a/openxla_patches/gpu_nvml.diff b/openxla_patches/gpu_nvml.diff deleted file mode 100644 index fd38807775a9..000000000000 --- a/openxla_patches/gpu_nvml.diff +++ /dev/null @@ -1,26 +0,0 @@ -iff --git a/xla/service/gpu/model/gpu_collective_performance_model.cc b/xla/service/gpu/model/gpu_collective_performance_model.cc -index 496969f545..2d9f73ee36 100644 ---- a/xla/service/gpu/model/gpu_collective_performance_model.cc -+++ b/xla/service/gpu/model/gpu_collective_performance_model.cc -@@ -34,7 +34,7 @@ limitations under the License. - - #if GOOGLE_CUDA - #include "third_party/gpus/cuda/include/cuda.h" --#include "third_party/gpus/cuda/nvml/include/nvml.h" -+#include "third_party/gpus/cuda/include/nvml.h" - #endif // GOOGLE_CUDA - namespace xla { - namespace gpu { -diff --git a/xla/service/gpu/model/gpu_collective_performance_model.h b/xla/service/gpu/model/gpu_collective_performance_model.h -index 01c3f3eb45..f44057602b 100644 ---- a/xla/service/gpu/model/gpu_collective_performance_model.h -+++ b/xla/service/gpu/model/gpu_collective_performance_model.h -@@ -32,7 +32,7 @@ limitations under the License. - #include - #endif - --#include "third_party/gpus/cuda/nvml/include/nvml.h" -+#include "third_party/gpus/cuda/include/nvml.h" - // Below is a list of function pointers to be used - // for querying device properties through nvml library. - #define NVML_FUNCTOR(name, rettype, args) \ \ No newline at end of file diff --git a/openxla_patches/gpu_race_condition.diff b/openxla_patches/gpu_race_condition.diff deleted file mode 100644 index 082376116a3e..000000000000 --- a/openxla_patches/gpu_race_condition.diff +++ /dev/null @@ -1,14 +0,0 @@ -diff --git a/xla/service/gpu/gpu_executable.cc b/xla/service/gpu/gpu_executable.cc -index 9279bd877..fab926a7c 100644 ---- a/xla/service/gpu/gpu_executable.cc -+++ b/xla/service/gpu/gpu_executable.cc -@@ -669,8 +669,7 @@ absl::StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( - #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - - // Force synchronous execution if the allocator requires it. -- const bool block_host_until_done = -- !memory_allocator->AllowsAsynchronousDeallocation(); -+ const bool block_host_until_done = true; - - // Lock the GPU with a shared lock so that we don't interfere with autotuning - // that may be running during JIT compilation while allowing multiple XLA \ No newline at end of file diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index b381d3feff7c..4f0f3bf384ed 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -2,10 +2,6 @@ load( "//bazel:rules_def.bzl", "ptxla_cc_test", ) -load( - "@xla//xla/tsl/platform/default:cuda_build_defs.bzl", - "if_cuda_is_configured", -) licenses(["notice"]) # Apache 2.0 @@ -134,7 +130,6 @@ cc_library( "@xla//xla:shape_util", "@xla//xla/hlo/builder:xla_computation", "@xla//xla/pjrt:pjrt_client", - "@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_wrapper_impl", "@xla//xla/pjrt:pjrt_c_api_client", @@ -216,10 +211,9 @@ cc_library( "@torch//:headers", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:initialize", + "@xla//xla/pjrt/distributed:in_memory_key_value_store", "@xla//xla/pjrt:pjrt_c_api_client", "@xla//xla/pjrt:tfrt_cpu_pjrt_client", - "@xla//xla/pjrt/gpu:se_gpu_pjrt_client", - "@xla//xla/service:gpu_plugin", ], ) @@ -295,9 +289,7 @@ cc_library( deps = [ "@xla//xla/backends/profiler/cpu:host_tracer", "@xla//xla/backends/profiler/cpu:metadata_collector", - ] + if_cuda_is_configured([ - "@xla//xla/backends/profiler/gpu:device_tracer", - ]), + ], alwayslink = True, ) From 8ff2ee68693f4f04f3fc757de96f5e25765e58f7 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 3 Sep 2025 13:50:58 -0300 Subject: [PATCH 088/133] Remove CUDA from `benchmarks` directory. (#9610) This PR removes the CUDA specific code from the `benchmarks` directory. This is in line with the CUDA deprecation that started on release 2.8. **Key Changes:** - Removed the `keep_model_data_on_cuda` parameter - Used in combination with zero-overhead CUDA to XLA:CUDA data movement, removed in [#9598][1] and [#9603][2] - Deleted `llama.py`, `nightly.sh`, `run_benchmark.sh`, `run_single_graph_bm.sh`, and `run_top_tier_bm.sh` - All of them ran benchmarks comparing PyTorch Inductor with XLA:CUDA, specifically [1]: https://github.com/pytorch/xla/pull/9598 [2]: https://github.com/pytorch/xla/pull/9603 --- benchmarks/README.md | 10 +- benchmarks/benchmark_experiment.py | 27 +- benchmarks/benchmark_model.py | 11 +- benchmarks/experiment_runner.py | 5 - benchmarks/llama.py | 269 ------------------- benchmarks/nightly.sh | 258 ------------------ benchmarks/result_analyzer.py | 53 ++-- benchmarks/run_benchmark.sh | 79 ------ benchmarks/run_single_graph_bm.sh | 26 -- benchmarks/run_top_tier_bm.sh | 25 -- benchmarks/torchbench_model.py | 11 +- benchmarks/util.py | 13 +- benchmarks/verifier.py | 2 +- test/benchmarks/test_benchmark_experiment.py | 5 +- test/benchmarks/test_experiment_runner.py | 106 ++------ 15 files changed, 62 insertions(+), 838 deletions(-) delete mode 100644 benchmarks/llama.py delete mode 100755 benchmarks/nightly.sh delete mode 100644 benchmarks/run_benchmark.sh delete mode 100755 benchmarks/run_single_graph_bm.sh delete mode 100755 benchmarks/run_top_tier_bm.sh diff --git a/benchmarks/README.md b/benchmarks/README.md index 9476ecbcb50c..fe4bdc309b4f 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -77,7 +77,7 @@ Disable autoboost selecting clock rate based on thermal, and power budget effect Run the `experiment_runner.py` from the `pytorch` directory, which should be the parent of the `xla` directory. -The following example runs the alexnet benchmark on GPU through the +The following example runs the alexnet benchmark on CPU through the Pytorch/XLA-dynamo path and through the Inductor-dynamo with 5 repetitions each. The results will be stored in a json file (eg results.jsonl) in `experiment_results`. @@ -88,7 +88,7 @@ python xla/benchmarks/experiment_runner.py \ --xla=PJRT --xla=None \ --test=eval --test=train \ --suite-name=torchbench \ - --accelerator=cuda \ + --accelerator=cpu \ --output-dirname=experiment_results \ --repeat=5 \ --print-subprocess \ @@ -118,7 +118,7 @@ python xla/benchmarks/experiment_runner.py \ --suite-name=torchbench \ --progress-bar \ --model-config='{"model_name":"BERT_pytorch"}' \ - --experiment-config='{"accelerator":"cuda","xla":"PJRT","xla_flags":null,"dynamo":"openxla","torch_xla2":null,"test":"train","keep_model_data_on_cuda":false,"enable_functionalization":false}' \ + --experiment-config='{"accelerator":"cpu","xla":"PJRT","xla_flags":null,"dynamo":"openxla","torch_xla2":null,"test":"train","enable_functionalization":false}' \ --repeat 1 ``` @@ -135,13 +135,13 @@ works only for inference now. ``` cd pytorch -PJRT_DEVICE=CUDA python3 new_xla/benchmarks/experiment_runner.py \ +PJRT_DEVICE=CPU python3 new_xla/benchmarks/experiment_runner.py \ --xla=PJRT \ --dynamo=openxla \ --test=eval \ --filter=BERT_pytorch$ \ --suite-name=torchbench \ - --accelerator=cuda \ + --accelerator=cpu \ --progress-bar \ --output-dirname=/tmp/output \ --repeat=2 \ diff --git a/benchmarks/benchmark_experiment.py b/benchmarks/benchmark_experiment.py index e1fab48334a8..daffdce2f7f5 100644 --- a/benchmarks/benchmark_experiment.py +++ b/benchmarks/benchmark_experiment.py @@ -20,13 +20,12 @@ def list_experiment_configs(self): # Start with default config. config_choices = { - "accelerator": ["cpu", "cuda", "tpu"], + "accelerator": ["cpu", "tpu"], "xla": [None, "PJRT", "XRT"], "xla_flags": [None], "dynamo": [None, "inductor", "openxla"], "torch_xla2": [None], # options only apply to torch_xla2 "test": ["eval", "train"], - "keep_model_data_on_cuda": [False], "enable_functionalization": [False], } @@ -46,10 +45,6 @@ def list_experiment_configs(self): if self._args.xla_flags: config_choices["xla_flags"] = list( map(parse_none_str, set(self._args.xla_flags))) - if self._args.keep_model_data_on_cuda: - config_choices["keep_model_data_on_cuda"] = [ - self._args.keep_model_data_on_cuda - ] if self._args.enable_functionalization: config_choices["enable_functionalization"] = [ self._args.enable_functionalization @@ -85,7 +80,6 @@ def _is_available(self, cfg_xla = experiment_config["xla"] cfg_test = experiment_config["test"] cfg_torch_xla2 = experiment_config["torch_xla2"] - cfg_keep_model_data_on_cuda = experiment_config["keep_model_data_on_cuda"] # Check that dynamo refers to an existing backend. if cfg_dynamo is not None and cfg_dynamo not in dynamo.list_backends( @@ -118,16 +112,16 @@ def _is_available(self, if cfg_accelerator == "tpu": if cfg_xla is None: return False - elif cfg_accelerator in ("cpu", "cuda"): + elif cfg_accelerator == "cpu": if cfg_xla == "XRT": return False + elif cfg_accelerator == "cuda": + if cfg_xla is not None: + # PyTorch/XLA with CUDA backend is no longer supported. + return False else: raise NotImplementedError - # cfg_keep_model_data_on_cuda is only avaible when using dynamo - if cfg_keep_model_data_on_cuda and cfg_dynamo != "openxla": - return False - return True def load_experiment(self, @@ -140,7 +134,6 @@ def load_experiment(self, test = experiment_config["test"] batch_size = experiment_config.get("batch_size", self._args.batch_size) torch_xla2 = experiment_config["torch_xla2"] - keep_model_data_on_cuda = experiment_config["keep_model_data_on_cuda"] enable_functionalization = experiment_config["enable_functionalization"] return BenchmarkExperiment( accelerator=accelerator, @@ -148,7 +141,6 @@ def load_experiment(self, xla_flags=xla_flags, dynamo=dynamo, torch_xla2=torch_xla2, - keep_model_data_on_cuda=keep_model_data_on_cuda, test=test, batch_size=batch_size, enable_functionalization=enable_functionalization, @@ -159,14 +151,12 @@ class BenchmarkExperiment: def __init__(self, accelerator: str, xla: Optional[str], xla_flags: Optional[str], dynamo: str, torch_xla2: bool, - keep_model_data_on_cuda: bool, test: str, batch_size: str, - enable_functionalization: bool): + test: str, batch_size: str, enable_functionalization: bool): self.accelerator = accelerator self.xla = xla self.xla_flags = xla_flags self.dynamo = dynamo self.torch_xla2 = torch_xla2 - self.keep_model_data_on_cuda = keep_model_data_on_cuda self.test = test self.batch_size = batch_size self.accelerator_model = get_accelerator_model(self.accelerator) @@ -191,8 +181,6 @@ def update_process_env(self, process_env: Dict[str, str]): if is_xla_device_available("TPU"): process_env["TPU_NUM_DEVICES"] = "1" process_env["XRT_TPU_CONFIG"] = "localservice;0;localhost:51011" - elif is_xla_device_available("CUDA"): - process_env["GPU_NUM_DEVICES"] = "1" elif self.xla is None: # In non-xla CPU training experiments, an env var is still needed if an # xla device exists, or there will be "Missing XLA configuration" error. @@ -246,7 +234,6 @@ def to_dict(self): d["xla_flags"] = self.xla_flags d["dynamo"] = self.dynamo d["torch_xla2"] = self.torch_xla2 - d["keep_model_data_on_cuda"] = self.keep_model_data_on_cuda d["test"] = self.test d["batch_size"] = self.batch_size d["enable_functionalization"] = self.enable_functionalization diff --git a/benchmarks/benchmark_model.py b/benchmarks/benchmark_model.py index 2b2f6c1957b7..008a4539c7a9 100644 --- a/benchmarks/benchmark_model.py +++ b/benchmarks/benchmark_model.py @@ -103,7 +103,6 @@ def prepare_for_experiment( else: raise NotImplementedError - keep_model_data_on_cuda = self.benchmark_experiment.keep_model_data_on_cuda if self.benchmark_experiment.torch_xla2: import torch_xla2.export import torch_xla2 @@ -125,7 +124,7 @@ def prepare_for_experiment( self.module = lambda *x: jax_func(weights, x) self.example_inputs = move_to_device( self.example_inputs, device, torch_xla2=True) - elif not keep_model_data_on_cuda: + else: self.module = self.module.to(self.device) self.example_inputs = move_to_device( self.example_inputs, self.device, torch_xla2=False) @@ -137,14 +136,6 @@ def prepare_for_experiment( logger.info(f"Running torch.compile with opts {compilation_opts}") self.model_iter_fn = torch.compile(self.model_iter_fn, **compilation_opts) - if keep_model_data_on_cuda: - - def assert_func(t): - assert t.device.type.lower( - ) == 'cuda', 'When keep_model_data_on_cuda is set, the input data should remain on the CUDA device.' - - pytree.tree_map_only(torch.Tensor, assert_func, self.example_inputs) - def pick_grad(self): if self.benchmark_experiment.test == "eval": return torch.no_grad() diff --git a/benchmarks/experiment_runner.py b/benchmarks/experiment_runner.py index b784af68e47b..04a5524ad387 100644 --- a/benchmarks/experiment_runner.py +++ b/benchmarks/experiment_runner.py @@ -936,11 +936,6 @@ def __str__(self): help="""Collect CUDA and CPU times per operation. This will also gather CPU fallbacks.""", ) - parser.add_argument( - "--keep-model-data-on-cuda", - action="store_true", - help="""Whether to keep the model and data on CUDA and not to move to an XLA device. This is to be used with PyTorch/XLA dynamo. When set, PyTorch/XLA dynamo bridge move the model and data to the XLA device.""", - ) parser.add_argument( "--xla-flags", type=str, diff --git a/benchmarks/llama.py b/benchmarks/llama.py deleted file mode 100644 index 53e88ddb0d2f..000000000000 --- a/benchmarks/llama.py +++ /dev/null @@ -1,269 +0,0 @@ -import argparse -import datetime -import logging -import json -import os -import re -import subprocess -import sys - -from enum import Enum - -logger = logging.getLogger(__name__) - - -def get_info_from_result_file(results_dir: str) -> tuple[str, str, float]: - results_file = os.path.join(results_dir, 'results.jsonl') - if not os.path.exists(results_file): - sys.exit(f"Results file {results_file} not found. " - "Please run experiment_runner.py first.") - accelerator_model = None - with open(results_file, 'r') as f: - first_line = f.readline() - acc_match = re.search(r'"accelerator_model": "([^"]+)"', first_line) - time_match = re.search(r'"timestamp": ([0-9.]+)', first_line) - if acc_match and time_match: - accelerator_model = acc_match.group(1) - timestamp = float(time_match.group(1)) - else: - sys.exit(f"Cannot find a timestamp and a matching accelerator " - "in {results_file}.") - logger.debug(f"Found accelerator_model='{accelerator_model}' and " - f"timestamp={timestamp} in {results_file}.") - return accelerator_model, timestamp - - -def set_up_llama_repo(workspace_dir: str) -> str: - llama_dir = os.path.join(workspace_dir, 'llama-inference') - if os.path.exists(llama_dir): - logger.debug(f'llama_dir={llama_dir} already exists; no setting up to do.') - return llama_dir - - logger.debug(f'Setting up llama repo at {llama_dir}.') - subprocess.check_call([ - 'git', 'clone', 'https://github.com/pytorch-tpu/llama.git', '--branch', - 'llama2-google-next-inference', llama_dir - ]) - subprocess.check_call( - ['pip', 'install', '-r', - os.path.join(llama_dir, 'requirements.txt')]) - subprocess.check_call(['pip', 'install', '-e', llama_dir]) - - # Create model JSON files - model_configs = { - '7b.json': { - "dim": 4096, - "multiple_of": 256, - "n_heads": 32, - "n_layers": 32, - "norm_eps": 1e-05, - "vocab_size": -1 - }, - '13b.json': { - "dim": 5120, - "multiple_of": 256, - "n_heads": 40, - "n_layers": 40, - "norm_eps": 1e-05, - "vocab_size": -1 - }, - '70b.json': { - "dim": 8192, - "multiple_of": 4096, - "ffn_dim_multiplier": 1.3, - "n_heads": 64, - "n_kv_heads": 8, - "n_layers": 80, - "norm_eps": 1e-05, - "vocab_size": -1 - } - } - for filename, config in model_configs.items(): - filepath = os.path.join(llama_dir, filename) - with open(filepath, 'w') as f: - json.dump(config, f) - f.write("\n") - return llama_dir - - -def parse_log_file(log_file: str): - latencies = [] - with open(log_file, 'r') as f: - for line in f: - if ('Totally decoded ' not in line or 'tokens in' not in line or - ' seconds' not in line): - continue - parts = line.strip().split() - tokens = float(parts[2]) - seconds = float(parts[5]) - latency_per_token = seconds / tokens - latencies.append(latency_per_token) - logger.debug(f'{log_file}: Found latencies={latencies}') - return latencies - - -def benchmark_has_already_run(results_file: str, model_name: str, xla: str, - dynamo: str, batch_size: int): - with open(results_file, 'r') as f: - for line in f: - # Grep for relevant lines to avoid parsing the entire JSONL file. - if f'"model_name": "{model_name}"' not in line: - continue - r = json.loads(line.rstrip('\n|\r')) - # yapf: disable - if all( - r.get(k1, {}).get(k2) == v - for (k1, k2, v) in [ - ('experiment', 'accelerator', 'cuda'), - ('experiment', 'batch_size', batch_size), - ('experiment', 'dynamo', dynamo), - ('experiment', 'test', 'eval'), - ('experiment', 'xla', xla), - ('experiment', 'xla_flags', None), - ('model', 'model_name', model_name), - ]): - return True - # yapf: enable - return False - - -def run_benchmarks(args, llama_dir: str, results_dir: str, - accelerator_model: str, timestamp: float): - os.chdir(llama_dir) - for size in ['7b', '13b', '70b']: - params_json = 'params.json' - if os.path.exists(params_json): - os.remove(params_json) - os.symlink(f'{size}.json', params_json) - model_name = f"llama2.{size}" - for dynamo in [None, 'inductor', 'openxla']: - backend = dynamo if dynamo else 'lazytensor' - xla = None if dynamo == 'inductor' else 'PJRT' - summary = f"{model_name} eval {backend} batch {args.batch_size}" - - results_file = os.path.join(results_dir, 'results.jsonl') - if benchmark_has_already_run(results_file, model_name, xla, dynamo, - args.batch_size): - logger.info(f"SKIP already completed benchmark -- {summary}") - continue - - logger.info(f"RUN {summary}") - log_file = os.path.join(results_dir, - f'llama-inference.{backend}.{size}.log') - - cmd = [ - 'python', 'example_text_completion.py', '1', '--ckpt_dir', '.', - '--tokenizer_path', - os.path.join(llama_dir, 't5_tokenizer/spiece.model'), '--max_seq_len', - '2048', '--max_gen_len', '1000', f'--max_batch_size', - f'{args.batch_size}', '--mp', 'True', f'--repeat', f'{args.repeat}', - f'--dynamo', f'"{dynamo}"' if dynamo else "''" - ] - - run_env = os.environ.copy() - if dynamo == 'inductor': - run_env['CUDA_VISIBLE_DEVICES'] = '0' - run_env['USE_CUDA'] = '1' - else: - run_env['PJRT_DEVICE'] = 'CUDA' - run_env['GPU_NUM_DEVICES'] = '1' - - run_ok = True - with open(log_file, 'w') as f: - try: - subprocess.check_call(cmd, stdout=f, stderr=f, env=run_env) - except subprocess.CalledProcessError: - logger.warning(f"Run failed -- see {log_file}.") - run_ok = False - - result = { - 'model': { - 'suite_name': 'llama2', - 'model_name': model_name, - }, - 'experiment': { - 'accelerator': 'cuda', - 'accelerator_model': accelerator_model, - 'xla': xla, - 'xla_flags': None, - 'dynamo': dynamo, - 'test': 'eval', - 'batch_size': args.batch_size, - }, - 'repeat': args.repeat, - 'iterations_per_run': 1, - 'metrics': { - # Filled in below. - }, - 'timestamp': timestamp, - } - if run_ok: - latencies = parse_log_file(log_file) - result['metrics']['total_time'] = latencies - else: - result['metrics']['error'] = f"Run failed -- see {log_file}." - - with open(results_file, mode="a", encoding="utf-8") as f: - json.dump(result, f, ensure_ascii=False) - f.write("\n") - - -def parse_args(): - # Helper class for --log-level flag. - class LogLevel(Enum): - critical = logging.CRITICAL - error = logging.ERROR - warning = logging.WARNING - info = logging.INFO - debug = logging.DEBUG - - @staticmethod - def parse(s: str): - try: - return LogLevel[s] - except KeyError: - raise ValueError() - - def __str__(self): - return self.name - - parser = argparse.ArgumentParser(description='Run Llama inference benchmarks') - parser.add_argument('--batch_size', type=int, default=1, help='Batch size.') - parser.add_argument( - '--log-level', - default=LogLevel.info, - choices=list(LogLevel), - type=LogLevel.parse, - help='Log level') - parser.add_argument( - '--repeat', type=int, default=8, help='Number of repetitions') - parser.add_argument( - '--workspace_dir', type=str, required=True, help='Workspace directory.') - args = parser.parse_args() - - return args - - -def main(): - args = parse_args() - logging.basicConfig(level=args.log_level.value, force=True) - args.workspace_dir = os.path.expanduser(args.workspace_dir) - if not os.path.exists(args.workspace_dir): - sys.exit(f"Workspace directory {args.workspace_dir} not found.") - - # Sanity check: we should already be inside the appropriate venv. - workspace_dir = os.path.realpath(args.workspace_dir) - logger.debug(f'workspace_dir realpath: {workspace_dir}') - if sys.prefix != os.path.join(workspace_dir, 'env'): - sys.exit( - "Error: must run under the Python venv from the given --workspace_dir.") - - results_dir = os.path.join(workspace_dir, 'experiment_results') - accelerator_model, timestamp = get_info_from_result_file(results_dir) - llama_dir = set_up_llama_repo(workspace_dir) - - run_benchmarks(args, llama_dir, results_dir, accelerator_model, timestamp) - - -if __name__ == "__main__": - main() diff --git a/benchmarks/nightly.sh b/benchmarks/nightly.sh deleted file mode 100755 index 64b34055cbf9..000000000000 --- a/benchmarks/nightly.sh +++ /dev/null @@ -1,258 +0,0 @@ -#!/bin/bash -# Pytorch/XLA Nightly Benchmark Runner. - -set -ex - -ACCELERATOR=a100 -OUTPUT_DIR=${HOME:?} -WORKSPACE=$(date --utc +%Y-%m-%d) -REPEAT=8 -ENABLE_PROFILING= - -while getopts 'A:O:PR:T:W:' OPTION -do - case ${OPTION?} in - A) - ACCELERATOR=${OPTARG:?} - ;; - O) - OUTPUT_DIR=${OPTARG:?} - ;; - P) - ENABLE_PROFILING=1 - ;; - R) - REPEAT=${OPTARG:?} - ;; - T) - # Avoid printing the token; re-enable printing later. - { set +x; } 2>/dev/null - export HUGGING_FACE_HUB_TOKEN=${OPTARG:?} - set -x - ;; - W) - WORKSPACE=${OPTARG:?} - ;; - esac -done - -NIGHTLY_RUNS=nightly_runs -NIGHTLY_RESULTS=nightly_results -if [[ ${ENABLE_PROFILING?} ]]; then - NIGHTLY_RUNS=nightly_profiling_runs - NIGHTLY_RESULTS=nightly_profiling_results -fi -WORKSPACE_DIR=${OUTPUT_DIR:?}/${NIGHTLY_RUNS:?}/${WORKSPACE:?} -BM_DIR=${WORKSPACE_DIR:?}/pytorch/xla/benchmarks - -# Intermediate results, which are processed to generate reports. -WORKSPACE_RESULTS_DIR=${WORKSPACE_DIR:?}/experiment_results - -# Final data files and reports go here. -NIGHTLY_RESULTS_DIR=${OUTPUT_DIR:?}/${NIGHTLY_RESULTS:?} - -# Init workspace -# -# Sometimes a run fails halfway. Typically this is because -# experiment_runner crashes. We then fix the problem and -# run the script again, which skips the build phase. -IS_FRESH_RUN=1 # Set to null below; read with ${IS_FRESH_RUN?}. -if [ -d ${WORKSPACE_DIR:?} ]; then - IS_FRESH_RUN= -fi - -if [[ ${IS_FRESH_RUN?} ]]; then - rm -rf ${HOME:?}/.cache/bazel -fi - -mkdir -p ${WORKSPACE_DIR:?} -cd ${WORKSPACE_DIR:?} - -ENV_DIR=env -if [[ ${IS_FRESH_RUN?} ]]; then - python3 -m venv ${ENV_DIR:?} -fi -source ${ENV_DIR:?}/bin/activate - -# Download and build everything -if [[ ${IS_FRESH_RUN?} ]]; then - # Install deps - pip install --upgrade pip - - TIMESTAMP=$(date +%s) - # Clone repos first so that their HEAD is as close as possible to $TIMESTAMP. - git clone https://github.com/pytorch/pytorch.git - git clone https://github.com/pytorch/xla.git pytorch/xla - git clone https://github.com/pytorch/vision.git - git clone https://github.com/pytorch/audio.git - git clone https://github.com/pytorch/benchmark.git - - # Set up pytorch - cd pytorch - pip install -r requirements.txt - make triton - USE_CUDA=1 python setup.py develop - cd .. - - # Set up pytorch/xla - cd pytorch/xla - # Query local compute capability. If that fails, assign a sane default. - LOCAL_CAP=compute_$(nvidia-smi --query-gpu=compute_cap --format=csv | \ - tail -1 | sed 's/\.//g' | grep -E '^[0-9]{2}$' || echo '80') - python setup.py develop - cd ../.. - - # Set up torchbench deps. - cd vision - python setup.py develop - cd .. - cd audio - python setup.py develop - cd .. - - # Set up torchbench - cd benchmark - USE_CUDA=1 python install.py - cd .. - - # Apply local patches - cd benchmark - git apply ../pytorch/xla/benchmarks/patches/mismatched_batch_size.patch - cd .. -else - # Grab the timestamp from the first result, if it exists. - # Otherwise take the current timestamp. - TIMESTAMP=$(head -1 ${WORKSPACE_RESULTS_DIR:?}/results.jsonl | \ - sed -E 's|.*\"timestamp\": ([0-9.]+).*|\1|' | \ - grep -E '^[0-9.]+$' || date +%s) -fi - -# Stabilize clock freqs -sudo nvidia-smi --lock-gpu-clocks=1200,1200 - -# Note: this doesn't work on GCP because it's a VM. -# Moreover, we should look into disabling turbo boost if possible. -# sudo cpupower frequency-set --governor performance - -PROFILING_FLAGS= -if [[ ${ENABLE_PROFILING?} ]]; then - PROFILING_FLAGS="--dump-dynamo-counters \ - --collect-dynamo-counters \ - --dump-pytorch-profiles \ - --dump-pytorch-xla-metrics \ - --profile-cuda-cpu \ - --profile-cuda-cpu-individual-ops" -fi - -# Run the experiments -cd pytorch -# Note: to avoid running in Eager mode (i.e. --xla=None --dynamo=None), -# we split experiment_runner.py's invocation in two. -# -# Inference + Training: XLA Lazy tensors, XLA+XLA_Eval Dynamo. -python xla/benchmarks/experiment_runner.py \ - --test=eval --test=train \ - --xla=PJRT \ - --dynamo=None --dynamo=openxla \ - --suite-name=torchbench --accelerator=cuda \ - --output-dirname=${WORKSPACE_RESULTS_DIR:?} \ - --repeat=${REPEAT:?} --print-subprocess \ - --timestamp=${TIMESTAMP:?} ${PROFILING_FLAGS?} -# Inference + Training: Inductor Dynamo. -python xla/benchmarks/experiment_runner.py \ - --test=eval --test=train \ - --xla=None \ - --dynamo=inductor \ - --suite-name=torchbench --accelerator=cuda \ - --output-dirname=${WORKSPACE_RESULTS_DIR:?} \ - --repeat=${REPEAT:?} --print-subprocess \ - --timestamp=${TIMESTAMP:?} ${PROFILING_FLAGS?} -cd .. - -# Run Llama2 benchmarks. -python ${BM_DIR:?}/llama.py --workspace_dir=${WORKSPACE_DIR:?} - -# Gather results and generate reports -REPORTS_DIR=${NIGHTLY_RESULTS_DIR:?}/reports/${WORKSPACE:?} -mkdir -p ${REPORTS_DIR:?} -cp ${WORKSPACE_RESULTS_DIR:?}/results.jsonl \ - ${NIGHTLY_RESULTS_DIR:?}/${WORKSPACE:?}.jsonl - -PYTORCH_GIT_REV=$(git -C pytorch rev-parse --short HEAD) -XLA_GIT_TAG=$(git -C pytorch/xla describe --tags --always) -GIT_TAGS="PT: ${PYTORCH_GIT_REV:?} XLA: ${XLA_GIT_TAG:?}" - -COMMON_TITLE_PREFIX= -if [[ ${ENABLE_PROFILING?} ]]; then - COMMON_TITLE_PREFIX="[Profiling ON] " -fi - -INFERENCE_BACKENDS_CMD='--backends inductor openxla+dynamo openxla+lazytensor' -TRAINING_BACKENDS_CMD='--backends inductor openxla+dynamo openxla+lazytensor' - -# Skip result files coming from one-off runs. -INPUT_JSONL_FILES=$(ls ${NIGHTLY_RESULTS_DIR:?}/*.jsonl | \ - grep '[0-9]\+-[0-9]\+-[0-9]\+\.jsonl') - -for testname in inference training; do - for report in latest histogram speedup; do - for format in csv svg; do - for tier in '' 1; do - TITLE_PREFIX= - TIER_CMD= - TIER_FILE_SUFFIX= - if [[ ${tier?} ]]; then - TITLE_PREFIX="${COMMON_TITLE_PREFIX?}Tier${tier?} " - TIER_CMD=--filter-by-tier=${tier:?} - TIER_FILE_SUFFIX=-tier${tier:?} - fi - - TITLE="(${testname:?})" - WIDTH=9 - HEIGHT=7 - if [ "${report:?}" == "latest" ]; then - TITLE="${WORKSPACE:?} (${testname:?}) ${GIT_TAGS:?}" - if [[ -z ${tier?} ]]; then - WIDTH=15 - HEIGHT=8 - fi - fi - BACKENDS_CMD= - if [ "${testname:?}" = 'inference' ]; then - BACKENDS_CMD="${INFERENCE_BACKENDS_CMD:?}" - else - BACKENDS_CMD="${TRAINING_BACKENDS_CMD:?}" - fi - python ${BM_DIR:?}/aggregate.py --accelerator=${ACCELERATOR:?} \ - --report=${report:?} --test=${testname:?} --format=${format:?} \ - --title="${TITLE_PREFIX?}${TITLE:?}" \ - --fig-height=${HEIGHT:?} --fig-width=${WIDTH:?} \ - ${TIER_CMD?} \ - ${BACKENDS_CMD:?} -- \ - ${INPUT_JSONL_FILES:?} \ - > ${REPORTS_DIR:?}/${ACCELERATOR:?}-${testname:?}-${report:?}${TIER_FILE_SUFFIX?}.${format:?} - done - done - done -done - -# Generate Llama2 output. -for testname in inference; do - for report in latest_grouped; do - for format in csv svg tab; do - BACKENDS_CMD= - if [ "${testname:?}" = 'inference' ]; then - BACKENDS_CMD="${INFERENCE_BACKENDS_CMD:?}" - else - BACKENDS_CMD="${TRAINING_BACKENDS_CMD:?}" - fi - python ${BM_DIR:?}/aggregate.py --accelerator=${ACCELERATOR:?} \ - --report=${report:?} --test=${testname:?} --format=${format:?} \ - --title="${COMMON_TITLE_PREFIX?}Llama2 (${testname:?})" \ - --filter='^llama2\.' \ - ${BACKENDS_CMD:?} -- \ - ${INPUT_JSONL_FILES:?} \ - > ${REPORTS_DIR:?}/${ACCELERATOR:?}-${testname:?}-${report:?}-llama2.${format:?} - done - done -done diff --git a/benchmarks/result_analyzer.py b/benchmarks/result_analyzer.py index 69f6b323206d..3da67fb7067b 100644 --- a/benchmarks/result_analyzer.py +++ b/benchmarks/result_analyzer.py @@ -57,7 +57,6 @@ def run_csv(self): "xla_flags": pd.Series(dtype="str"), "dynamo": pd.Series(dtype="str"), "torch_xla2": pd.Series(dtype="str"), - "keep_model_data_on_cuda": pd.Series(dtype="bool"), "test": pd.Series(dtype="str"), "batch_size": pd.Series(dtype="int"), "repeat": pd.Series(dtype="int"), @@ -122,10 +121,6 @@ def extract_metrics_jsonl(self, file: str): dynamo_value = "None" if dynamo is None else dynamo torch_xla2 = dataline["experiment"]["torch_xla2"] torch_xla2_value = "None" if torch_xla2 is None else torch_xla2 - keep_model_data_on_cuda = dataline["experiment"][ - "keep_model_data_on_cuda"] - keep_model_data_on_cuda_value = "None" if keep_model_data_on_cuda is None else str( - keep_model_data_on_cuda) test = dataline["experiment"]["test"] test_value = "None" if test is None else test outputs_file = dataline["experiment"].get("outputs_file", None) @@ -146,7 +141,6 @@ def extract_metrics_jsonl(self, file: str): "xla": xla_value, "dynamo": dynamo_value, "torch_xla2": torch_xla2_value, - "keep_model_data_on_cuda": keep_model_data_on_cuda_value, "test": test_value, "outputs_file": outputs_file_value } @@ -180,38 +174,21 @@ def extract_metrics_csv(self, file: str, metric_df: Optional[pd.DataFrame]): timestamp = dataline[ "timestamp"] if "timestamp" in dataline else self.timestamp d = { - "timestamp": - timestamp, - "suite_name": - dataline["model"]["suite_name"], - "model_name": - dataline["model"]["model_name"], - "accelerator": - dataline["experiment"]["accelerator"], - "accelerator_model": - dataline["experiment"]["accelerator_model"], - "xla": - dataline["experiment"]["xla"], - "xla_flags": - dataline["experiment"]["xla_flags"], - "dynamo": - dataline["experiment"]["dynamo"], - "torch_xla2": - dataline["experiment"]["torch_xla2"], - "keep_model_data_on_cuda": - dataline["experiment"]["keep_model_data_on_cuda"], - "test": - dataline["experiment"]["test"], - "batch_size": - dataline["experiment"]["batch_size"], - "repeat": - dataline["repeat"], - "iterations_per_run": - dataline["iterations_per_run"], - "error_message": - None, - "outputs_file": - dataline["experiment"].get("outputs_file", ""), + "timestamp": timestamp, + "suite_name": dataline["model"]["suite_name"], + "model_name": dataline["model"]["model_name"], + "accelerator": dataline["experiment"]["accelerator"], + "accelerator_model": dataline["experiment"]["accelerator_model"], + "xla": dataline["experiment"]["xla"], + "xla_flags": dataline["experiment"]["xla_flags"], + "dynamo": dataline["experiment"]["dynamo"], + "torch_xla2": dataline["experiment"]["torch_xla2"], + "test": dataline["experiment"]["test"], + "batch_size": dataline["experiment"]["batch_size"], + "repeat": dataline["repeat"], + "iterations_per_run": dataline["iterations_per_run"], + "error_message": None, + "outputs_file": dataline["experiment"].get("outputs_file", ""), } if "error" in dataline["metrics"] and not self._args.hide_errors: diff --git a/benchmarks/run_benchmark.sh b/benchmarks/run_benchmark.sh deleted file mode 100644 index 79b746c10ad4..000000000000 --- a/benchmarks/run_benchmark.sh +++ /dev/null @@ -1,79 +0,0 @@ -#!/bin/bash -set -exo pipefail -CDIR="$(cd "$(dirname "$0")" ; pwd -P)" -LOGFILE=/tmp/benchmark_test.log - -# Note [Keep Going] -# -# Set the `CONTINUE_ON_ERROR` flag to `1` to make the CI tests continue on error. -# This will allow you to see all the failures on your PR, not stopping with the first -# test failure like the default behavior. -CONTINUE_ON_ERROR="${CONTINUE_ON_ERROR:-0}" -if [[ "$CONTINUE_ON_ERROR" == "1" ]]; then - set +e -fi - -TESTGPUVM=None -TESTTPUVM=None -# NUMBER=0 - -while getopts 'G:T:' OPTION # N: -do - case $OPTION in - G) - TESTGPUVM=$OPTARG - ;; - T) - TESTTPUVM=$OPTARG - ;; - # N) - # NUMBER=$OPTARG - # ;; - esac -done -shift $(($OPTIND - 1)) - -# func for test after ssh to VM, create container and execute in container -function benchmarking_in_container { - sudo docker pull gcr.io/tpu-pytorch/xla:nightly_3.8_cuda_11.8 - sudo apt-get install -y apt-transport-https ca-certificates curl gnupg-agent software-properties-common - nvidia-smi - distribution=$(. /etc/os-release;echo $ID$VERSION_ID) - curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - - curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list - sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit - sudo systemctl restart docker - sudo docker run --gpus all -it -d gcr.io/tpu-pytorch/xla:nightly_3.8_cuda_11.8 bin/bash - sudo docker exec -it $(sudo docker ps | awk 'NR==2 { print $1 }') /bin/bash - # install torchbench - cd ~ - git clone -b xla_benchmark https://github.com/pytorch/benchmark.git - cd benchmark - # install deps - pip install --pre torchvision torchaudio -i https://download.pytorch.org/whl/nightly/cu118 - # git clone xla - cd ~ - git clone -b benchmark https://github.com/pytorch/xla.git xla - cd ~/xla/benchmarks - # dry run - python3 experiment_runner.py --suite-name=torchbench --accelerator=gpu --progress-bar --dry-run - # run bechmark - python3 experiment_runner.py --suite-name=torchbench --accelerator=gpu --progress-bar - # analyze result to csv - python3 result_analyzer.py -} - - - -if TESTGPUVM='1A100': - # ssh to 1-A100 GPUVM and test in container - gcloud compute ssh a100-manfei-1 --zone us-central1-c --project tpu-prod-env-one-vm -- -o ProxyCommand='corp-ssh-helper %h %p' --command=benchmarking_in_container -elif TESTGPUVM='8A100': - # SSH TO 8-A100 GPUVM and test in container - gcloud compute ssh manfei-a100-8-new --zone us-central1-c --project tpu-prod-env-one-vm -- -o ProxyCommand='corp-ssh-helper %h %p' --command=benchmarking_in_container -elif TESTGPUVM='4H100': - # ssh to 4-H100 GPUVM and test in container -elif TESTTPUVM='v5e8': - # ssh to v5e-8 TPUVM and test in container -elif TESTTPUVM='v5p8': - # ssh to v5p-8 TPUVM and test in container diff --git a/benchmarks/run_single_graph_bm.sh b/benchmarks/run_single_graph_bm.sh deleted file mode 100755 index 98e10a06d05b..000000000000 --- a/benchmarks/run_single_graph_bm.sh +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env bash - -set -ex - -DATE=$(date +"%Y_%m_%d_%H_%M") - -OUT_PATH=xla/benchmarks/bm_results/single_graph/$DATE -mkdir -p $OUT_PATH - -python new_xla/benchmarks/experiment_runner.py \ - --dynamo=inductor --dynamo=openxla \ - --xla=None --xla=PJRT \ - --test=eval \ - --filter-by-single-graph \ - --pure-wall-time \ - --suite-name=torchbench \ - --accelerator=cuda \ - --output-dirname=$OUT_PATH \ - --repeat=5 \ - --print-subprocess \ - --no-resume \ - > $OUT_PATH/stdout.txt 2> $OUT_PATH/stderr.txt - -python3 xla/benchmarks/result_analyzer.py \ - --output-dirname=$OUT_PATH \ - --database=$OUT_PATH/$DATE.csv diff --git a/benchmarks/run_top_tier_bm.sh b/benchmarks/run_top_tier_bm.sh deleted file mode 100755 index 9b8e8eb8eb69..000000000000 --- a/benchmarks/run_top_tier_bm.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/usr/bin/env bash - -set -ex - -DATE=$(date +"%Y_%m_%d_%H_%M") - -OUT_PATH=xla/benchmarks/bm_results/$DATE -mkdir -p $OUT_PATH - -python xla/benchmarks/experiment_runner.py \ - --dynamo=inductor --dynamo=openxla \ - --xla=None --xla=PJRT \ - --test=eval --test=train \ - --filter-by-tier=1 --filter-by-tier=2 --filter-by-tier=3 \ - --suite-name=torchbench \ - --accelerator=cuda \ - --output-dirname=$OUT_PATH \ - --repeat=5 \ - --print-subprocess \ - --no-resume \ - > $OUT_PATH/stdout.txt 2> $OUT_PATH/stderr.txt - -python3 xla/benchmarks/result_analyzer.py \ - --output-dirname=$OUT_PATH \ - --database=$OUT_PATH/$DATE.csv diff --git a/benchmarks/torchbench_model.py b/benchmarks/torchbench_model.py index 55b7f5552762..75a64fa86fd4 100644 --- a/benchmarks/torchbench_model.py +++ b/benchmarks/torchbench_model.py @@ -273,13 +273,10 @@ def set_up(self): # Move the initialized model to XLA device if it's not there already. if self.benchmark_experiment.xla and not self.should_initialize_on_xla(): - # First, move the model and the inputs to CPU. - # This avoids having dupplicated data on CUDA. - keep_model_data_on_cuda = self.benchmark_experiment.keep_model_data_on_cuda - if self.is_accelerator_cuda() and not keep_model_data_on_cuda: - self.module = self.module.to("cpu") - self.example_inputs = move_to_device(self.example_inputs, "cpu") - cleanup(self.is_accelerator_cuda()) + assert not self.is_accelerator_cuda() + self.module = self.module.to("cpu") + self.example_inputs = move_to_device(self.example_inputs, "cpu") + cleanup() # Torchbench has quite different setup for yolov3, so directly passing # the right example_inputs diff --git a/benchmarks/util.py b/benchmarks/util.py index bdd965a46a90..3c13232af2f6 100644 --- a/benchmarks/util.py +++ b/benchmarks/util.py @@ -51,18 +51,9 @@ def deterministic_torch_manual_seed(*args, **kwargs): @functools.lru_cache(maxsize=3) def is_xla_device_available(devkind, use_xla2: bool = False): - if devkind not in ["CPU", "CUDA", "TPU"]: + if devkind not in ["CPU", "TPU"]: raise ValueError(devkind) # Checking the availability of a given device kind. - # - # We intentionally use subprocess instead of multiprocessing library. The - # reason being that we might initialize CUDA in the parent process and use - # CUDA in the child process. This is a known limitation of using CUDA and - # forking the process. - # - # In this case, subprocess works because it replaces the forked memory with - # the execution of the new program (fresh memory), avoiding the error. - # # For more information: https://github.com/pytorch/xla/pull/5960 CHECK_XLA_DEVICE_PY = "check_xla_device.py" python_file = os.path.join(os.path.dirname(__file__), CHECK_XLA_DEVICE_PY) @@ -80,7 +71,7 @@ def move_to_device(item, device, torch_xla2: bool = False): def move_to_device_func(tensor: torch.Tensor) -> torch.Tensor: # If `tensor` is an XLA tensor, first move it to CPU. We need to do - # that if we want to move the tensor to, say, CUDA. + # that if we want to move the tensor to TPU. if tensor.device.type == "xla": return tensor.cpu().to(device) return tensor.to(device) diff --git a/benchmarks/verifier.py b/benchmarks/verifier.py index d2e940711ddd..4fefc509cc5e 100644 --- a/benchmarks/verifier.py +++ b/benchmarks/verifier.py @@ -152,7 +152,7 @@ def maybe_synchronize(): # Delete the model for saving up memory. del model # Clean-up CUDA as well. - cleanup(cuda=True) + cleanup(cuda=experiment_config["accelerator"] == "cuda") def _apply_eager_config(experiment): diff --git a/test/benchmarks/test_benchmark_experiment.py b/test/benchmarks/test_benchmark_experiment.py index 2c5efcd05832..841beb519e09 100644 --- a/test/benchmarks/test_benchmark_experiment.py +++ b/test/benchmarks/test_benchmark_experiment.py @@ -7,16 +7,15 @@ class BenchmarkExperimentTest(unittest.TestCase): def test_to_dict(self): be = BenchmarkExperiment("cpu", "PJRT", "some xla_flags", "openxla", None, - False, "train", "123", False) + "train", "123", False) actual = be.to_dict() - self.assertEqual(10, len(actual)) + self.assertEqual(9, len(actual)) self.assertEqual("cpu", actual["accelerator"]) self.assertTrue("accelerator_model" in actual) self.assertEqual("PJRT", actual["xla"]) self.assertEqual("some xla_flags", actual["xla_flags"]) self.assertEqual("openxla", actual["dynamo"]) self.assertEqual(None, actual["torch_xla2"]) - self.assertEqual(False, actual["keep_model_data_on_cuda"]) self.assertEqual("train", actual["test"]) self.assertEqual("123", actual["batch_size"]) self.assertEqual(False, actual["enable_functionalization"]) diff --git a/test/benchmarks/test_experiment_runner.py b/test/benchmarks/test_experiment_runner.py index 4ce4167d0e46..e1c572e402f2 100644 --- a/test/benchmarks/test_experiment_runner.py +++ b/test/benchmarks/test_experiment_runner.py @@ -29,44 +29,15 @@ def test_dummy_dry_run(self): expected_in_stderr = [ "Number of selected experiment configs: 4", "Number of selected model configs: 1", - "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cpu\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"torch_xla2\": null, \"test\": \"eval\", \"keep_model_data_on_cuda\": false, \"enable_functionalization\": false}", - "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cpu\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"torch_xla2\": null, \"test\": \"train\", \"keep_model_data_on_cuda\": false, \"enable_functionalization\": false}", - "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cpu\", \"xla\": null, \"xla_flags\": null, \"dynamo\": \"inductor\", \"torch_xla2\": null, \"test\": \"eval\", \"keep_model_data_on_cuda\": false, \"enable_functionalization\": false}", - "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cpu\", \"xla\": null, \"xla_flags\": null, \"dynamo\": \"inductor\", \"torch_xla2\": null, \"test\": \"train\", \"keep_model_data_on_cuda\": false, \"enable_functionalization\": false}", + "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cpu\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"torch_xla2\": null, \"test\": \"eval\", \"enable_functionalization\": false}", + "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cpu\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"torch_xla2\": null, \"test\": \"train\", \"enable_functionalization\": false}", + "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cpu\", \"xla\": null, \"xla_flags\": null, \"dynamo\": \"inductor\", \"torch_xla2\": null, \"test\": \"eval\", \"enable_functionalization\": false}", + "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cpu\", \"xla\": null, \"xla_flags\": null, \"dynamo\": \"inductor\", \"torch_xla2\": null, \"test\": \"train\", \"enable_functionalization\": false}", ] for expected in expected_in_stderr: self.assertIn(expected, child.stderr) - @absltest.skipUnless(xr.device_type() in {'CUDA'}, 'Needs CUDA accelerator') - def test_dummy_dry_run_cuda(self): - child = subprocess.run([ - "python", - EXPERIMENT_RUNNER_PY, - "--dynamo=openxla", - "--dynamo=inductor", - "--xla=PJRT", - "--xla=None", - "--test=eval", - "--test=train", - "--suite-name=dummy", - "--accelerator=cuda", - "--dry-run", - ], - capture_output=True, - text=True) - expected_in_stderr = [ - "Number of selected experiment configs: 4", - "Number of selected model configs: 1", - "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"torch_xla2\": null, \"test\": \"eval\", \"keep_model_data_on_cuda\": false, \"enable_functionalization\": false}", - "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"torch_xla2\": null, \"test\": \"train\", \"keep_model_data_on_cuda\": false, \"enable_functionalization\": false}", - "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": null, \"xla_flags\": null, \"dynamo\": \"inductor\", \"torch_xla2\": null, \"test\": \"eval\", \"keep_model_data_on_cuda\": false, \"enable_functionalization\": false}", - "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": null, \"xla_flags\": null, \"dynamo\": \"inductor\", \"torch_xla2\": null, \"test\": \"train\", \"keep_model_data_on_cuda\": false, \"enable_functionalization\": false}", - ] - for expected in expected_in_stderr: - self.assertIn(expected, child.stderr) - - @absltest.skipUnless(xr.device_type() in {'CUDA'}, 'Needs CUDA accelerator') - def test_dummy_dry_run_inductor_cuda(self): + def test_dummy_dry_run_inductor_cpu(self): child = subprocess.run([ "python", EXPERIMENT_RUNNER_PY, @@ -85,14 +56,13 @@ def test_dummy_dry_run_inductor_cuda(self): expected_in_stderr = [ "Number of selected experiment configs: 2", "Number of selected model configs: 1", - "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": null, \"xla_flags\": null, \"dynamo\": \"inductor\", \"torch_xla2\": null, \"test\": \"eval\", \"keep_model_data_on_cuda\": false, \"enable_functionalization\": false}", - "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": null, \"xla_flags\": null, \"dynamo\": \"inductor\", \"torch_xla2\": null, \"test\": \"train\", \"keep_model_data_on_cuda\": false, \"enable_functionalization\": false}", + "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": null, \"xla_flags\": null, \"dynamo\": \"inductor\", \"torch_xla2\": null, \"test\": \"eval\", \"enable_functionalization\": false}", + "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": null, \"xla_flags\": null, \"dynamo\": \"inductor\", \"torch_xla2\": null, \"test\": \"train\", \"enable_functionalization\": false}", ] for expected in expected_in_stderr: self.assertIn(expected, child.stderr) - @absltest.skipUnless(xr.device_type() in {'CUDA'}, 'Needs CUDA accelerator') - def test_dummy_openxla_train_cuda(self): + def test_dummy_openxla_train_cpu(self): child = subprocess.run([ "python", EXPERIMENT_RUNNER_PY, @@ -103,7 +73,7 @@ def test_dummy_openxla_train_cuda(self): "--test=eval", "--test=train", "--suite-name=dummy", - "--accelerator=cuda", + "--accelerator=cpu", "--filter=^dummy$", "--dry-run", ], @@ -112,21 +82,20 @@ def test_dummy_openxla_train_cuda(self): expected_in_stderr = [ "Number of selected experiment configs: 4", "Number of selected model configs: 1", - "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"torch_xla2\": null, \"test\": \"train\", \"keep_model_data_on_cuda\": false, \"enable_functionalization\": false}", - "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"torch_xla2\": null, \"test\": \"eval\", \"keep_model_data_on_cuda\": false, \"enable_functionalization\": false}", - "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": null, \"xla_flags\": null, \"dynamo\": \"inductor\", \"torch_xla2\": null, \"test\": \"eval\", \"keep_model_data_on_cuda\": false, \"enable_functionalization\": false}", - "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": null, \"xla_flags\": null, \"dynamo\": \"inductor\", \"torch_xla2\": null, \"test\": \"train\", \"keep_model_data_on_cuda\": false, \"enable_functionalization\": false}", + "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cpu\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"torch_xla2\": null, \"test\": \"train\", \"enable_functionalization\": false}", + "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cpu\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"torch_xla2\": null, \"test\": \"eval\", \"enable_functionalization\": false}", + "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cpu\", \"xla\": null, \"xla_flags\": null, \"dynamo\": \"inductor\", \"torch_xla2\": null, \"test\": \"eval\", \"enable_functionalization\": false}", + "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cpu\", \"xla\": null, \"xla_flags\": null, \"dynamo\": \"inductor\", \"torch_xla2\": null, \"test\": \"train\", \"enable_functionalization\": false}", ] for expected in expected_in_stderr: self.assertIn(expected, child.stderr) - @absltest.skipUnless(xr.device_type() in {'CUDA'}, 'Needs CUDA accelerator') - def test_dummy_dynamo_none_cuda(self): + def test_dummy_dynamo_none_cpu(self): child = subprocess.run([ "python", EXPERIMENT_RUNNER_PY, "--suite-name=dummy", - "--accelerator=cuda", + "--accelerator=cpu", "--xla=PJRT", "--xla=None", "--filter=^dummy$", @@ -137,39 +106,14 @@ def test_dummy_dynamo_none_cuda(self): expected_in_stderr = [ "Number of selected experiment configs: 8", "Number of selected model configs: 1", - "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": null, \"torch_xla2\": null, \"test\": \"eval\", \"keep_model_data_on_cuda\": false, \"enable_functionalization\": false}", - "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": null, \"torch_xla2\": null, \"test\": \"train\", \"keep_model_data_on_cuda\": false, \"enable_functionalization\": false}", - "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"torch_xla2\": null, \"test\": \"eval\", \"keep_model_data_on_cuda\": false, \"enable_functionalization\": false}", - "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"torch_xla2\": null, \"test\": \"train\", \"keep_model_data_on_cuda\": false, \"enable_functionalization\": false}", - "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": null, \"xla_flags\": null, \"dynamo\": null, \"torch_xla2\": null, \"test\": \"eval\", \"keep_model_data_on_cuda\": false, \"enable_functionalization\": false}", - "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": null, \"xla_flags\": null, \"dynamo\": null, \"torch_xla2\": null, \"test\": \"train\", \"keep_model_data_on_cuda\": false, \"enable_functionalization\": false}", - "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": null, \"xla_flags\": null, \"dynamo\": \"inductor\", \"torch_xla2\": null, \"test\": \"eval\", \"keep_model_data_on_cuda\": false, \"enable_functionalization\": false}", - "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": null, \"xla_flags\": null, \"dynamo\": \"inductor\", \"torch_xla2\": null, \"test\": \"train\", \"keep_model_data_on_cuda\": false, \"enable_functionalization\": false}", - ] - for expected in expected_in_stderr: - self.assertIn(expected, child.stderr) - - @absltest.skipUnless(xr.device_type() in {'CUDA'}, 'Needs CUDA accelerator') - def test_dummy_dry_run_cuda_with_keep_model_data_on_cuda(self): - child = subprocess.run([ - "python", - EXPERIMENT_RUNNER_PY, - "--dynamo=openxla", - "--xla=PJRT", - "--test=eval", - "--test=train", - "--suite-name=dummy", - "--accelerator=cuda", - "--keep-model-data-on-cuda", - "--dry-run", - ], - capture_output=True, - text=True) - expected_in_stderr = [ - "Number of selected experiment configs: 2", - "Number of selected model configs: 1", - "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"torch_xla2\": null, \"test\": \"eval\", \"keep_model_data_on_cuda\": true, \"enable_functionalization\": false}", - "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cuda\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"torch_xla2\": null, \"test\": \"train\", \"keep_model_data_on_cuda\": true, \"enable_functionalization\": false}", + "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cpu\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": null, \"torch_xla2\": null, \"test\": \"eval\", \"enable_functionalization\": false}", + "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cpu\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": null, \"torch_xla2\": null, \"test\": \"train\", \"enable_functionalization\": false}", + "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cpu\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"torch_xla2\": null, \"test\": \"eval\", \"enable_functionalization\": false}", + "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cpu\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"torch_xla2\": null, \"test\": \"train\", \"enable_functionalization\": false}", + "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cpu\", \"xla\": null, \"xla_flags\": null, \"dynamo\": null, \"torch_xla2\": null, \"test\": \"eval\", \"enable_functionalization\": false}", + "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cpu\", \"xla\": null, \"xla_flags\": null, \"dynamo\": null, \"torch_xla2\": null, \"test\": \"train\", \"enable_functionalization\": false}", + "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cpu\", \"xla\": null, \"xla_flags\": null, \"dynamo\": \"inductor\", \"torch_xla2\": null, \"test\": \"eval\", \"enable_functionalization\": false}", + "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cpu\", \"xla\": null, \"xla_flags\": null, \"dynamo\": \"inductor\", \"torch_xla2\": null, \"test\": \"train\", \"enable_functionalization\": false}", ] for expected in expected_in_stderr: self.assertIn(expected, child.stderr) @@ -192,8 +136,8 @@ def test_dummy_dry_run_with_functionalization(self): expected_in_stderr = [ "Number of selected experiment configs: 2", "Number of selected model configs: 1", - "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cpu\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"torch_xla2\": null, \"test\": \"eval\", \"keep_model_data_on_cuda\": false, \"enable_functionalization\": true}", - "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cpu\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"torch_xla2\": null, \"test\": \"train\", \"keep_model_data_on_cuda\": false, \"enable_functionalization\": true}", + "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cpu\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"torch_xla2\": null, \"test\": \"eval\", \"enable_functionalization\": true}", + "--model-config={\"model_name\": \"dummy\"} --experiment-config={\"accelerator\": \"cpu\", \"xla\": \"PJRT\", \"xla_flags\": null, \"dynamo\": \"openxla\", \"torch_xla2\": null, \"test\": \"train\", \"enable_functionalization\": true}", ] for expected in expected_in_stderr: self.assertIn(expected, child.stderr) From c48478aaa4e7b39d8862741d66280be9510a6c5d Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 3 Sep 2025 13:51:10 -0300 Subject: [PATCH 089/133] Remove CUDA tests from distributed tests. (#9612) This PR removes CUDA specific logic and tests from distributed tests. Including both multiprocessing and SPMD tests. This is in line with the CUDA deprecation that started on release 2.8. **Key Changes:** - Removed `skipIf` test decorations whenever the condition is checking CUDA - Removed `CUDA` from the list of allowed devices for a few of these tests --- test/eager/test_eager_all_reduce_in_place.py | 2 +- test/pjrt/test_ddp.py | 2 -- test/spmd/test_spmd_debugging.py | 30 ++++++++----------- test/spmd/test_train_spmd_linear_model.py | 3 -- .../test_xla_spmd_python_api_interaction.py | 14 +-------- test/test_assume_pure_spmd.py | 20 ------------- test/test_fsdp_auto_wrap.py | 9 ++---- test/test_mp_all_gather.py | 4 +-- test/test_mp_distributed_mm.py | 5 ++-- test/test_mp_early_exit.py | 4 +-- test/test_mp_reduce_scatter.py | 2 +- ...st_torch_distributed_fsdp_frozen_weight.py | 6 ++-- test/torch_distributed/test_ddp.py | 5 +--- ...orch_distributed_all_gather_xla_backend.py | 5 ++-- ...orch_distributed_all_reduce_xla_backend.py | 5 ++-- ...ributed_bucketed_all_reduce_xla_backend.py | 5 ++-- ...istributed_multi_all_reduce_xla_backend.py | 5 ++-- ..._distributed_reduce_scatter_xla_backend.py | 5 ++-- torch_xla/test/test_utils.py | 5 ---- 19 files changed, 36 insertions(+), 100 deletions(-) diff --git a/test/eager/test_eager_all_reduce_in_place.py b/test/eager/test_eager_all_reduce_in_place.py index 7ea68b7fb6e4..f6a9103aa41b 100644 --- a/test/eager/test_eager_all_reduce_in_place.py +++ b/test/eager/test_eager_all_reduce_in_place.py @@ -12,7 +12,7 @@ def _mp_fn(index): device = torch_xla.device() - if xm.xla_device_hw(device) not in ('TPU', 'CUDA', 'NEURON'): + if xm.xla_device_hw(device) not in ('TPU', 'NEURON'): return ordinal_tensor_1 = torch.tensor([index], dtype=torch.float).to(device) diff --git a/test/pjrt/test_ddp.py b/test/pjrt/test_ddp.py index d93bbe45c4d9..1c7d1bff7acd 100644 --- a/test/pjrt/test_ddp.py +++ b/test/pjrt/test_ddp.py @@ -33,8 +33,6 @@ def _ddp_init(index: int = ...): def test_ddp_init(self): pjrt.run_multiprocess(self._ddp_init) - @absltest.skipIf(xr.device_type() == 'CUDA', - "GPU device is not supported by pjrt.spawn_threads") def test_ddp_init_threaded(self): pjrt.spawn_threads(self._ddp_init) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index 34221d375e9c..2f126f00955e 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -28,9 +28,8 @@ def setUpClass(cls): xr.use_spmd() super().setUpClass() - @unittest.skipIf( - xr.device_type() == 'CPU', - f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") + @unittest.skipIf(xr.device_type() == 'CPU', + f"Requires PJRT_DEVICE set to `TPU`.") def test_debugging_spmd_single_host_tiled_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{devices=[2,4]0,1,2,3,4,5,6,7}' @@ -108,9 +107,8 @@ def test_debugging_spmd_single_host_tiled_tpu(self): fake_output = fake_capture.get() assert output == fake_output - @unittest.skipIf( - xr.device_type() == 'CPU', - f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") + @unittest.skipIf(xr.device_type() == 'CPU', + f"Requires PJRT_DEVICE set to `TPU`.") def test_single_host_partial_replication_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}' @@ -168,9 +166,8 @@ def test_single_host_partial_replication_tpu(self): fake_output = fake_capture.get() assert output == fake_output - @unittest.skipIf( - xr.device_type() == 'CPU', - f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") + @unittest.skipIf(xr.device_type() == 'CPU', + f"Requires PJRT_DEVICE set to `TPU`.") def test_single_host_replicated_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{replicated}' @@ -340,9 +337,8 @@ def test_single_host_replicated_cpu(self): # e.g.: sharding={devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate} # e.g.: sharding={replicated} - @unittest.skipIf( - xr.device_type() == 'CPU', - f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") + @unittest.skipIf(xr.device_type() == 'CPU', + f"Requires PJRT_DEVICE set to `TPU`.") def test_debugging_spmd_multi_host_tiled_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{devices=[2,8]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15}' @@ -468,9 +464,8 @@ def test_debugging_spmd_multi_host_tiled_tpu(self): fake_output = fake_capture.get() assert output == fake_output - @unittest.skipIf( - xr.device_type() == 'CPU', - f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") + @unittest.skipIf(xr.device_type() == 'CPU', + f"Requires PJRT_DEVICE set to `TPU`.") def test_multi_host_partial_replication_tpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding sharding = '{devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate}' @@ -560,9 +555,8 @@ def test_multi_host_partial_replication_tpu(self): fake_output = fake_capture.get() assert output == fake_output - @unittest.skipIf( - xr.device_type() == 'CPU', - f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.") + @unittest.skipIf(xr.device_type() == 'CPU', + f"Requires PJRT_DEVICE set to `TPU`.") @unittest.skipIf(xr.global_runtime_device_count() != 8, f"Limit test num_devices to 8 for function consistency") def test_multi_host_replicated_tpu(self): diff --git a/test/spmd/test_train_spmd_linear_model.py b/test/spmd/test_train_spmd_linear_model.py index f1f40e061fc7..45c8a8ea3907 100644 --- a/test/spmd/test_train_spmd_linear_model.py +++ b/test/spmd/test_train_spmd_linear_model.py @@ -20,9 +20,6 @@ # the gradient checkpointing A/B test run for it. SKIP_GRADIENT_CHECKPOINTING: bool = False -skipOnGpu = unittest.skipIf(xr.device_type() == 'CUDA', - 'https://github.com/pytorch/xla/issues/9128') - @contextmanager def extended_argv(args): diff --git a/test/spmd/test_xla_spmd_python_api_interaction.py b/test/spmd/test_xla_spmd_python_api_interaction.py index ba051964a108..06176df08f74 100644 --- a/test/spmd/test_xla_spmd_python_api_interaction.py +++ b/test/spmd/test_xla_spmd_python_api_interaction.py @@ -98,17 +98,6 @@ def test_global_runtime_device_count(self): self.assertGreaterEqual(xr.global_runtime_device_count(), 4) elif device_type == "CPU": self.assertEqual(xr.global_runtime_device_count(), 1) - elif device_type == 'CUDA': - command = 'nvidia-smi --list-gpus | wc -l' - result = subprocess.run( - command, - capture_output=True, - shell=True, - check=True, - text=True, - ) - expected_gpu_cnt = int(result.stdout) - self.assertEqual(xr.global_runtime_device_count(), expected_gpu_cnt) def test_addressable_runtime_device_count(self): device_type = os.environ['PJRT_DEVICE'] @@ -145,8 +134,7 @@ class BasicAutocastAPITest(test_xla_sharding_base.XlaShardingTest): def setUpClass(cls): super().setUpClass() - @unittest.skipIf(xr.device_type() not in ['TPU', 'CUDA'], - f"TPU/GPU autocast test.") + @unittest.skipIf(xr.device_type() not in ('TPU',), f"TPU autocast test.") def test_xla_autocast_api(self): device = torch_xla.device() t1 = torch.ones([2, 3], device=device, dtype=torch.float32) diff --git a/test/test_assume_pure_spmd.py b/test/test_assume_pure_spmd.py index f6320d755be1..dd7a918f5e55 100644 --- a/test/test_assume_pure_spmd.py +++ b/test/test_assume_pure_spmd.py @@ -37,10 +37,6 @@ def setUp(self): @unittest.skipUnless(xr.global_runtime_device_count() > 1, "Multiple devices required") - @unittest.skipIf( - torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA', - "TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU" - ) def test_assume_pure_works_with_mark_sharding(self): x = torch.randn((8, 4, 5, 128), device='xla') result = assume_pure(mark_sharding)(x, self.spmd_mesh, @@ -52,10 +48,6 @@ def test_assume_pure_works_with_mark_sharding(self): @unittest.skipUnless(xr.global_runtime_device_count() > 1, "Multiple devices required") - @unittest.skipIf( - torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA', - "TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU" - ) def test_assume_pure_works_with_mark_sharding_with_gradients(self): x = torch.randn((8, 4, 5, 128)).to('xla').requires_grad_(True) result = assume_pure(mark_sharding_with_gradients)( @@ -71,10 +63,6 @@ def test_assume_pure_works_with_mark_sharding_with_gradients(self): @unittest.skipUnless(xr.global_runtime_device_count() > 1, "Multiple devices required") - @unittest.skipIf( - torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA', - "TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU" - ) def test_assume_pure_works_with_mark_sharding_nested(self): mesh = get_2d_mesh("model", "batch") set_global_mesh(mesh) @@ -88,10 +76,6 @@ def test_assume_pure_works_with_mark_sharding_nested(self): @unittest.skipUnless(xr.global_runtime_device_count() > 1, "Multiple devices required") - @unittest.skipIf( - torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA', - "TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU" - ) def test_assume_pure_works_with_mark_sharding_with_gradients_nested(self): mesh = get_2d_mesh("model", "batch") set_global_mesh(mesh) @@ -109,10 +93,6 @@ def test_assume_pure_works_with_mark_sharding_with_gradients_nested(self): @unittest.skipUnless(xr.global_runtime_device_count() > 1, "Multiple devices required") - @unittest.skipIf( - torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA', - "TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU" - ) def test_convert_to_jax_mesh(self): jax_mesh = self.spmd_mesh.get_jax_mesh() self.assertEqual(jax_mesh.devices.shape, self.spmd_mesh.mesh_shape) diff --git a/test/test_fsdp_auto_wrap.py b/test/test_fsdp_auto_wrap.py index 019612899697..2ab3b3de132b 100644 --- a/test/test_fsdp_auto_wrap.py +++ b/test/test_fsdp_auto_wrap.py @@ -30,10 +30,6 @@ def forward(self, x): hidden2 = self.fc2(x) return hidden1, hidden2 - @unittest.skipIf( - xr.device_type() == 'CUDA', - "This test fails only on GPU with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840)" - ) def test(self): dev = torch_xla.device() input = torch.zeros([16, 16], device=dev) @@ -49,13 +45,12 @@ def test(self): def _mp_fn(index): device = torch_xla.device() - if xm.xla_device_hw(device) in ('TPU', 'CUDA'): + if xm.xla_device_hw(device) in ('TPU',): test = unittest.main(exit=False) sys.exit(0 if test.result.wasSuccessful() else 1) else: print( - 'Default device {} is not a TPU or CUDA device'.format(device), - file=sys.stderr) + 'Default device {} is not a TPU device'.format(device), file=sys.stderr) if __name__ == '__main__': diff --git a/test/test_mp_all_gather.py b/test/test_mp_all_gather.py index 93d64f47ef3e..8f1634c40afe 100644 --- a/test/test_mp_all_gather.py +++ b/test/test_mp_all_gather.py @@ -14,7 +14,7 @@ def _mp_fn(index): device = torch_xla.device() world_size = xr.world_size() input_list_size = 5 - if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'): + if xm.xla_device_hw(device) in ('TPU', 'NEURON'): # Testing with a single replica group ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device) result = xm.all_gather(ordinal_tensor, dim=0) @@ -161,7 +161,7 @@ def _mp_fn(index): # TODO: add test for torch.compile when support for list input is ready else: - print(f'{device} is not a TPU or GPU device', file=sys.stderr) + print(f'{device} is not a TPU device', file=sys.stderr) if __name__ == '__main__': diff --git a/test/test_mp_distributed_mm.py b/test/test_mp_distributed_mm.py index 7d6c7982cb2f..e660b6159d0a 100644 --- a/test/test_mp_distributed_mm.py +++ b/test/test_mp_distributed_mm.py @@ -9,7 +9,7 @@ def _mp_fn(index): device = torch_xla.device() - if xm.xla_device_hw(device) in ('TPU', 'CUDA'): + if xm.xla_device_hw(device) in ('TPU',): world_size = xr.world_size() torch_xla._XLAC._xla_set_mat_mul_precision('highest') torch.manual_seed(11) @@ -34,8 +34,7 @@ def _mp_fn(index): sys.exit(1) else: print( - 'Default device {} is not a TPU or GPU device'.format(device), - file=sys.stderr) + 'Default device {} is not a TPU device'.format(device), file=sys.stderr) if __name__ == '__main__': diff --git a/test/test_mp_early_exit.py b/test/test_mp_early_exit.py index 89e46722e232..d5c2987c7d9a 100644 --- a/test/test_mp_early_exit.py +++ b/test/test_mp_early_exit.py @@ -13,7 +13,7 @@ def _mp_fn(): dist.init_process_group('xla', init_method='xla://') device = torch_xla.device() - if xm.xla_device_hw(device) in ['TPU', 'CUDA']: + if xm.xla_device_hw(device) in ('TPU',): train_loader = xu.SampleGenerator( data=torch.zeros(1, 12), sample_count=1024) train_loader = pl.MpDeviceLoader(train_loader, device) @@ -23,7 +23,7 @@ def _mp_fn(): if step > max_steps: break else: - print(f'{device} is not a TPU or GPU device', file=sys.stderr) + print(f'{device} is not a TPU device', file=sys.stderr) if __name__ == '__main__': diff --git a/test/test_mp_reduce_scatter.py b/test/test_mp_reduce_scatter.py index 12fc7fdfe1c8..e9701ba89b05 100644 --- a/test/test_mp_reduce_scatter.py +++ b/test/test_mp_reduce_scatter.py @@ -13,7 +13,7 @@ def _mp_fn(index): shard_size = 2 input_list_size = 5 - if xm.xla_device_hw(device) in ['TPU', 'CUDA', 'CPU']: + if xm.xla_device_hw(device) in ['TPU', 'CPU']: rand = torch.rand((32, shard_size * world_size, 32)) xrand = rand.to(device) diff --git a/test/test_torch_distributed_fsdp_frozen_weight.py b/test/test_torch_distributed_fsdp_frozen_weight.py index 98730dbf7009..51cce201727c 100644 --- a/test/test_torch_distributed_fsdp_frozen_weight.py +++ b/test/test_torch_distributed_fsdp_frozen_weight.py @@ -8,10 +8,8 @@ def _mp_fn(index): dev = torch_xla.device() - if xm.xla_device_hw(dev) not in ('TPU', 'CUDA'): - print( - 'Default device {} is not a TPU or CUDA device'.format(dev), - file=sys.stderr) + if xm.xla_device_hw(dev) not in ('TPU',): + print('Default device {} is not a TPU device'.format(dev), file=sys.stderr) return model = nn.Linear(1024, 1024) diff --git a/test/torch_distributed/test_ddp.py b/test/torch_distributed/test_ddp.py index 1d91f520d5aa..b8d482a53e2c 100644 --- a/test/torch_distributed/test_ddp.py +++ b/test/torch_distributed/test_ddp.py @@ -3,7 +3,6 @@ import sys import torch_xla import torch_xla.core.xla_model as xm -from torch_xla.test.test_utils import skipIfCUDA # Setup import folders. xla_test_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0]))) @@ -25,7 +24,7 @@ def _ddp_correctness(rank, # We cannot run this guard before XMP, # see API_GUIDE.md#running-on-multiple-xla-devices-with-multi-processing. device = torch_xla.device() - if xm.xla_device_hw(device) not in ('TPU', 'CUDA'): + if xm.xla_device_hw(device) not in ('TPU',): print( 'Default device {} is not a TPU device'.format(device), file=sys.stderr) @@ -39,8 +38,6 @@ def _ddp_correctness(rank, def test_ddp_correctness(self): torch_xla.launch(self._ddp_correctness, args=(False, FLAGS.debug)) - # Ref: https://github.com/pytorch/xla/pull/8593 - @skipIfCUDA("GPU CI is failing") def test_ddp_correctness_with_gradient_as_bucket_view(self): torch_xla.launch(self._ddp_correctness, args=(False, FLAGS.debug, True)) diff --git a/test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py b/test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py index 7c30b211ad49..5201fd347ba6 100644 --- a/test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py +++ b/test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py @@ -10,7 +10,7 @@ def _mp_fn(index): device = torch_xla.device() - if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'): + if xm.xla_device_hw(device) in ('TPU', 'NEURON'): world_size = xr.world_size() rank = xr.global_ordinal() @@ -30,8 +30,7 @@ def _mp_fn(index): assert torch.all(xoutput0.cpu() == expected0), f'{xoutput0} != {expected0}' else: print( - 'Default device {} is not a TPU or GPU device'.format(device), - file=sys.stderr) + 'Default device {} is not a TPU device'.format(device), file=sys.stderr) if __name__ == '__main__': diff --git a/test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py b/test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py index 2fd71d2ed84e..434a4ba19e10 100644 --- a/test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py +++ b/test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py @@ -10,7 +10,7 @@ def _mp_fn(index): device = torch_xla.device() - if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'): + if xm.xla_device_hw(device) in ('TPU', 'NEURON'): world_size = xr.world_size() dist.init_process_group('xla', init_method='xla://') # note that we can't use torch.tensor(torch.distributed.get_rank()) directly @@ -25,8 +25,7 @@ def _mp_fn(index): xla_rank_tensor.cpu() == expected), f'{xla_rank_tensor} != {expected}' else: print( - 'Default device {} is not a TPU or GPU device'.format(device), - file=sys.stderr) + 'Default device {} is not a TPU device'.format(device), file=sys.stderr) if __name__ == '__main__': diff --git a/test/torch_distributed/test_torch_distributed_bucketed_all_reduce_xla_backend.py b/test/torch_distributed/test_torch_distributed_bucketed_all_reduce_xla_backend.py index c462f7552800..ba0da2efdc2e 100644 --- a/test/torch_distributed/test_torch_distributed_bucketed_all_reduce_xla_backend.py +++ b/test/torch_distributed/test_torch_distributed_bucketed_all_reduce_xla_backend.py @@ -10,7 +10,7 @@ def _mp_fn(index): device = torch_xla.device() - if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'): + if xm.xla_device_hw(device) in ('TPU', 'NEURON'): world_size = xr.world_size() rank = xr.global_ordinal() @@ -35,8 +35,7 @@ def _mp_fn(index): scale)) == torch.tensor(True) else: print( - 'Default device {} is not a TPU or GPU device'.format(device), - file=sys.stderr) + 'Default device {} is not a TPU device'.format(device), file=sys.stderr) if __name__ == '__main__': diff --git a/test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py b/test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py index 9089f9d799ff..7737a71591c5 100644 --- a/test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py +++ b/test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py @@ -10,7 +10,7 @@ def _mp_fn(index): device = torch_xla.device() - if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'): + if xm.xla_device_hw(device) in ('TPU', 'NEURON'): world_size = xr.world_size() rank = xr.global_ordinal() @@ -31,8 +31,7 @@ def _mp_fn(index): xinputs.cpu() == expected), f'trial {i}, {xinputs} != {expected}' else: print( - 'Default device {} is not a TPU or GPU device'.format(device), - file=sys.stderr) + 'Default device {} is not a TPU device'.format(device), file=sys.stderr) if __name__ == '__main__': diff --git a/test/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py b/test/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py index 006d3fd33a95..ba07a8e77f3d 100644 --- a/test/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py +++ b/test/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py @@ -10,7 +10,7 @@ def _mp_fn(index): device = torch_xla.device() - if xm.xla_device_hw(device) in ('TPU', 'CUDA'): + if xm.xla_device_hw(device) in ('TPU',): world_size = xr.world_size() rank = xr.global_ordinal() @@ -27,8 +27,7 @@ def _mp_fn(index): assert torch.all(xoutput.cpu() == expected), f'{xoutput} != {expected}' else: print( - 'Default device {} is not a TPU or GPU device'.format(device), - file=sys.stderr) + 'Default device {} is not a TPU device'.format(device), file=sys.stderr) if __name__ == '__main__': diff --git a/torch_xla/test/test_utils.py b/torch_xla/test/test_utils.py index 6e9f779c5f0a..92b5dcb111f2 100644 --- a/torch_xla/test/test_utils.py +++ b/torch_xla/test/test_utils.py @@ -11,11 +11,6 @@ import torch_xla.utils.utils as xu -def skipIfCUDA(reason): - accelerator = xr.device_type() or "" - return lambda f: unittest.skipIf(accelerator.lower() == "cuda", reason)(f) - - def mp_test(func): """Wraps a `unittest.TestCase` function running it within an isolated process. From e0de097f00643661b9dc6cfc99c1b5e86044eab1 Mon Sep 17 00:00:00 2001 From: wirthual Date: Thu, 4 Sep 2025 00:15:29 +0200 Subject: [PATCH 090/133] Make torch_xla package PEP 561 compliant (#9515) The repo seems to be using type hints very thoroghly, this PR adds changes to make the package PEP 561 compliant: https://peps.python.org/pep-0561/ This avoids errors like the following when using mypy: ```bash infinity_emb/inference/loading_strategy.py:13: error: Cannot find implementation or library stub for module named "torch_xla" [import-not-found] ``` Torch package does the [same](https://github.com/pytorch/pytorch/blob/8e07c9870d07c5a318ab21bb16b3fa27576851e6/setup.py#L1288) --- setup.py | 5 ++++- torch_xla/py.typed | 0 2 files changed, 4 insertions(+), 1 deletion(-) create mode 100644 torch_xla/py.typed diff --git a/setup.py b/setup.py index 39bc9129e6a3..33642f9a3f6e 100644 --- a/setup.py +++ b/setup.py @@ -466,7 +466,10 @@ def _get_jax_install_requirements(): 'importlib_metadata>=4.6;python_version<"3.10"', ], package_data={ - 'torch_xla': ['lib/*.so*',], + 'torch_xla': [ + 'lib/*.so*', + 'py.typed', + ], }, entry_points={ 'console_scripts': [ diff --git a/torch_xla/py.typed b/torch_xla/py.typed new file mode 100644 index 000000000000..e69de29bb2d1 From 342de86f759586be7b48b455e1bb4d6e65d619cc Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 3 Sep 2025 23:37:07 -0300 Subject: [PATCH 091/133] Remove other CUDA usage from PyTorch/XLA repository. (#9618) This PR removes CUDA specific logic from the remaining files in this repository. This is in line with the CUDA deprecation that started on release 2.8. **Key Changes:** - Removed CUDA branches for testing (e.g. `.circleci/common.sh`) - Removed files (e.g. documentation) for CUDA related matters (e.g. `docs/source/accelerators/gpu.md`) - Removed mentions to CUDA as a supported PyTorch/XLA accelerator - Removed CUDA specific parameters from CI configuration files (e.g.`.github/workflows/_test.yml`) - Removed CUDA specific parameters from artifacts build configuration files (e.g. `infra/tpu-pytorch-releases/artifacts_builds.tf`) --- .circleci/common.sh | 25 +----- .devcontainer/gpu-internal/devcontainer.json | 30 ------- .github/ISSUE_TEMPLATE.md | 2 +- .github/ISSUE_TEMPLATE/bug-report.md | 2 +- .github/ci.md | 37 +++----- .github/scripts/run_tests.sh | 5 -- .github/workflows/_test.yml | 53 ++++------- .github/workflows/setup/action.yml | 33 ------- CONTRIBUTING.md | 4 - README.md | 88 +++---------------- ...ributed-pytorch-xla-basics-with-pjrt.ipynb | 2 +- docs/source/accelerators/gpu.md | 6 -- docs/source/contribute/bazel.md | 6 +- docs/source/contribute/plugins.md | 3 +- docs/source/learn/_pjrt.md | 8 +- docs/source/perf/amp.md | 4 +- docs/source/perf/spmd_advanced.md | 4 +- docs/source/perf/spmd_gpu.md | 48 ---------- examples/train_resnet_amp.py | 3 +- infra/ansible/README.md | 4 +- infra/ansible/config/apt.yaml | 17 ---- infra/ansible/config/cuda_deps.yaml | 24 ----- infra/ansible/config/vars.yaml | 7 +- infra/ansible/playbook.yaml | 6 +- infra/tpu-pytorch-releases/README.md | 25 +++--- .../tpu-pytorch-releases/artifacts_builds.tf | 27 ------ .../dev_images.auto.tfvars | 12 --- infra/tpu-pytorch-releases/dev_images.tf | 6 +- 28 files changed, 74 insertions(+), 417 deletions(-) delete mode 100644 .devcontainer/gpu-internal/devcontainer.json delete mode 100644 docs/source/accelerators/gpu.md delete mode 100644 docs/source/perf/spmd_gpu.md delete mode 100644 infra/ansible/config/cuda_deps.yaml diff --git a/.circleci/common.sh b/.circleci/common.sh index 3093a8006942..50ec8eae1ade 100755 --- a/.circleci/common.sh +++ b/.circleci/common.sh @@ -158,26 +158,12 @@ function run_torch_xla_cpp_tests() { fi if [ "$USE_COVERAGE" != "0" ]; then - if [ -x "$(command -v nvidia-smi)" ]; then - PJRT_DEVICE=CUDA test/cpp/run_tests.sh $EXTRA_ARGS -L"" - cp $XLA_DIR/bazel-out/_coverage/_coverage_report.dat /tmp/cov1.dat - PJRT_DEVICE=CUDA test/cpp/run_tests.sh -X early_sync -F AtenXlaTensorTest.TestEarlySyncLiveTensors -L"" $EXTRA_ARGS - cp $XLA_DIR/bazel-out/_coverage/_coverage_report.dat /tmp/cov2.dat - lcov --add-tracefile /tmp/cov1.dat -a /tmp/cov2.dat -o /tmp/merged.dat - else - PJRT_DEVICE=CPU test/cpp/run_tests.sh $EXTRA_ARGS -L"" - cp $XLA_DIR/bazel-out/_coverage/_coverage_report.dat /tmp/merged.dat - fi + PJRT_DEVICE=CPU test/cpp/run_tests.sh $EXTRA_ARGS -L"" + cp $XLA_DIR/bazel-out/_coverage/_coverage_report.dat /tmp/merged.dat genhtml /tmp/merged.dat -o ~/htmlcov/cpp/cpp_lcov.info mv /tmp/merged.dat ~/htmlcov/cpp_lcov.info else - # Shard GPU testing - if [ -x "$(command -v nvidia-smi)" ]; then - PJRT_DEVICE=CUDA test/cpp/run_tests.sh $EXTRA_ARGS -L"" - PJRT_DEVICE=CUDA test/cpp/run_tests.sh -X early_sync -F AtenXlaTensorTest.TestEarlySyncLiveTensors -L"" $EXTRA_ARGS - else - PJRT_DEVICE=CPU test/cpp/run_tests.sh $EXTRA_ARGS -L"" - fi + PJRT_DEVICE=CPU test/cpp/run_tests.sh $EXTRA_ARGS -L"" fi popd } @@ -196,11 +182,6 @@ function run_torch_xla_tests() { RUN_CPP="${RUN_CPP_TESTS:0}" RUN_PYTHON="${RUN_PYTHON_TESTS:0}" - if [ -x "$(command -v nvidia-smi)" ]; then - num_devices=$(nvidia-smi --list-gpus | wc -l) - echo "Found $num_devices GPU devices..." - export GPU_NUM_DEVICES=$num_devices - fi export PYTORCH_TESTING_DEVICE_ONLY_FOR="xla" export CXX_ABI=$(python -c "import torch;print(int(torch._C._GLIBCXX_USE_CXX11_ABI))") diff --git a/.devcontainer/gpu-internal/devcontainer.json b/.devcontainer/gpu-internal/devcontainer.json deleted file mode 100644 index ce06bab9e2e7..000000000000 --- a/.devcontainer/gpu-internal/devcontainer.json +++ /dev/null @@ -1,30 +0,0 @@ -{ - "name": "gpu-internal", - "image": "us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.8_cuda_12.1", - "runArgs": [ - "--gpus=all", - "--net=host", - "--shm-size=16G" - ], - "containerEnv": { - "BAZEL_REMOTE_CACHE": "1", - "SILO_NAME": "cache-silo-${localEnv:USER}-gpuvm" - }, - "initializeCommand": "docker pull us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.8_cuda_12.1", - "customizations": { - "vscode": { - "extensions": [ - "llvm-vs-code-extensions.vscode-clangd", - "ms-vscode.cpptools-themes", - "BazelBuild.vscode-bazel", - "DevonDCarew.bazel-code", - "StackBuild.bazel-stack-vscode", - "StackBuild.bazel-stack-vscode-cc", - "xaver.clang-format", - "ryanluker.vscode-coverage-gutters", - "ms-azuretools.vscode-docker", - "ms-python.python" - ] - } - } -} \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md index 6c37920bd137..b44f8dca7ad2 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE.md @@ -13,5 +13,5 @@ Error messages and stack traces are also helpful. ## System Info -- reproducible on XLA backend [CPU/TPU/CUDA]: +- reproducible on XLA backend [CPU/TPU]: - torch_xla version: diff --git a/.github/ISSUE_TEMPLATE/bug-report.md b/.github/ISSUE_TEMPLATE/bug-report.md index 54f785623a50..3c10b58bfe5a 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.md +++ b/.github/ISSUE_TEMPLATE/bug-report.md @@ -46,7 +46,7 @@ Steps to reproduce the behavior: ## Environment - - Reproducible on XLA backend [CPU/TPU/CUDA]: + - Reproducible on XLA backend [CPU/TPU]: - torch_xla version: diff --git a/.github/ci.md b/.github/ci.md index 2cc72b5abf50..cc3994c884e7 100644 --- a/.github/ci.md +++ b/.github/ci.md @@ -44,20 +44,20 @@ fail. Steps for fixing and merging such breaking PyTorch change is as following: ### Running TPU tests on PRs -The `build_and_test.yml` workflow runs tests on the TPU in addition to CPU and -GPU. The set of tests run on the TPU is defined in `test/tpu/run_tests.sh`. +The `build_and_test.yml` workflow runs tests on the TPU in addition to CPU. +The set of tests run on the TPU is defined in `test/tpu/run_tests.sh`. ## CI Environment Before the CI in this repository runs, we build a base dev image. These are the same images we recommend in our VSCode `.devcontainer` setup and nightly build -to ensure consistency between environments. We produce variants with and without -CUDA, configured in `infra/ansible` (build config) and -`infra/tpu-pytorch-releases/dev_images.tf` (build triggers). +to ensure consistency between environments. We produce variants configured in +`infra/ansible` (build config) and `infra/tpu-pytorch-releases/dev_images.tf` +(build triggers). The CI runs in two environments: -1. Organization self-hosted runners for CPU and GPU: used for almost every step +1. Organization self-hosted runners for CPU: used for almost every step of the CI. These runners are managed by PyTorch and have access to the shared ECR repository. 1. TPU self-hosted runners: these are managed by us and are only available in @@ -68,24 +68,18 @@ The CI runs in two environments: We have two build paths for each CI run: -- `torch_xla`: we build the main package to support both TPU and GPU[^1], along +- `torch_xla`: we build the main package to support TPU, along with a CPU build of `torch` from HEAD. This build step exports the `torch-xla-wheels` artifact for downstream use in tests. - Some CI tests also require `torchvision`. To reduce flakiness, we compile `torchvision` from [`torch`'s CI pin][pytorch-vision-pin]. - C++ tests are piggybacked onto the same build and uploaded in the `cpp-test-bin` artifact. -- `torch_xla_cuda_plugin`: the XLA CUDA runtime can be built independently of - either `torch` or `torch_xla` -- it depends only on our pinned OpenXLA. Thus, - this build should be almost entirely cached, unless your PR changes the XLA - pin or adds a patch. -Both the main package build and plugin build are configured with ansible at -`infra/ansible`, although they run in separate stages (`stage=build_srcs` vs -`stage=build_plugin`). This is the same configuration we use for our nightly and -release builds. +The main package build is configured with ansible at `infra/ansible`. This is +the same configuration we use for our nightly and release builds. -The CPU and GPU test configs are defined in the same file, `_test.yml`. Since +The CPU test config is defined in the file `_test.yml`. Since some of the tests come from the upstream PyTorch repository, we check out PyTorch at the same git rev as the `build` step (taken from `torch_xla.version.__torch_gitrev__`). The tests are split up into multiple @@ -93,23 +87,16 @@ groups that run in parallel; the `matrix` section of `_test.yml` corresponds to in `.github/scripts/run_tests.sh`. CPU tests run immediately after the `torch_xla` build completes. This will -likely be the first test feedback on your commit. GPU tests will launch when -both the `torch_xla` and `torch_xla_cuda_plugin` complete. GPU compilation is -much slower due to the number of possible optimizations, and the GPU chips -themselves are quite outdated, so these tests will take longer to run than the -CPU tests. +likely be the first test feedback on your commit. ![CPU tests launch when `torch_xla` is complete](../docs/assets/ci_test_dependency.png) -![GPU tests also depend on CUDA -plugin](../docs/assets/ci_test_dependency_gpu.png) - For the C++ test groups in either case, the test binaries are pre-built during the build phase and packaged in `cpp-test-bin`. This will only be downloaded if necessary. -[^1]: Note: both GPU and TPU support require their respective plugins to be +[^1]: Note: TPU support require its respective plugins to be installed. This package will _not_ work on either out of the box. ### TPU CI diff --git a/.github/scripts/run_tests.sh b/.github/scripts/run_tests.sh index 7ae422c47953..65f46f9cf48c 100755 --- a/.github/scripts/run_tests.sh +++ b/.github/scripts/run_tests.sh @@ -77,11 +77,6 @@ PYTORCH_DIR=$1 XLA_DIR=$2 USE_COVERAGE="${3:-0}" -if [ -x "$(command -v nvidia-smi)" ]; then - num_devices=$(nvidia-smi --list-gpus | wc -l) - echo "Found $num_devices GPU devices..." - export GPU_NUM_DEVICES=$num_devices -fi export PYTORCH_TESTING_DEVICE_ONLY_FOR="xla" export CXX_ABI=$(python -c "import torch;print(int(torch._C._GLIBCXX_USE_CXX11_ABI))") diff --git a/.github/workflows/_test.yml b/.github/workflows/_test.yml index 4ef00dcedaed..23ffe34f8a46 100644 --- a/.github/workflows/_test.yml +++ b/.github/workflows/_test.yml @@ -23,11 +23,6 @@ on: description: | Set the maximum (in minutes) how long the workflow should take to finish timeout-minutes: - install-cuda-plugin: - required: false - type: boolean - default: false - description: Whether to install CUDA plugin package torch-commit: required: true type: string @@ -46,7 +41,7 @@ jobs: runs-on: ${{ inputs.runner }} container: image: ${{ inputs.dev-image }} - options: "${{ inputs.install-cuda-plugin == true && '--gpus all' || '' }} --shm-size 16g" + options: "--shm-size 16g" strategy: fail-fast: false matrix: @@ -95,9 +90,7 @@ jobs: uses: ./.actions/.github/workflows/setup with: torch-commit: ${{ inputs.torch-commit }} - cuda: ${{ inputs.install-cuda-plugin && true || false }} wheels-artifact: torch-xla-wheels - cuda-plugin-artifact: ${{ inputs.install-cuda-plugin && 'cuda-plugin' || null }} - name: Fetch CPP test binaries if: inputs.has_code_changes == 'true' && matrix.run_cpp_tests uses: actions/download-artifact@v4 @@ -111,9 +104,6 @@ jobs: run: | chmod +x /tmp/test/bin/* ls -l /tmp/test/bin - - name: Check GPU - if: inputs.has_code_changes == 'true' && inputs.install-cuda-plugin - run: nvidia-smi - name: Install test deps if: inputs.has_code_changes == 'true' shell: bash @@ -164,35 +154,24 @@ jobs: exit 0 fi docker cp "${pid}":/home/jenkins/htmlcov "${GITHUB_WORKSPACE}" - if [ -n "${GPU_FLAG:-}" ]; then - if [ -n "${PYTHON_TEST_NAME}" ]; then - gsutil cp ${GITHUB_WORKSPACE}/htmlcov/lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/absolute/pytorchxla/${CIRCLE_WORKFLOW_ID}/gpu_python_coverage_${PYTHON_TEST_NAME}.out - gsutil cp ${GITHUB_WORKSPACE}/htmlcov/lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/incremental/pytorchxla/${CIRCLE_WORKFLOW_ID}/gpu_python_coverage_${PYTHON_TEST_NAME}.out - fi - if [ -n "${CPP_TEST_NAME}" ]; then - gsutil cp ${GITHUB_WORKSPACE}/htmlcov/cpp_lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/absolute/pytorchxla/${CIRCLE_WORKFLOW_ID}/gpu_cpp_coverage_${CPP_TEST_NAME}.out - gsutil cp ${GITHUB_WORKSPACE}/htmlcov/cpp_lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/incremental/pytorchxla/${CIRCLE_WORKFLOW_ID}/gpu_cpp_coverage_${CPP_TEST_NAME}.out - fi - else - if [ -n "${PYTHON_TEST_NAME}" ]; then - gsutil cp ${GITHUB_WORKSPACE}/htmlcov/lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/absolute/pytorchxla/${CIRCLE_WORKFLOW_ID}/cpu_python_coverage_${PYTHON_TEST_NAME}.out - gsutil cp ${GITHUB_WORKSPACE}/htmlcov/lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/incremental/pytorchxla/${CIRCLE_WORKFLOW_ID}/cpu_python_coverage_${PYTHON_TEST_NAME}.out - fi + if [ -n "${PYTHON_TEST_NAME}" ]; then + gsutil cp ${GITHUB_WORKSPACE}/htmlcov/lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/absolute/pytorchxla/${CIRCLE_WORKFLOW_ID}/cpu_python_coverage_${PYTHON_TEST_NAME}.out + gsutil cp ${GITHUB_WORKSPACE}/htmlcov/lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/incremental/pytorchxla/${CIRCLE_WORKFLOW_ID}/cpu_python_coverage_${PYTHON_TEST_NAME}.out + fi - if [ -n "${CPP_TEST_NAME}" ]; then - gsutil cp ${GITHUB_WORKSPACE}/htmlcov/cpp_lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/absolute/pytorchxla/${CIRCLE_WORKFLOW_ID}/cpu_cpp_coverage_${CPP_TEST_NAME}.out - gsutil cp ${GITHUB_WORKSPACE}/htmlcov/cpp_lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/incremental/pytorchxla/${CIRCLE_WORKFLOW_ID}/cpu_cpp_coverage_${CPP_TEST_NAME}.out - fi + if [ -n "${CPP_TEST_NAME}" ]; then + gsutil cp ${GITHUB_WORKSPACE}/htmlcov/cpp_lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/absolute/pytorchxla/${CIRCLE_WORKFLOW_ID}/cpu_cpp_coverage_${CPP_TEST_NAME}.out + gsutil cp ${GITHUB_WORKSPACE}/htmlcov/cpp_lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/incremental/pytorchxla/${CIRCLE_WORKFLOW_ID}/cpu_cpp_coverage_${CPP_TEST_NAME}.out + fi - if [ "${CPP_TEST_NAME}" == "cpp_tests" ]; then - ABS_METADATA='{"host": "github", "project": "pytorchxla", "trace_type": "LCOV", "commit_id": '\"${GITHUB_SHA}\"', "ref": "HEAD", "source": "https://github.com/pytorch/xla", "owner": "cloud-tpu-pt-dev", "bug_component": "587012"}' - echo $ABS_METADATA > abs_metadata.json - gsutil cp abs_metadata.json gs://ng3-metrics/ng3-pytorchxla-coverage/absolute/pytorchxla/${CIRCLE_WORKFLOW_ID}/metadata.json + if [ "${CPP_TEST_NAME}" == "cpp_tests" ]; then + ABS_METADATA='{"host": "github", "project": "pytorchxla", "trace_type": "LCOV", "commit_id": '\"${GITHUB_SHA}\"', "ref": "HEAD", "source": "https://github.com/pytorch/xla", "owner": "cloud-tpu-pt-dev", "bug_component": "587012"}' + echo $ABS_METADATA > abs_metadata.json + gsutil cp abs_metadata.json gs://ng3-metrics/ng3-pytorchxla-coverage/absolute/pytorchxla/${CIRCLE_WORKFLOW_ID}/metadata.json - INC_METADATA='{"host": "github", "project": "pytorchxla", "trace_type": "LCOV", "patchset_num": 1, "change_id": '${CIRCLE_BUILD_NUM}', "owner": "cloud-tpu-pt-dev", "bug_component": "587012"}' - echo $INC_METADATA > inc_metadata.json - gsutil cp inc_metadata.json gs://ng3-metrics/ng3-pytorchxla-coverage/incremental/pytorchxla/${CIRCLE_WORKFLOW_ID}/metadata.json - fi + INC_METADATA='{"host": "github", "project": "pytorchxla", "trace_type": "LCOV", "patchset_num": 1, "change_id": '${CIRCLE_BUILD_NUM}', "owner": "cloud-tpu-pt-dev", "bug_component": "587012"}' + echo $INC_METADATA > inc_metadata.json + gsutil cp inc_metadata.json gs://ng3-metrics/ng3-pytorchxla-coverage/incremental/pytorchxla/${CIRCLE_WORKFLOW_ID}/metadata.json fi - name: Report no code changes if: inputs.has_code_changes == 'false' diff --git a/.github/workflows/setup/action.yml b/.github/workflows/setup/action.yml index 574b85e5b0d5..e1d6fdb8599d 100644 --- a/.github/workflows/setup/action.yml +++ b/.github/workflows/setup/action.yml @@ -3,20 +3,10 @@ inputs: torch-commit: type: string description: PyTorch commit to check out, if provided - cuda: - type: boolean - description: Whether to set up CUDA library paths - default: false wheels-artifact: type: string description: | Artifact containing `torch` (cpu) and `torch-xla` wheels to install - cuda-plugin-artifact: - type: string - description: Artifact containing `torch-xla-cuda-plugin` to install - cuda-torch-artifact: - type: string - description: Artifact containing CUDA build of `torch` runs: using: "composite" steps: @@ -26,12 +16,6 @@ runs: run: | ls -la rm -rvf ${GITHUB_WORKSPACE}/* - - name: Setup CUDA environment - shell: bash - run: | - echo "PATH=$PATH:/usr/local/cuda-12.3/bin" >> $GITHUB_ENV - echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.3/lib64" >> $GITHUB_ENV - if: ${{ inputs.cuda }} - name: Setup gcloud shell: bash run: | @@ -59,23 +43,6 @@ runs: name: ${{ inputs.wheels-artifact }} path: /tmp/wheels/ if: ${{ inputs.wheels-artifact }} - - name: Fetch CUDA plugin - uses: actions/download-artifact@v4 - with: - name: ${{ inputs.cuda-plugin-artifact }} - path: /tmp/wheels/ - if: ${{ inputs.cuda-plugin-artifact }} - - name: Remove CPU `torch` build - shell: bash - run: | - rm -rf /tmp/wheels/torch-* - if: ${{ inputs.cuda-torch-artifact }} - - name: Fetch CUDA `torch` build - uses: actions/download-artifact@v4 - with: - name: ${{ inputs.cuda-torch-artifact }} - path: /tmp/wheels/ - if: ${{ inputs.cuda-torch-artifact }} - name: Install wheels shell: bash run: | diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6c05fd88f747..b8d233c87002 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -238,10 +238,6 @@ first time, you may need to build everything again, for example, after a python setup.py develop ``` -### Additional steps for GPU - -Please refer to this [guide](https://github.com/pytorch/xla/blob/master/plugins/cuda/README.md). - ## Before Creating a Pull Request In `pytorch/xla` repo we enforce coding style for both C++ and Python files. diff --git a/README.md b/README.md index 989858ef16fd..d02ae1a0968e 100644 --- a/README.md +++ b/README.md @@ -95,24 +95,23 @@ batch size 1024: Our github contains many useful docs on working with different aspects of PyTorch XLA, here is a list of useful docs spread around our repository: - [docs/source/learn](https://github.com/pytorch/xla/tree/master/docs/source/learn): docs for learning concepts associated with XLA, troubleshooting, pjrt, eager mode, and dynamic shape. -- [docs/source/accelerators](https://github.com/pytorch/xla/tree/master/docs/source/accelerators): references to `GPU` and `TPU` accelerator documents. +- [docs/source/accelerators](https://github.com/pytorch/xla/tree/master/docs/source/accelerators): references to `TPU` accelerator documents. - [docs/source/perf](https://github.com/pytorch/xla/tree/master/docs/source/perf): documentation about performance specific aspects of PyTorch/XLA such as: `AMP`, `DDP`, `Dynamo`, Fori loop, `FSDP`, quantization, recompilation, and `SPMD` - [docs/source/features](https://github.com/pytorch/xla/tree/master/docs/source/features): documentation on distributed torch, pallas, scan, and stable hlo. - [docs/source/contribute](https://github.com/pytorch/xla/tree/master/docs/source/contribute): documents on setting up PyTorch for development, and guides for lowering operations. - PJRT plugins: - [CPU](https://github.com/pytorch/xla/blob/master/plugins/cpu/README.md) - - [CUDA](https://github.com/pytorch/xla/blob/master/plugins/cuda/README.md) - [torchax/docs](https://github.com/pytorch/xla/tree/master/torchax/docs): torchax documents - [torchax/examples](https://github.com/pytorch/xla/tree/master/torchax/examples): torchax examples ## Getting Started Following here are guides for two modes: -- Single process: one Python interpreter controlling a single GPU/TPU at a time -- Multi process: N Python interpreters are launched, corresponding to N GPU/TPUs +- Single process: one Python interpreter controlling a single TPU at a time +- Multi process: N Python interpreters are launched, corresponding to N TPUs found on the system -Another mode is SPMD, where one Python interpreter controls all N GPU/TPUs found on +Another mode is SPMD, where one Python interpreter controls all N TPUs found on the system. Multi processing is more complex, and is not compatible with SPMD. This tutorial does not dive into SPMD. For more on that, check our [SPMD guide](https://github.com/pytorch/xla/blob/master/docs/source/perf/spmd_basic.md). @@ -223,7 +222,7 @@ If you're using `DistributedDataParallel`, make the following changes: Additional information on PyTorch/XLA, including a description of its semantics and functions, is available at [PyTorch.org](http://pytorch.org/xla/). See the [API Guide](API_GUIDE.md) for best practices when writing networks that run on -XLA devices (TPU, CUDA, CPU and...). +XLA devices (TPU, CPU and...). Our comprehensive user guides are available at: @@ -234,13 +233,9 @@ Our comprehensive user guides are available at: ## PyTorch/XLA tutorials -* [Cloud TPU VM - quickstart](https://cloud.google.com/tpu/docs/run-calculation-pytorch) -* [Cloud TPU Pod slice - quickstart](https://cloud.google.com/tpu/docs/pytorch-pods) -* [Profiling on TPU - VM](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm) -* [GPU guide](docs/gpu.md) +* [Cloud TPU VM quickstart](https://cloud.google.com/tpu/docs/run-calculation-pytorch) +* [Cloud TPU Pod slice quickstart](https://cloud.google.com/tpu/docs/pytorch-pods) +* [Profiling on TPU VM](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm) ## Reference implementations @@ -259,12 +254,10 @@ Cloud TPU plugin corresponding to your installed `torch_xla`, install the option pip install torch_xla[tpu] ``` -GPU release builds and GPU/TPU nightly builds are available in our public GCS bucket. +TPU nightly builds are available in our public GCS bucket. -| Version | Cloud GPU VM Wheels | +| Version | Cloud TPU Nightly Wheels | | --- | ----------- | -| 2.7 (CUDA 12.6 + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.6/torch_xla-2.7.0-cp310-cp310-manylinux_2_28_x86_64.whl` | -| 2.7 (CUDA 12.6 + Python 3.11) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.6/torch_xla-2.7.0-cp311-cp311-manylinux_2_28_x86_64.whl` | | nightly (Python 3.11) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev-cp311-cp311-linux_x86_64.whl` | | nightly (Python 3.12) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev-cp312-cp312-linux_x86_64.whl` | | nightly (Python 3.13) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev-cp312-cp312-linux_x86_64.whl` | @@ -296,27 +289,6 @@ The torch wheel version `2.9.0.dev20250423+cpu` can be found at https://download | 2.1 (XRT + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/xrt/tpuvm/torch_xla-2.1.0%2Bxrt-cp310-cp310-manylinux_2_28_x86_64.whl` | | 2.1 (Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.1.0-cp38-cp38-linux_x86_64.whl` | -
- -| Version | GPU Wheel | -| --- | ----------- | -| 2.5 (CUDA 12.1 + Python 3.9) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl` | -| 2.5 (CUDA 12.1 + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl` | -| 2.5 (CUDA 12.1 + Python 3.11) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl` | -| 2.5 (CUDA 12.4 + Python 3.9) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl` | -| 2.5 (CUDA 12.4 + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl` | -| 2.5 (CUDA 12.4 + Python 3.11) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl` | -| 2.4 (CUDA 12.1 + Python 3.9) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp39-cp39-manylinux_2_28_x86_64.whl` | -| 2.4 (CUDA 12.1 + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp310-cp310-manylinux_2_28_x86_64.whl` | -| 2.4 (CUDA 12.1 + Python 3.11) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp311-cp311-manylinux_2_28_x86_64.whl` | -| 2.3 (CUDA 12.1 + Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp38-cp38-manylinux_2_28_x86_64.whl` | -| 2.3 (CUDA 12.1 + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl` | -| 2.3 (CUDA 12.1 + Python 3.11) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp311-cp311-manylinux_2_28_x86_64.whl` | -| 2.2 (CUDA 12.1 + Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp38-cp38-manylinux_2_28_x86_64.whl` | -| 2.2 (CUDA 12.1 + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl` | -| 2.1 + CUDA 11.8 | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/11.8/torch_xla-2.1.0-cp38-cp38-manylinux_2_28_x86_64.whl` | -| nightly + CUDA 12.0 >= 2023/06/27| `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.0/torch_xla-nightly-cp38-cp38-linux_x86_64.whl` | - ### Docker @@ -337,46 +309,6 @@ To use the above dockers, please pass `--privileged --net host --shm-size=16G` a ```bash docker run --privileged --net host --shm-size=16G -it us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm /bin/bash ``` -
- -| Version | GPU CUDA 12.6 Docker | -| --- | ----------- | -| 2.7 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.7.0_3.10_cuda_12.6` | - - -
- - -| Version | GPU CUDA 12.4 Docker | -| --- | ----------- | -| 2.5 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_cuda_12.4` | -| 2.4 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_cuda_12.4` | - -
- - -| Version | GPU CUDA 12.1 Docker | -| --- | ----------- | -| 2.5 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_cuda_12.1` | -| 2.4 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_cuda_12.1` | -| 2.3 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.3.0_3.10_cuda_12.1` | -| 2.2 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.2.0_3.10_cuda_12.1` | -| 2.1 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_cuda_12.1` | -| nightly | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1` | -| nightly at date | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1_YYYYMMDD` | - -
- -| Version | GPU CUDA 11.8 + Docker | -| --- | ----------- | -| 2.1 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_cuda_11.8` | -| 2.0 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.0_3.8_cuda_11.8` | - -
- - -To run on [compute instances with -GPUs](https://cloud.google.com/compute/docs/gpus/create-vm-with-gpus). ## Troubleshooting diff --git a/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb b/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb index 8d4fbd95bff7..f06e4a9b9f03 100644 --- a/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb +++ b/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb @@ -461,7 +461,7 @@ " torch.manual_seed(42)\n", " model = nn.Linear(128, 10).to(device)\n", "\n", - " # Optional for TPU v4 and GPU\n", + " # Optional for TPU v4\n", " xm.broadcast_master_param(model)\n", "\n", " model = DDP(model, gradient_as_bucket_view=True)\n", diff --git a/docs/source/accelerators/gpu.md b/docs/source/accelerators/gpu.md deleted file mode 100644 index 56abb192a704..000000000000 --- a/docs/source/accelerators/gpu.md +++ /dev/null @@ -1,6 +0,0 @@ -# Learn about GPUs - -For information on GPUs on Google Cloud, see: - -- [About GPUs on Google Cloud](https://cloud.google.com/compute/docs/gpus/overview) -- [GPU machine types](https://cloud.google.com/compute/docs/gpus) diff --git a/docs/source/contribute/bazel.md b/docs/source/contribute/bazel.md index 0e41ec837057..69e1d5954c82 100644 --- a/docs/source/contribute/bazel.md +++ b/docs/source/contribute/bazel.md @@ -22,9 +22,7 @@ http_archive( ], patch_tool = "patch", patches = [ - "//openxla_patches:gpu_nvml.diff", - "//openxla_patches:gpu_race_condition.diff", - "//openxla_patches:count_down.diff", + "//openxla_patches:no_fortify.diff", ], strip_prefix = "xla-" + xla_hash, urls = [ @@ -223,7 +221,7 @@ The `xla_client` tests are pure hermetic tests that can be easily executed. The `torch_xla` plugin tests are more complex: they require `torch` and `torch_xla` to be installed, and they cannot run in parallel, since they are using either XRT server/client on the same -port, or because they use a GPU or TPU device and there's only one +port, or because they use a TPU device and there's only one available at the time. For that reason, all tests under `torch_xla/csrc/` are bundled into a single target `:main` that runs them all sequentially. diff --git a/docs/source/contribute/plugins.md b/docs/source/contribute/plugins.md index 40ae841e8d7b..84ca6fe1c9ea 100644 --- a/docs/source/contribute/plugins.md +++ b/docs/source/contribute/plugins.md @@ -1,8 +1,7 @@ # Custom Hardware Plugins PyTorch/XLA supports custom hardware through OpenXLA's PJRT C API. The -PyTorch/XLA team directly supports plugins for Cloud TPU (`libtpu`) and -GPU ([OpenXLA](https://github.com/openxla/xla/tree/main/xla/pjrt/gpu)). +PyTorch/XLA team directly supports plugins for Cloud TPU (`libtpu`). The same plugins may also be used by JAX and TF. ## Implementing a PJRT Plugin diff --git a/docs/source/learn/_pjrt.md b/docs/source/learn/_pjrt.md index 16300239353a..2f4f446991de 100644 --- a/docs/source/learn/_pjrt.md +++ b/docs/source/learn/_pjrt.md @@ -38,7 +38,7 @@ the `runtime` tag. ## TL;DR - To use the PJRT preview runtime, set the `PJRT_DEVICE` environment - variable to `CPU`, `TPU`, or `CUDA` + variable to `CPU`, or `TPU` - In XRT, all distributed workloads are multiprocess, with one process per device. On TPU v2 and v3 in PJRT, workloads are multiprocess and multithreaded (4 processes with 2 threads each), so your workload @@ -57,7 +57,7 @@ the `runtime` tag. - To use `torch.distributed`, import `torch_xla.experimental.pjrt_backend` and use the `xla://` `init_method`. - - These steps are optional for GPU and TPU v4. + - These steps are optional for TPU v4. Sample diff from XRT to PJRT: @@ -84,7 +84,7 @@ def _mp_fn(index): torch.manual_seed(42) model = nn.Linear(128, 10).to(device) -+ # Optional for TPU v4 and GPU ++ # Optional for TPU v4 + xm.broadcast_master_param(model) model = DDP(model, gradient_as_bucket_view=True) @@ -119,7 +119,7 @@ if __name__ == '__main__': ## Benefits - Simple runtime configuration: just set `PJRT_DEVICE` to `TPU`, - `CPU`, or `CUDA` and start using XLA! Or, let PJRT select a device + or `CPU` and start using XLA! Or, let PJRT select a device automatically based on your environment. - Improved performance: reduced overhead from gRPC means faster end-to-end execution. On TorchBench 2.0, we observed a \>35% diff --git a/docs/source/perf/amp.md b/docs/source/perf/amp.md index 36d777fd865f..223e338f2135 100644 --- a/docs/source/perf/amp.md +++ b/docs/source/perf/amp.md @@ -2,7 +2,7 @@ Pytorch/XLA's AMP extends [Pytorch's AMP package](https://pytorch.org/docs/stable/amp.html) with support for -automatic mixed precision on `XLA:GPU` and `XLA:TPU` devices. AMP is +automatic mixed precision on `XLA:TPU` devices. AMP is used to accelerate training and inference by executing certain operations in `float32` and other operations in a lower precision datatype (`float16` or `bfloat16` depending on hardware support). This @@ -99,4 +99,4 @@ unlisted ops run if they're downstream from autocasted ops. Our [mnist training script](https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist_amp.py) and [imagenet training script](https://github.com/pytorch/xla/blob/master/test/test_train_mp_imagenet_amp.py) -demonstrate how AMP is used on both TPUs and GPUs. +demonstrate how AMP is used on TPUs. diff --git a/docs/source/perf/spmd_advanced.md b/docs/source/perf/spmd_advanced.md index 7005ee5dd4c0..2a056dc3d693 100644 --- a/docs/source/perf/spmd_advanced.md +++ b/docs/source/perf/spmd_advanced.md @@ -110,7 +110,7 @@ torch.ops.xla.dynamo_mark_sharding(output, device_ids, mesh_shape, axis_names, p ### SPMD Debugging Tool -We provide a `shard placement visualization debug tool` for PyTorch/XLA SPMD user on TPU/GPU/CPU with single-host/multi-host: you could use `visualize_tensor_sharding` to visualize sharded tensor, or you could use `visualize_sharding` to visualize sharing string. Here are two code examples on TPU single-host(v4-8) with `visualize_tensor_sharding` or `visualize_sharding`: +We provide a `shard placement visualization debug tool` for PyTorch/XLA SPMD user on TPU/CPU with single-host/multi-host: you could use `visualize_tensor_sharding` to visualize sharded tensor, or you could use `visualize_sharding` to visualize sharing string. Here are two code examples on TPU single-host(v4-8) with `visualize_tensor_sharding` or `visualize_sharding`: - Code snippet used `visualize_tensor_sharding` and visualization result: ```python @@ -141,7 +141,7 @@ generated_table = visualize_sharding(sharding, use_color=False) visualize_sharding example on TPU v4-8(single-host) -You could use these examples on TPU/GPU/CPU single-host and modify it to run on multi-host. And you could modify it to sharding-style `tiled`, `partial_replication` and `replicated`. +You could use these examples on TPU/CPU single-host and modify it to run on multi-host. And you could modify it to sharding-style `tiled`, `partial_replication` and `replicated`. ### Auto-Sharding We are introducing a new PyTorch/XLA SPMD feature, called ``auto-sharding``, [RFC](https://github.com/pytorch/xla/issues/6322). This is an experimental feature in `r2.3` and `nightly`, that supports `XLA:TPU` and a single TPUVM host. diff --git a/docs/source/perf/spmd_gpu.md b/docs/source/perf/spmd_gpu.md deleted file mode 100644 index cda25723aaad..000000000000 --- a/docs/source/perf/spmd_gpu.md +++ /dev/null @@ -1,48 +0,0 @@ -# Running SPMD on GPU - -PyTorch/XLA supports SPMD on NVIDIA GPU (single-node or multi-nodes). -The training/inference script remains the same as the one used for TPU, -such as this [ResNet -script](https://github.com/pytorch/xla/blob/1dc78948c0c9d018d8d0d2b4cce912552ab27083/test/spmd/test_train_spmd_imagenet.py). -To execute the script using SPMD, we leverage `torchrun`: - - PJRT_DEVICE=CUDA \ - torchrun \ - --nnodes=${NUM_GPU_MACHINES} \ - --node_rank=${RANK_OF_CURRENT_MACHINE} \ - --nproc_per_node=1 \ - --rdzv_endpoint=":" \ - training_or_inference_script_using_spmd.py - -- `--nnodes`: how many GPU machines to be used. -- `--node_rank`: the index of the current GPU machines. The value can - be 0, 1, ..., \${NUMBER_GPU_VM}-1. -- `--nproc_per_node`: the value must be 1 due to the SPMD requirement. -- `--rdzv_endpoint`: the endpoint of the GPU machine with - node_rank==0, in the form `host:port`. The host will be the internal - IP address. The `port` can be any available port on the machine. For - single-node training/inference, this parameter can be omitted. - -For example, if you want to train a ResNet model on 2 GPU machines using -SPMD, you can run the script below on the first machine: - - XLA_USE_SPMD=1 PJRT_DEVICE=CUDA \ - torchrun \ - --nnodes=2 \ - --node_rank=0 \ - --nproc_per_node=1 \ - --rdzv_endpoint=":12355" \ - pytorch/xla/test/spmd/test_train_spmd_imagenet.py --fake_data --batch_size 128 - -and run the following on the second machine: - - XLA_USE_SPMD=1 PJRT_DEVICE=CUDA \ - torchrun \ - --nnodes=2 \ - --node_rank=1 \ - --nproc_per_node=1 \ - --rdzv_endpoint=":12355" \ - pytorch/xla/test/spmd/test_train_spmd_imagenet.py --fake_data --batch_size 128 - -For more information, please refer to the [SPMD support on GPU -RFC](https://github.com/pytorch/xla/issues/6256). diff --git a/examples/train_resnet_amp.py b/examples/train_resnet_amp.py index f5ca308bed75..f63c3cad8544 100644 --- a/examples/train_resnet_amp.py +++ b/examples/train_resnet_amp.py @@ -22,8 +22,7 @@ def train_loop_fn(self, loader, epoch): with autocast(torch_xla.device()): output = self.model(data) loss = self.loss_fn(output, target) - # TPU amp uses bf16 hence gradient scaling is not necessary. If runnign with XLA:GPU - # check https://github.com/pytorch/xla/blob/master/docs/amp.md#amp-for-xlagpu. + # TPU amp uses bf16 hence gradient scaling is not necessary. loss.backward() self.run_optimizer() tracker.add(self.batch_size) diff --git a/infra/ansible/README.md b/infra/ansible/README.md index 9094f645de30..9ce34d962cff 100644 --- a/infra/ansible/README.md +++ b/infra/ansible/README.md @@ -23,11 +23,11 @@ behavior (installed pip/apt packages and set environment variables): * `stage`: build or release. Different packages are installed depending on the chosen stage. * `arch`: aarch64 or amd64. Architecture of the built image and wheels. -* `accelerator`: tpu or cuda. Available accelerator. +* `accelerator`: tpu. Available accelerator. The variables can be passed through `-e` flag: `-e "="`. -Example: `ansible-playbook playbook.yaml -e "stage=build arch=amd64 accelerator=tpu"` +Example: `ansible-playbook playbook.yaml -e "stage=build arch=amd64"` ## Config structure diff --git a/infra/ansible/config/apt.yaml b/infra/ansible/config/apt.yaml index d026fea3e037..ae3d95468344 100644 --- a/infra/ansible/config/apt.yaml +++ b/infra/ansible/config/apt.yaml @@ -20,13 +20,6 @@ apt: - lcov - less - build_cuda: - - "cuda-libraries-{{ cuda_version | replace('.', '-') }}" - - "cuda-toolkit-{{ cuda_version | replace('.', '-') }}" - - "cuda-minimal-build-{{ cuda_version | replace('.', '-') }}" - - "{{ cuda_deps['libcudnn'][cuda_version] }}" - - "{{ cuda_deps['libcudnn-dev'][cuda_version] }}" - build_aarch64: - scons @@ -39,23 +32,13 @@ apt: - patch - vim - release_cuda: - - "cuda-libraries-{{ cuda_version | replace('.', '-') }}" - - "cuda-minimal-build-{{ cuda_version | replace('.', '-') }}" - - "{{ cuda_deps['libcudnn'][cuda_version] }}" - # Specify objects with string fields `url` and `keyring`. # The keyring path should start with /usr/share/keyrings/ for debian and ubuntu. signing_keys: - url: https://apt.llvm.org/llvm-snapshot.gpg.key keyring: /usr/share/keyrings/llvm.pgp - # Get the recent key version from - # https://docs.nvidia.com/cuda/cuda-installation-guide-linux/#network-repo-installation-for-debian. - - url: "https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub" - keyring: /usr/share/keyrings/cuda.pgp repos: # signed-by path should match the corresponding keyring path above. - "deb [signed-by=/usr/share/keyrings/llvm.pgp] http://apt.llvm.org/{{ llvm_debian_repo }}/ llvm-toolchain-{{ llvm_debian_repo }}-{{ clang_version }} main" - "deb-src [signed-by=/usr/share/keyrings/llvm.pgp] http://apt.llvm.org/{{ llvm_debian_repo }}/ llvm-toolchain-{{ llvm_debian_repo }}-{{ clang_version }} main" - - "deb [signed-by=/usr/share/keyrings/cuda.pgp] https://developer.download.nvidia.com/compute/cuda/repos/{{ cuda_repo }}/x86_64/ /" diff --git a/infra/ansible/config/cuda_deps.yaml b/infra/ansible/config/cuda_deps.yaml deleted file mode 100644 index 3732bb0f93ec..000000000000 --- a/infra/ansible/config/cuda_deps.yaml +++ /dev/null @@ -1,24 +0,0 @@ -# Versions of cuda dependencies for given cuda versions. -# Note: wrap version in quotes to ensure they're treated as strings. -cuda_deps: - # List all libcudnn8 versions with `apt list -a libcudnn8` - libcudnn: - "12.8": libcudnn9-cuda-12=9.1.1.17-1 - "12.6": libcudnn9-cuda-12=9.1.1.17-1 - "12.4": libcudnn9-cuda-12=9.1.1.17-1 - "12.3": libcudnn9-cuda-12=9.1.1.17-1 - "12.1": libcudnn8=8.9.2.26-1+cuda12.1 - "12.0": libcudnn8=8.8.0.121-1+cuda12.0 - "11.8": libcudnn8=8.7.0.84-1+cuda11.8 - "11.7": libcudnn8=8.5.0.96-1+cuda11.7 - "11.2": libcudnn8=8.1.1.33-1+cuda11.2 - libcudnn-dev: - "12.8": libcudnn9-dev-cuda-12=9.1.1.17-1 - "12.6": libcudnn9-dev-cuda-12=9.1.1.17-1 - "12.4": libcudnn9-dev-cuda-12=9.1.1.17-1 - "12.3": libcudnn9-dev-cuda-12=9.1.1.17-1 - "12.1": libcudnn8-dev=8.9.2.26-1+cuda12.1 - "12.0": libcudnn8-dev=8.8.0.121-1+cuda12.0 - "11.8": libcudnn8-dev=8.7.0.84-1+cuda11.8 - "11.7": libcudnn8-dev=8.5.0.96-1+cuda11.7 - "11.2": libcudnn8-dev=8.1.1.33-1+cuda11.2 diff --git a/infra/ansible/config/vars.yaml b/infra/ansible/config/vars.yaml index f34e2c3cb632..c336e7754f46 100644 --- a/infra/ansible/config/vars.yaml +++ b/infra/ansible/config/vars.yaml @@ -1,8 +1,3 @@ -# Used for fetching cuda from the right repo, see apt.yaml. -cuda_repo: debian11 -cuda_version: "11.8" -# Determines supported GPUs. See https://developer.nvidia.com/cuda-gpus -cuda_compute_capabilities: 5.2,7.0,7.5,8.0,9.0 # Used for fetching clang from the right repo, see apt.yaml. llvm_debian_repo: bullseye clang_version: 17 @@ -10,7 +5,7 @@ clang_version: 17 package_version: 2.9.0 # If set to true, wheels will be renamed to $WHEEL_NAME-nightly-cp38-cp38-linux_x86_64.whl. nightly_release: false -# Whether to preinstall libtpu in the PyTorch/XLA wheel. Ignored for GPU build. +# Whether to preinstall libtpu in the PyTorch/XLA wheel. bundle_libtpu: 1 # Suffix for bazel remote cache key cache_suffix: "" diff --git a/infra/ansible/playbook.yaml b/infra/ansible/playbook.yaml index 7626714e8d18..85153a43d3a2 100644 --- a/infra/ansible/playbook.yaml +++ b/infra/ansible/playbook.yaml @@ -6,7 +6,7 @@ # - stage: build or release. Different packages are installed depending on # the chosen stage. # - arch: aarch64 or amd64. Architecture of the built image and wheels. - # - accelerator: tpu or cuda. Available accelerator. + # - accelerator: tpu. pre_tasks: - name: "Validate required variables" ansible.builtin.assert: @@ -20,7 +20,7 @@ - name: arch pattern: ^(aarch64|amd64)$ - name: accelerator - pattern: ^(tpu|cuda)$ + pattern: ^tpu$ - name: "Include vars from config files" ansible.builtin.include_vars: @@ -28,8 +28,6 @@ loop: # vars.yaml should be the first as other config files depend on it. - vars.yaml - # cuda_deps should be loaded before apt, since apt depends on it. - - cuda_deps.yaml - apt.yaml - pip.yaml - env.yaml diff --git a/infra/tpu-pytorch-releases/README.md b/infra/tpu-pytorch-releases/README.md index f173b3ee8575..a70e0b064a6e 100644 --- a/infra/tpu-pytorch-releases/README.md +++ b/infra/tpu-pytorch-releases/README.md @@ -39,13 +39,11 @@ consists of the following fields. sources when building image and wheels. * `package_version` (string) - Version of the built wheels. Passed to the build steps. -* `accelerator` ("tpu"|"cuda") - Supported accelerator. Affects build +* `accelerator` ("tpu") - Supported accelerator. Affects build process and installed dependencies, see [apt.yaml](../ansible/config/apt.yaml) and [pip.yaml](../ansible/config/pip.yaml). * `python_version` (optional, string, default = "3.8") - Python version used for the docker image base and build process. -* `cuda_version` (optional, string, default = "11.8") - CUDA version to install. - Used only if `accelerator` is set to "cuda" * `arch` (optional, "amd64"|"aarch64", default = "amd64") - Architecture affects installed dependencies and build process, see [apt.yaml](../ansible/config/apt.yaml) and [pip.yaml](../ansible/config/pip.yaml). @@ -71,7 +69,6 @@ unset properties of existing triggers. git_tag = "v3.0.0" package_version = "3.0" accelerator = "tpu" - cuda_version = "11.8" # optional python_version = "3.8" # optional arch = "amd64" # optional }, @@ -95,12 +92,10 @@ at midnight (`America/Los_Angeles` time zone). Nightly builds in the `nightly_builds` variable in [artifacts.auto.tfvars](./artifacts.auto.tfvars) consists of the following fields. -* `accelerator` ("tpu"|"cuda") - Supported accelerator. Impacts build +* `accelerator` ("tpu") - Supported accelerator. Impacts build process and installed dependencies. * `python_version` (optional, string, default = "3.8") - Python version used for the docker images base and build process. -* `cuda_version` (optional, string, default = "11.8") - CUDA version to install. - Used only if `accelerator` is set to "cuda" * `arch` (optional, "amd64"|"aarch64", default = "amd64") - Architecture influences installed dependencies and build process. * `cxx11_abi` (optional, "0"|"1", default = "0") - Whether to use C++11 ABI or @@ -115,9 +110,8 @@ unset properties of existing triggers. #### Modify or add a new nightly release -1. Modify or add an entry with specific `accelerator`, `python_version` and (optionally) - `cuda_version` to the `nightly_builds` variable in the - [artifacts.auto.tfvars](./artifacts.auto.tfvars) file. +1. Modify or add an entry with specific `accelerator`, and `python_version` + to the `nightly_builds` variable in the [artifacts.auto.tfvars](./artifacts.auto.tfvars) file. See all variables in the section above. **Example** @@ -125,10 +119,13 @@ unset properties of existing triggers. ```hcl nightly_builds = [ { - accelerator = "cuda" - cuda_version = "11.8" # optional - python_version = "3.8" # optional - arch = "amd64" # optional + git_tag = "v2.8.0" + package_version = "2.8.0" + pytorch_git_rev = "v2.8.0" + accelerator = "tpu" + python_version = "3.10" + bundle_libtpu = "0" + cxx11_abi = "1" }, # ... ] diff --git a/infra/tpu-pytorch-releases/artifacts_builds.tf b/infra/tpu-pytorch-releases/artifacts_builds.tf index 099a2402afe9..b4e469b617be 100644 --- a/infra/tpu-pytorch-releases/artifacts_builds.tf +++ b/infra/tpu-pytorch-releases/artifacts_builds.tf @@ -6,10 +6,6 @@ locals { release_package_version = "2.8.0-rc5" release_pytorch_git_rev = "v2.8.0-rc8" nightly_package_version = "2.9.0" - cuda_versions = { - "nightly": [], - "r2.8": [] # Note: PyTorch 2.8 release doesn't have CUDA support - } # Built once a day from master generated_nightly_builds = concat( @@ -22,16 +18,6 @@ locals { cxx11_abi = "1" } ], - # CUDA builds - [ - for pair in setproduct(local.tpu_python_versions, local.cuda_versions["nightly"]) : { - accelerator = "cuda" - cuda_version = pair[1] - python_version = pair[0] - bundle_libtpu = "0" - cxx11_abi = "1" - } - ] ) # Built on push to specific tag. @@ -59,19 +45,6 @@ locals { bundle_libtpu = "1" } ], - - # cuda build for latest release - [ - for pair in setproduct(local.tpu_python_versions, local.cuda_versions["r2.8"]) : { - git_tag = local.release_git_tag - package_version = local.release_package_version - pytorch_git_rev = local.release_pytorch_git_rev - accelerator = "cuda" - cuda_version = pair[1] - python_version = pair[0] - bundle_libtpu = "0" - } - ] ) versioned_builds = concat(local.generated_versioned_builds, var.manual_versioned_builds) nightly_builds = concat(local.generated_nightly_builds, var.manual_nightly_builds) diff --git a/infra/tpu-pytorch-releases/dev_images.auto.tfvars b/infra/tpu-pytorch-releases/dev_images.auto.tfvars index e1618f2a80c2..aee461990fd4 100644 --- a/infra/tpu-pytorch-releases/dev_images.auto.tfvars +++ b/infra/tpu-pytorch-releases/dev_images.auto.tfvars @@ -7,17 +7,5 @@ dev_images = [ accelerator = "tpu" extra_tags = ["tpu"] python_version = "3.12" - }, - { - accelerator = "cuda" - cuda_version = "12.1" - extra_tags = ["cuda"] - python_version = "3.10" - }, - { - accelerator = "cuda" - cuda_version = "12.3" - extra_tags = ["cuda"] - python_version = "3.10" } ] diff --git a/infra/tpu-pytorch-releases/dev_images.tf b/infra/tpu-pytorch-releases/dev_images.tf index 54c340809efb..03798c9dbefb 100644 --- a/infra/tpu-pytorch-releases/dev_images.tf +++ b/infra/tpu-pytorch-releases/dev_images.tf @@ -3,10 +3,9 @@ variable "dev_images" { accelerator = string arch = optional(string, "amd64") python_version = optional(string, "3.8") - cuda_version = optional(string, "11.8") # Additional tags on top of uniquely generated tag based on accelerator, - # python and cuda versions. + # python versions. extra_tags = optional(list(string), []) })) } @@ -16,7 +15,7 @@ locals { for di in var.dev_images : format("%s_%s", di.python_version, - di.accelerator == "tpu" ? "tpuvm" : format("cuda_%s", di.cuda_version) + "tpuvm" ) => di } } @@ -55,7 +54,6 @@ module "dev_images" { accelerator = each.value.accelerator arch = each.value.arch python_version = each.value.python_version - cuda_version = each.value.cuda_version } docker_repo_url = module.docker_registry.url From 77d85fb18d40410380f044d171cc6bfd83055d47 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 3 Sep 2025 23:37:41 -0300 Subject: [PATCH 092/133] Remove CUDA from remaining tests. (#9613) This PR removes CUDA specific logic and tests from the remaining tests (after #9612). This is in line with the CUDA deprecation that started on release 2.8. --- test/cpp/test_aten_xla_tensor_2.cpp | 26 +- test/cpp/test_aten_xla_tensor_6.cpp | 2 +- test/cpp/test_replication.cpp | 4 +- test/ds/test_dynamic_shapes.py | 2 +- test/dynamo/test_dynamo.py | 101 ++------ test/dynamo/test_traceable_collectives.py | 2 +- test/pjrt/test_runtime.py | 9 +- test/pytorch_test_base.py | 17 +- test/run_tests.sh | 1 - test/test_autocast.py | 76 ------ test/test_compilation_cache_utils.py | 2 +- test/test_operations.py | 277 +--------------------- test/test_ops.py | 7 +- test/test_persistent_cache.py | 6 +- test/test_profiler.py | 5 - test/test_python_ops.py | 1 - test/test_zero1.py | 5 +- 17 files changed, 42 insertions(+), 501 deletions(-) diff --git a/test/cpp/test_aten_xla_tensor_2.cpp b/test/cpp/test_aten_xla_tensor_2.cpp index 31491249f618..dc3d605da34a 100644 --- a/test/cpp/test_aten_xla_tensor_2.cpp +++ b/test/cpp/test_aten_xla_tensor_2.cpp @@ -1555,20 +1555,18 @@ TEST_F(AtenXlaTensorTest, TestGroupNormBackward) { /*cudnn_enabled=*/false); }; torch::Tensor undef; - ForEachDevice({XlaDeviceType::CUDA, XlaDeviceType::TPU}, - [&](const torch::Device& device) { - TestBackward({input, undef_weight ? undef : weight, - undef_weight ? undef : bias}, - device, testfn, - /*rtol=*/1e-3, /*atol=*/1e-3, - /*derivative_level=*/2); - ExpectCounterNotChanged("aten::.*", - cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::native_batch_norm", - cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::native_batch_norm_backward", - cpp_test::GetIgnoredCounters()); - }); + ForEachDevice({XlaDeviceType::TPU}, [&](const torch::Device& device) { + TestBackward( + {input, undef_weight ? undef : weight, undef_weight ? undef : bias}, + device, testfn, + /*rtol=*/1e-3, /*atol=*/1e-3, + /*derivative_level=*/2); + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::native_batch_norm", + cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::native_batch_norm_backward", + cpp_test::GetIgnoredCounters()); + }); } } } diff --git a/test/cpp/test_aten_xla_tensor_6.cpp b/test/cpp/test_aten_xla_tensor_6.cpp index df3e1280b6e2..b9a669760b1b 100644 --- a/test/cpp/test_aten_xla_tensor_6.cpp +++ b/test/cpp/test_aten_xla_tensor_6.cpp @@ -873,7 +873,7 @@ TEST_F(AtenXlaTensorTest, TestEmbeddingBackward) { TEST_F(AtenXlaTensorTest, TestAmpUpdateScale) { XlaDeviceType hw_type = static_cast(bridge::GetDefaultDevice()->type()); - if (hw_type != XlaDeviceType::CUDA && hw_type != XlaDeviceType::CPU) { + if (hw_type != XlaDeviceType::CPU) { return; } torch::Tensor growth_tracker = diff --git a/test/cpp/test_replication.cpp b/test/cpp/test_replication.cpp index 00ec937bf8a1..73db4ab42392 100644 --- a/test/cpp/test_replication.cpp +++ b/test/cpp/test_replication.cpp @@ -98,9 +98,7 @@ void TestSingleReplication( class ReplicationTest : public AtenXlaTensorTestBase {}; -// Parallelism for DataParallel uses multi-threads. But cuda assumes one GPU -// device per process instead of relying on threads so we will not run the test -// on GPU. +// Parallelism for DataParallel uses multi-threads. TEST_F(ReplicationTest, TestNSingleReplication) { WithAllDevices( {XlaDeviceType::TPU}, diff --git a/test/ds/test_dynamic_shapes.py b/test/ds/test_dynamic_shapes.py index 46f329de4537..355a795bd14e 100644 --- a/test/ds/test_dynamic_shapes.py +++ b/test/ds/test_dynamic_shapes.py @@ -186,7 +186,7 @@ def test_masked_select_shape(self): def test_nonzero_cast(self): t1 = torch.ones(5, 2, device='xla') # Result of the nonzero should be the index type. Currently - # index type is s64 on cpu and gpu, but s32 on TPU. We should be + # index type is s64 on cpu, but s32 on TPU. We should be # able to cast it to any other type without error. t2 = torch.nonzero(t1.int()).float() torch_xla.sync() diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index 2c05adf7716c..c106483e1a37 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -148,17 +148,6 @@ def fn_simple(self, x, y): b = torch.sin(y) return a + b - def _choose_proper_device(self, initialize_on_cuda): - if not initialize_on_cuda: - return torch_xla.device() - - assert initialize_on_cuda - if xr.device_type() != "CUDA" or not torch.cuda.is_available(): - self.skipTest( - "Skip this test because it requires xr.device_type()=='CUDA' and torch.cuda.is_available()." - ) - return "cuda:0" - @skipOnNeuron def test_simple_model(self): device = torch_xla.device() @@ -193,51 +182,7 @@ def test_simple_model(self): # Dynamo has to sync the input since they are intermedate IR(xla_xy and xla_y3) self.assertEqual(met.counter_value('DynamoSyncInputExecuteTime'), 1) - # Tests that the dynamo bridge automatically moves tensors to XLA device, - # then back to the original device. - @unittest.skipIf(xr.device_type() != "CUDA" or not torch.cuda.is_available(), - f"GPU tests should only run on GPU devices.") - @parameterized.parameters( - "0", - "1", - ) - def test_simple_model_automoves_tensors(self, zero_copy_enabled): - x = torch.tensor(100.0, requires_grad=True, device="cuda:0") - y = torch.tensor(200.0, requires_grad=True, device="cuda:0") - original_device = x.device - eager_result = self.fn_simple(x, y) - - # Since all tests run in the same process, have to reset the metrics report. - met.clear_all() - torch._dynamo.reset() - - fn_simple_dynamo = torch.compile(self.fn_simple, backend="openxla") - res_xla_dynamo = fn_simple_dynamo(x, y) - self.assertIn('xla::add', met.counter_names()) - self.assertTrue(res_xla_dynamo.device == original_device) - self.assertTrue(torch.allclose(eager_result, res_xla_dynamo)) - - # verify that tracing is skipped in following runs - met.clear_counters() - res_xla_dynamo_reused = fn_simple_dynamo(x, y) - self.assertNotIn('xla::add', met.counter_names()) - self.assertTrue(res_xla_dynamo_reused.device == original_device) - self.assertTrue(torch.allclose(eager_result, res_xla_dynamo_reused)) - - # verify that dynamo can handle different inputs - res_xla_dynamo_different = fn_simple_dynamo(x + y, y * 3) - res_cpu_3 = self.fn_simple(x + y, y * 3) - self.assertTrue(res_xla_dynamo_different.device == original_device) - self.assertTrue(torch.allclose(res_cpu_3, res_xla_dynamo_different)) - - # There should not be any fallbacks. - self.assertEqual(torch_xla._XLAC._get_executed_fallback_ops(), []) - - @parameterized.parameters( - True, - False, - ) - def test_fn_without_input(self, initialize_on_cuda): + def test_fn_without_input(self): def fn_without_input(device): constant = 0.835 @@ -245,19 +190,15 @@ def fn_without_input(device): arange = torch.arange(16, device=device).reshape(4, 4) return expanded + arange - device = self._choose_proper_device(initialize_on_cuda) + device = torch_xla.device() compiled_fn = torch.compile(fn_without_input, backend='openxla') res_cpu = fn_without_input('cpu') res_xla_dynamo = compiled_fn(device) self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo.cpu())) - @parameterized.parameters( - (True, 'openxla'), - (False, dynamo_backend2.dynamo_backend), - (False, 'openxla'), - ) - def test_simple_model_with_in_place_ops(self, initialize_on_cuda, backend): + @parameterized.parameters('openxla', dynamo_backend2.dynamo_backend) + def test_simple_model_with_in_place_ops(self, backend): class TestModel(nn.Module): @@ -279,7 +220,7 @@ def forward(self, index, copy_tensor, input_tensor, op_name): output = input_tensor + self.self_tensor return output - device = self._choose_proper_device(initialize_on_cuda) + device = torch_xla.device() torch._dynamo.reset() met.clear_all() @@ -306,18 +247,14 @@ def forward(self, index, copy_tensor, input_tensor, op_name): op_name=in_place_op) self.assertTrue(torch.allclose(res_cpu, res_device_dynamo.cpu())) - @parameterized.parameters( - (True, 'openxla'), - (False, dynamo_backend2.dynamo_backend), - (False, 'openxla'), - ) - def test_einsum(self, initialize_on_cuda, backend): + @parameterized.parameters('openxla', dynamo_backend2.dynamo_backend) + def test_einsum(self, backend): # einsum currently does not have meta function to compute the shape hence # will fallback to XLA with FakeTensor as input to infer the output shape. def einsum_mm(a, b): return torch.einsum('ijkl,ijlm->ijkm', a, b) - device = self._choose_proper_device(initialize_on_cuda) + device = torch_xla.device() a = torch.randn(4, 4, 4, 4).to(device) b = torch.randn(4, 4, 4, 4).to(device) torch_xla.sync() @@ -328,16 +265,10 @@ def einsum_mm(a, b): self.assertTrue( torch.allclose(res_device_non_dynamo.cpu(), res_device_dynamo.cpu())) - @parameterized.parameters( - True, - False, - ) - def test_simple_model_with_different_input_shape(self, initialize_on_cuda): + def test_simple_model_with_different_input_shape(self): met.clear_all() - device = self._choose_proper_device(initialize_on_cuda) - # We need to make `dim` depend on `initialize_on_cuda` because the XLA compilation cache - # does not clean itself between the parameterized tests. - dim = 5 + int(initialize_on_cuda) + device = torch_xla.device() + dim = 5 device_x = torch.randn(dim, dim).to(device) device_y = torch.randn(dim, dim).to(device) new_dim = 2 * dim @@ -369,13 +300,9 @@ def get_loader(self, device, sample_count, batch_size=4): @skipOnTpu @skipOnNeuron - @parameterized.parameters( - (True, 'openxla'), - (False, dynamo_backend2.dynamo_backend), - (False, 'openxla'), - ) - def test_resnet18(self, initialize_on_cuda, backend): - device = self._choose_proper_device(initialize_on_cuda) + @parameterized.parameters('openxla', dynamo_backend2.dynamo_backend) + def test_resnet18(self, backend): + device = torch_xla.device() sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10) loader = self.get_loader(device, sample_count, batch_size=4) resnet18 = torchvision.models.resnet18() diff --git a/test/dynamo/test_traceable_collectives.py b/test/dynamo/test_traceable_collectives.py index 45bd89266604..e9416dcb2e05 100644 --- a/test/dynamo/test_traceable_collectives.py +++ b/test/dynamo/test_traceable_collectives.py @@ -20,7 +20,7 @@ def collective_broadcast_and_cos(input, src): def _mp_fn(index): device = torch_xla.device() world_size = xr.world_size() - if xm.xla_device_hw(device) not in ('TPU', 'CUDA', 'NEURON'): + if xm.xla_device_hw(device) not in ('TPU', 'NEURON'): print(f'skip this test for hw {xm.xla_device_hw(device)}') return ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device) diff --git a/test/pjrt/test_runtime.py b/test/pjrt/test_runtime.py index 6529b5e826e1..31af6acd4e36 100644 --- a/test/pjrt/test_runtime.py +++ b/test/pjrt/test_runtime.py @@ -17,7 +17,7 @@ class TestExperimentalPjrt(parameterized.TestCase): def setUp(self): xr.set_device_type('CPU') - @parameterized.parameters(('CPU', 'CPU'), ('CUDA', 'CUDA'), ('TPU', 'TPU')) + @parameterized.parameters(('CPU', 'CPU'), ('TPU', 'TPU')) def test_device_type(self, pjrt_device, expected): with mock.patch.dict(os.environ, {'PJRT_DEVICE': pjrt_device}, clear=True): self.assertEqual(xr.device_type(), expected) @@ -69,11 +69,6 @@ def test_xla_device_error(self): }, True), ('pjrt_tpu_precedence', { 'PJRT_DEVICE': 'TPU', 'XRT_TPU_CONFIG': 'localservice;0;localhost:51011', - }, True), ('gpu_num_devives', { - 'GPU_NUM_DEVICES': '4' - }, True), ('pjrt_gpu', { - 'PJRT_DEVICE': 'CUDA', - 'GPU_NUM_DEVICES': '4' }, True)) def test_pjrt_default_device(self, env_vars, expect_using_pjrt): # Prevent flag checking during reinitialization of PJRT backend. @@ -86,7 +81,7 @@ def test_pjrt_default_device(self, env_vars, expect_using_pjrt): reload(torch_xla) logs_context = contextlib.nullcontext() if expect_using_pjrt: - self.assertIn(xr.device_type(), ['CPU', 'CUDA', 'TPU', 'NEURON']) + self.assertIn(xr.device_type(), ['CPU', 'TPU', 'NEURON']) else: self.assertIsNone(xr.device_type()) diff --git a/test/pytorch_test_base.py b/test/pytorch_test_base.py index 3355f8efba99..7c426deca1e0 100644 --- a/test/pytorch_test_base.py +++ b/test/pytorch_test_base.py @@ -295,7 +295,7 @@ 'test_leaky_relu_inplace_with_neg_slope_xla', # expecting a specific error message 'test_upsamplingBicubic2d_correctness_xla', # FIXME! Got dtypes torch.float32 and torch.float64 'test_CTCLoss_no_batch_dim_xla', # Value out of range - 'test_upsamplingBilinear2d_xla', # precision on GPU/TPU, slow compilation on CPU + 'test_upsamplingBilinear2d_xla', # precision on TPU, slow compilation on CPU # torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for output 0 with respect to input 0 'test_GRU_grad_and_gradgrad_xla_float64', # grad check failure 'test_LSTM_grad_and_gradgrad_xla_float64', # grad check failure @@ -475,18 +475,6 @@ }, } -DISABLED_TORCH_TESTS_GPU_ONLY = { - # test_torch.py - 'TestTorchDeviceTypeXLA': { - 'test_maximum_minimum_float_nan_and_inf', # maximum(nan,inf) = inf on GPU - }, - - # test_indexing.py - 'TestIndexingXLA': { - 'test_index_put_accumulate_large_tensor_xla', # illegal memory access was encountered - }, -} - class MatchSet(object): @@ -526,15 +514,12 @@ def union_of_disabled_tests(sets): DISABLED_TORCH_TESTS_CPU = DISABLED_TORCH_TESTS_ANY -DISABLED_TORCH_TESTS_GPU = union_of_disabled_tests( - [DISABLED_TORCH_TESTS_ANY, DISABLED_TORCH_TESTS_GPU_ONLY]) DISABLED_TORCH_TESTS_TPU = union_of_disabled_tests( [DISABLED_TORCH_TESTS_ANY, DISABLED_TORCH_TESTS_TPU_ONLY]) DISABLED_TORCH_TESTS = { 'TPU': prepare_match_set(DISABLED_TORCH_TESTS_TPU), 'CPU': prepare_match_set(DISABLED_TORCH_TESTS_CPU), - 'CUDA': prepare_match_set(DISABLED_TORCH_TESTS_GPU), } diff --git a/test/run_tests.sh b/test/run_tests.sh index 033089d651f5..bb03d7abe161 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -254,7 +254,6 @@ function run_xla_op_tests3 { run_test "$_TEST_DIR/test_devices.py" run_test "$_TEST_DIR/test_manual_xla_registration.py" run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_placements.py" - # NOTE: this line below is testing export and don't care about GPU PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_coverage "$_TEST_DIR/test_core_aten_ops.py" run_test "$_TEST_DIR/test_pallas.py" run_xla_ir_hlo_debug run_test "$_TEST_DIR/test_user_computation_debug_cache.py" diff --git a/test/test_autocast.py b/test/test_autocast.py index 32f72ae9762a..c16fee30a725 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -152,82 +152,6 @@ def __init__(self, dev): self.methods_bf16 = [("__matmul__", mat0_bf16 + mat1_fp32)] -class AutocastCudaTestExtraLists(object): - - def __init__(self, dev): - super().__init__() - n = 8 - dimsets = ((n, n, n), (n, n, n, n), (n, n, n, n, n)) - conv_args_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev), - torch.randn(dimset, dtype=torch.float32, device=dev)) - for dimset in dimsets] - - mat0_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) - mat1_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) - mat2_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) - mat3_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) - - pointwise0_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),) - - element0_fp32 = (torch.randn(1, dtype=torch.float32, device=dev),) - - # This is currently not part of AutocastTestLists and excludes `relu`, `addbmm` - self.torch_bf16 = [ - ("conv1d", conv_args_fp32[0]), - ("conv2d", conv_args_fp32[1]), - ("conv3d", conv_args_fp32[2]), - ("bmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32), - torch.randn((n, n, n), device=dev, dtype=torch.float32))), - ("mm", mat0_fp32 + mat1_fp32), - ("matmul", - torch.matmul( - torch.ones([2, 3], device=dev, dtype=torch.float32), - torch.ones([3, 2], device=dev, dtype=torch.float32))), - ("baddbmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32), - torch.randn((n, n, n), device=dev, dtype=torch.float32), - torch.randn((n, n, n), device=dev, dtype=torch.float32))), - ("addmm", mat1_fp32 + mat2_fp32 + mat3_fp32), - ("conv_tbc", (torch.randn((10, 7, 3), device=dev, dtype=torch.float32), - torch.randn((5, 3, 5), device=dev, dtype=torch.float32), - torch.randn(5, device=dev, dtype=torch.float32), 0)), - ("conv_transpose1d", conv_args_fp32[0]), - ("conv_transpose2d", conv_args_fp32[1]), - ("conv_transpose3d", conv_args_fp32[2]), - ("prelu", pointwise0_fp32 + element0_fp32), - ] - - -class AutocastCudaTestUnsupportedLists(object): - - def __init__(self): - super().__init__() - # Utility arguments, created as one-element tuples - self.torch_expect_builtin_promote = [ - "cat", # requires all input tensors to be the same type - "equal", # requires all input tensors to be the same type - "stack", # return f16 instead of f32 - ] - self.methods_expect_builtin_promote = [] - - # The remaining lists organize ops that autocast treats explicitly. - self.torch_fp16 = [ - "_convolution_nogroup", # need lowering - "addmv", # need lowering - ] - self.torch_fp32 = [ - "norm", # produce f16 instead of f32 - ] - self.torch_need_autocast_promote = [ - "scatter_add", # cat currently requires all input tensors to be the same type - ] - self.nn_fp16 = [] - self.nn_fp32 = [] - self.linalg_fp16 = [] - self.methods_fp16 = [] - self.methods_fp32 = [] - self.banned = [] - - class TestAutocastBase(unittest.TestCase): @classmethod diff --git a/test/test_compilation_cache_utils.py b/test/test_compilation_cache_utils.py index 0ac8a013d814..1113d3b67a83 100644 --- a/test/test_compilation_cache_utils.py +++ b/test/test_compilation_cache_utils.py @@ -73,7 +73,7 @@ def _test_num_graph_hash(self, use_dynamo, use_persistent): use_persistent=(True, False), ) def test_num_graph_hash(self, use_dynamo, use_persistent): - if use_persistent and (xr.device_type() not in {'TPU', 'CUDA', 'NEURON'}): + if use_persistent and (xr.device_type() not in {'TPU', 'NEURON'}): raise absltest.SkipTest('Device type does not support persistent caching') _test_spawn(self._test_num_graph_hash, (use_dynamo, use_persistent)) diff --git a/test/test_operations.py b/test/test_operations.py index 9d377083da54..08c2f39d63c3 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -90,12 +90,7 @@ def skipIfFunctionalizationDisabled(reason): def onlyOnCPU(fn): accelerator = os.environ.get("PJRT_DEVICE").lower() - return unittest.skipIf(accelerator != "cpu", "PJRT_DEVICE=CUDA required")(fn) - - -def onlyOnCUDA(fn): - accelerator = os.environ.get("PJRT_DEVICE").lower() - return unittest.skipIf(accelerator != "cuda", "PJRT_DEVICE=CUDA required")(fn) + return unittest.skipIf(accelerator != "cpu", "PJRT_DEVICE=CPU required")(fn) def onlyIfXLAExperimentalContains(feat): @@ -108,79 +103,10 @@ def _gen_tensor(*args, **kwargs): return torch.randn(*args, **kwargs) -def _gen_int_tensor(*args, **kwargs): - return torch.randint(*args, **kwargs) - - def _gen_mask(size): return torch.randint(0, 2, size, dtype=torch.bool) -def _get_device_support(devname): - devices = torch_xla._XLAC._xla_get_devices() - num_devices = 0 - for device in devices: - if re.match(devname + r':\d+$', device): - num_devices += 1 - return DeviceSupport(num_devices=num_devices) if num_devices > 0 else None - - -def _support_replicated(devname, num_devices): - devsup = _get_device_support(devname) - if not devsup: - return False - return devsup.num_devices >= num_devices - - -def _random_inputs(shapes, num_replicas=1): - random_tensors = [] - for _ in range(0, num_replicas): - replica_inputs = [] - for shape in shapes: - replica_inputs.append(_gen_tensor(*shape)) - random_tensors.append(tuple(replica_inputs)) - return tuple(random_tensors) - - -def _random_like(tensor_list): - random_tensors = [] - for o in tensor_list: - if o.dtype == torch.float32 or o.dtype == torch.float64: - random_tensors += [_gen_tensor(*o.shape, dtype=o.dtype)] - elif o.dtype == torch.int64: - # TODO remove this, we shouldn't be needing to pass random_tensor for long types - random_tensors += [torch.empty_like(o)] - else: - raise RuntimeError('Unsupported type: ', o.dtype) - return random_tensors - - -def _zeros_like(tensor_list): - zeros_tensors = [] - for o in tensor_list: - if o.dtype == torch.float32 or o.dtype == torch.float64: - zeros_tensors += [torch.zeros(*o.shape, dtype=o.dtype)] - elif o.dtype == torch.int64: - # TODO remove this, we shouldn't be needing to pass zeros_tensor for long types - zeros_tensors += [torch.zeros_like(o)] - else: - raise RuntimeError('Unsupported type: ', o.dtype) - return zeros_tensors - - -def onlyIfTorchSupportsCUDA(fn): - return unittest.skipIf( - not torch.cuda.is_available(), reason="requires PyTorch CUDA support")( - fn) - - -def onlyIfPJRTDeviceIsCUDA(fn): - return unittest.skipIf( - os.environ.get("PJRT_DEVICE") not in ("GPU", "CUDA"), - reason="requires CUDA as PJRT_DEVICE")( - fn) - - class TestToXlaTensorArena(test_utils.XlaTestCase): def test(self): @@ -274,10 +200,6 @@ def forward(self, x): return F.log_softmax(x, dim=1) -@unittest.skipIf( - xr.device_type() == 'CUDA', - 'Parallelism for DataParallel uses multi-threads. But cuda assumes one GPU device per process instead of relying on threads.' -) class TestParallelTensorMNIST(test_utils.XlaTestCase): def test(self): @@ -436,7 +358,7 @@ def test_masked_select_shape(self): def test_nonzero_cast(self): t1 = torch.ones(5, 2, device='xla') # Result of the nonzero should be the index type. Currently - # index type is s64 on cpu and gpu, but s32 on TPU. We should be + # index type is s64 on cpu, but s32 on TPU. We should be # able to cast it to any other type without error. t2 = torch.nonzero(t1.int()).float() torch_xla.sync() @@ -3036,27 +2958,6 @@ def test_as_strided_input_larger(self): self.assertEqual(a, former_a) - def _test_move_tensor_cuda_to_xla(self, cpu_tensor): - # Assumes CPU-XLA data movement works. - cuda_tensor = cpu_tensor.to("cuda") - # Move tensor CUDA -> XLA. - xla_tensor = cuda_tensor.to('xla') - # Move the XLA tensor back to CPU, and check that it is the same as - # the original CPU tensor. - self.assertTrue(torch.equal(cpu_tensor, xla_tensor.cpu())) - - @onlyIfTorchSupportsCUDA - @onlyIfPJRTDeviceIsCUDA - def test_aten_move_cuda_to_xla(self): - self._test_move_tensor_cuda_to_xla(torch.arange(5)) - - @onlyIfTorchSupportsCUDA - @onlyIfPJRTDeviceIsCUDA - def test_aten_move_scalar_cuda_to_xla(self): - # 0-dimensional scalar-tensor - # Has a different execution path than other tensors. - self._test_move_tensor_cuda_to_xla(torch.tensor(42)) - def test_unsafe_buffer_pointer(self): xla_device = torch_xla.device() xla_tensor_0 = torch.tensor(42).to(xla_device) @@ -3083,178 +2984,6 @@ def test_unsafe_buffer_pointer(self): self.assertGreaterEqual(buf_ptr_3, 0) -class TestDLPack(parameterized.TestCase): - - def _test_dlpack_capsule_conversion_helper(self, xla_tensor): - dlpt = xdlpack.to_dlpack(xla_tensor) # dlpt1 has type PyCapsule - xla_tensor2 = xdlpack.from_dlpack(dlpt) - - self.assertEqual(xla_tensor.device, xla_tensor2.device) - self.assertTrue(torch.allclose(xla_tensor.cpu(), xla_tensor2.cpu())) - self.assertRaisesRegex(RuntimeError, - "DLTensor capsule can be consumed only once", - lambda: xdlpack.from_dlpack(dlpt)) - - self.assertEqual( - torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor), - torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor2)) - - @onlyIfTorchSupportsCUDA - @onlyIfPJRTDeviceIsCUDA - @parameterized.parameters(*all_types_and(torch.half, torch.bfloat16)) - def test_dlpack_roundtrip_tensor(self, dtype): - xla_device = torch_xla.device() - # xtensor->CurrentDataHandle() == nullptr but xtensor->CurrentIrValue().node != nullptr and device_data != nullptr - # xla_tensor_2 uses XLANativeFunctions::_to_copy - xla_tensor_2 = torch.arange(5, dtype=dtype).to(xla_device) - self._test_dlpack_capsule_conversion_helper(xla_tensor_2) - - # xla_tensor_3 uses arange_out IR node. - xla_tensor_3 = torch.arange(5, dtype=dtype, device='xla') - torch_xla.sync() - self._test_dlpack_capsule_conversion_helper(xla_tensor_3) - - @onlyIfTorchSupportsCUDA - @onlyIfPJRTDeviceIsCUDA - @parameterized.parameters( - *all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool, - torch.uint16, torch.uint32, torch.uint64)) - def test_dlpack_roundtrip_scalar(self, dtype): - xla_device = torch_xla.device() - xla_tensor_0 = torch.tensor(42, dtype=dtype).to(xla_device) - # `torch_xla.sync()` ensures xtensor->CurrentDataHandle() != nullptr - torch_xla.sync() - self._test_dlpack_capsule_conversion_helper(xla_tensor_0) - - xla_tensor_1 = torch.tensor(42, dtype=dtype).to(xla_device) - # xtensor->CurrentDataHandle() == nullptr but xtensor->CurrentIrValue().node != nullptr and device_data != nullptr - self._test_dlpack_capsule_conversion_helper(xla_tensor_1) - - @onlyIfTorchSupportsCUDA - @onlyIfPJRTDeviceIsCUDA - def test_dlpack_roundtrip_bool(self): - xla_tensor = torch.ones(1, dtype=torch.bool).to('xla') - self._test_dlpack_capsule_conversion_helper(xla_tensor) - - @onlyIfTorchSupportsCUDA - @onlyIfPJRTDeviceIsCUDA - def test_dlpack_pytorch_cuda_to_xla(self): - t1_cuda = torch.arange(5).cuda() - dlt1 = torch.utils.dlpack.to_dlpack(t1_cuda) - xla_t1 = xdlpack.from_dlpack(dlt1) - self.assertEqual(xla_t1.device.type, 'xla') - self.assertEqual(xla_t1.device.index, t1_cuda.device.index) - t1_cuda[0] = t1_cuda[0] + 20 - self.assertTrue(torch.allclose(xla_t1.cpu(), t1_cuda.cpu())) - - t2_cuda = torch.tensor(5).cuda() - dlt2 = torch.utils.dlpack.to_dlpack(t2_cuda) - xla_t2 = xdlpack.from_dlpack(dlt2) - self.assertEqual(xla_t2.device.type, 'xla') - self.assertEqual(xla_t2.device.index, t2_cuda.device.index) - t2_cuda.fill_(6) - self.assertTrue(torch.allclose(xla_t2.cpu(), t2_cuda.cpu())) - - cuda1 = torch.device('cuda:1') - t3_cuda = torch.tensor(5, device=cuda1) - dlt3 = torch.utils.dlpack.to_dlpack(t3_cuda) - xla_t3 = xdlpack.from_dlpack(dlt3) - self.assertEqual(xla_t3.device.type, 'xla') - self.assertEqual( - xla_t3.device.index, - t3_cuda.device.index, - msg='both value should 1. xla_t3.device should be xla:1.') - t3_cuda.fill_(6) - self.assertTrue(torch.allclose(xla_t3.cpu(), t3_cuda.cpu())) - - @onlyIfTorchSupportsCUDA - @onlyIfPJRTDeviceIsCUDA - def test_dlpack_pytorch_cuda_to_xla_protocol_conversion(self): - # Unlike the test_dlpack_pytorch_cuda_to_xla, - # torch_cuda_tensor has attribute __dlpack__ and __dlpack_device__. - # From cuda tensors to xla tensors, the synchronization is handdled implicitly. - t1_cuda = torch.arange(5).cuda() - xla_t1 = xdlpack.from_dlpack(t1_cuda) - self.assertEqual(xla_t1.device.type, 'xla') - self.assertEqual(xla_t1.device.index, t1_cuda.device.index) - t1_cuda[0] = t1_cuda[0] + 20 - self.assertTrue(torch.allclose(xla_t1.cpu(), t1_cuda.cpu())) - - t2_cuda = torch.tensor(5).cuda() - xla_t2 = xdlpack.from_dlpack(t2_cuda) - self.assertEqual(xla_t2.device.type, 'xla') - self.assertEqual(xla_t2.device.index, t2_cuda.device.index) - t2_cuda.fill_(6) - self.assertTrue(torch.allclose(xla_t2.cpu(), t2_cuda.cpu())) - - cuda1 = torch.device('cuda:1') - t3_cuda = torch.tensor(5, device=cuda1) - xla_t3 = xdlpack.from_dlpack(t3_cuda) - self.assertEqual(xla_t3.device.type, 'xla') - self.assertEqual( - xla_t3.device.index, - t3_cuda.device.index, - msg='both value should 1. xla_t3.device should be xla:1.') - t3_cuda.fill_(6) - self.assertTrue(torch.allclose(xla_t3.cpu(), t3_cuda.cpu())) - - @onlyIfTorchSupportsCUDA - @onlyIfPJRTDeviceIsCUDA - def test_dlpack_xla_to_pytorch_cuda(self): - xla_t1 = torch.arange(5).to('xla') - dlt1 = xdlpack.to_dlpack(xla_t1) - cuda_t1 = torch.utils.dlpack.from_dlpack(dlt1) - self.assertEqual(cuda_t1.device.type, 'cuda') - self.assertEqual(cuda_t1.device.index, xla_t1.device.index) - cuda_t1[0] = cuda_t1[0] + 20 - self.assertTrue(torch.allclose(xla_t1.cpu(), cuda_t1.cpu())) - - @onlyIfTorchSupportsCUDA - @onlyIfPJRTDeviceIsCUDA - def test_dlpack_xla_to_pytorch_cuda_protocol_conversion(self): - xla_t1 = torch.arange(5).to('xla') - cuda_t1 = torch.utils.dlpack.from_dlpack(xla_t1) - self.assertEqual(cuda_t1.device.type, 'cuda') - self.assertEqual(cuda_t1.device.index, xla_t1.device.index) - cuda_t1[0] = cuda_t1[0] + 20 - self.assertTrue(torch.allclose(xla_t1.cpu(), cuda_t1.cpu())) - - @onlyIfTorchSupportsCUDA - @onlyIfPJRTDeviceIsCUDA - def test_dlpack_non_default_layout(self): - cuda_t = torch.arange(25, device=torch.device('cuda')).reshape(5, 5) - - t1 = cuda_t.t() - xla_t1 = xdlpack.from_dlpack(t1.__dlpack__()) - self.assertEqual(xla_t1.device.type, 'xla') - self.assertEqual(xla_t1.device.index, t1.device.index) - self.assertTrue(torch.allclose(t1.cpu(), xla_t1.cpu())) - - t2 = cuda_t[0] - xla_t2 = xdlpack.from_dlpack(t2.__dlpack__()) - self.assertEqual(xla_t2.device.type, 'xla') - self.assertEqual(xla_t2.device.index, t2.device.index) - self.assertTrue(torch.allclose(t2.cpu(), xla_t2.cpu())) - - t3 = cuda_t[:, 0] - self.assertRaisesRegex( - RuntimeError, - r"Only DLPack tensors with trivial \(compact\) striding are supported", - lambda: xdlpack.from_dlpack(t3.__dlpack__())) - - t4 = cuda_t[1, :] - xla_t4 = xdlpack.from_dlpack(t4.__dlpack__()) - self.assertEqual(xla_t4.device.type, 'xla') - self.assertEqual(xla_t4.device.index, t4.device.index) - self.assertTrue(torch.allclose(t4.cpu(), xla_t4.cpu())) - - t5 = cuda_t[1] - xla_t5 = xdlpack.from_dlpack(t5.__dlpack__()) - self.assertEqual(xla_t5.device.type, 'xla') - self.assertEqual(xla_t5.device.index, t5.device.index) - self.assertTrue(torch.allclose(t5.cpu(), xla_t5.cpu())) - - class SimpleModelWithDropout(torch.nn.Module): def __init__(self): @@ -3308,7 +3037,7 @@ def test_opt_barrier(self): opt_barrier = line break - # Somehow the CPU/GPU CI will not have the opt-barrier. + # Somehow the CPU CI will not have the opt-barrier. if opt_barrier != "": self.assertEqual(opt_barrier.count("f32[128,128]"), 6) self.assertEqual(opt_barrier.count("f32[128]"), 2) diff --git a/test/test_ops.py b/test/test_ops.py index 167b2024d07f..2e71948b76a1 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -126,6 +126,7 @@ def get_allowed_ops_map( AllowedOpInfoEntry('gt'), AllowedOpInfoEntry('imag'), AllowedOpInfoEntry('inverse'), + AllowedOpInfoEntry('index_put'), AllowedOpInfoEntry('isin'), AllowedOpInfoEntry('isneginf'), AllowedOpInfoEntry('le'), @@ -365,11 +366,7 @@ def get_allowed_ops_map( # AllowedOpInfoEntry('logdet'), xla::lodget does not handle empty input # AllowedOpInfoEntry('qr'), # Slice dim size 1 greater than dynamic slice dimension: 0 - # Failed on CUDA CI only (investigate) - # app.circleci.com/pipelines/github/pytorch/xla/9088/workflows/2d59c649-db2b-4384-921e-5e43eba1b51a/jobs/17875 - # AllowedOpInfoEntry('index_put'), - - # Worked locally (but failing on CI both CPU and CUDA) + # Worked locally (but failing on CI both CPU) # app.circleci.com/pipelines/github/pytorch/xla/9130/workflows/71c74f3d-1735-4328-81b5-784d6e6744da/jobs/17998 # AllowedOpInfoEntry('var_mean'), # AllowedOpInfoEntry('pow'), # for int64 don't work, likely rounding issue diff --git a/test/test_persistent_cache.py b/test/test_persistent_cache.py index a7cfa7ab1fd7..21bee1472cae 100644 --- a/test/test_persistent_cache.py +++ b/test/test_persistent_cache.py @@ -94,8 +94,6 @@ def _spmd_sharded_test(tmpdir, metrics): _assert_correctness_and_metrics(t, xt, metrics) -# Skip CUDA, the on disk cache cannot be deserialized after XLA pin update in -# #8908 @absltest.skipUnless(xr.device_type() in {'TPU', 'NEURON'}, 'Device type does not support persistent caching') class PersistentCacheTest(parameterized.TestCase): @@ -133,9 +131,7 @@ def test_persistent_cache_mp(self): ('spmd_replicated', _spmd_replicated_test), ('spmd_sharded', _spmd_sharded_test), ) - @absltest.skipUnless( - xr.device_type() == 'TPU', - 'TPU required for SPMD; single-device GPU is pending #6023') + @absltest.skipUnless(xr.device_type() == 'TPU', 'TPU required for SPMD') def test_persistent_cache(self, test_fn): self._run_test(_test_spawn, test_fn) diff --git a/test/test_profiler.py b/test/test_profiler.py index 3f2a1482ffb2..2017de1d67e7 100644 --- a/test/test_profiler.py +++ b/test/test_profiler.py @@ -77,11 +77,6 @@ def _check_trace_namespace_exists(self, path): f'Expected "build_graph" trace in: {path}') def test_trace_and_metrics(self): - # Create a new context for forking processes with the spawn method. - # This is necessary so as to avoid CUDA initialization issues when - # both PyTorch and PyTorch/XLA were compiled with CUDA support. - context = multiprocessing.get_context("spawn") - port = xu.get_free_tcp_ports()[0] training_started = context.Event() p = context.Process( diff --git a/test/test_python_ops.py b/test/test_python_ops.py index 9dc145947f62..a9cdd816fe85 100644 --- a/test/test_python_ops.py +++ b/test/test_python_ops.py @@ -53,7 +53,6 @@ def ref_put(dst, idx, src, accumulate): src = make_arg(src_size, noncontiguous=not src_contig) # If accumulate=True, `put_` should be deterministic regardless of the inputs on CPU - # On CUDA it may not be, but the test has enough tolerance to account for this if accumulate: idx = make_idx(src_size, high=dst.numel()) else: diff --git a/test/test_zero1.py b/test/test_zero1.py index 8bb2fbc3d822..62aa13b4ecc6 100644 --- a/test/test_zero1.py +++ b/test/test_zero1.py @@ -154,13 +154,12 @@ def test_zero1_load(self): def _mp_fn(index): device = torch_xla.device() - if xm.xla_device_hw(device) in ('TPU', 'CUDA'): + if xm.xla_device_hw(device) in ('TPU',): test = unittest.main(exit=False) sys.exit(0 if test.result.wasSuccessful() else 1) else: print( - 'Default device {} is not a TPU or CUDA device'.format(device), - file=sys.stderr) + 'Default device {} is not a TPU device'.format(device), file=sys.stderr) if __name__ == '__main__': From bd9538295b7462d21840384a6491eb5ed7131124 Mon Sep 17 00:00:00 2001 From: qihqi Date: Thu, 4 Sep 2025 10:39:07 -0700 Subject: [PATCH 093/133] Miscelanous cleanup (#9619) --- torchax/README.md | 3 + torchax/dev-requirements.txt | 4 +- torchax/examples/_diffusion.py | 106 ---------------- torchax/examples/_grad_of_attention.py | 76 ----------- .../torchbench_models/BERT_pytorch.py | 52 -------- torchax/examples/train_gpt/requirements.txt | 4 - torchax/pyproject.toml | 3 - torchax/test-requirements.txt | 7 +- torchax/test/test_misc.py | 12 ++ torchax/test/test_tf_integration.py | 51 -------- torchax/test_dist/test_to_device.py | 27 ++++ torchax/torchax/CONTRIBUTING.md | 15 ++- torchax/torchax/tf_integration.py | 119 ------------------ 13 files changed, 57 insertions(+), 422 deletions(-) delete mode 100644 torchax/examples/_diffusion.py delete mode 100644 torchax/examples/_grad_of_attention.py delete mode 100644 torchax/examples/torchbench_models/BERT_pytorch.py delete mode 100644 torchax/examples/train_gpt/requirements.txt delete mode 100644 torchax/test/test_tf_integration.py create mode 100644 torchax/test_dist/test_to_device.py delete mode 100644 torchax/torchax/tf_integration.py diff --git a/torchax/README.md b/torchax/README.md index 06d9e26d7dcd..2b1fa8d58f33 100644 --- a/torchax/README.md +++ b/torchax/README.md @@ -97,11 +97,14 @@ inputs = torch.randn(3, 3, 28, 28, device='jax') m = MyModel().to('jax') res = m(inputs) print(type(res)) # outputs torchax.tensor.Tensor +print(res.jax()) # print the underlying Jax Array ``` `torchax.tensor.Tensor` is a `torch.Tensor` subclass that holds a `jax.Array`. You can inspect that JAX array with `res.jax()`. +In other words, despite that the code above looks like PyTorch, it is actually running JAX! + ## What is happening behind the scene We took the approach detailed in the diff --git a/torchax/dev-requirements.txt b/torchax/dev-requirements.txt index 7c0020e5156e..2da02ae8599b 100644 --- a/torchax/dev-requirements.txt +++ b/torchax/dev-requirements.txt @@ -1,5 +1,5 @@ -f https://download.pytorch.org/whl/torch -torch==2.7.1 ; sys_platform == 'darwin' # macOS -torch==2.7.1+cpu; sys_platform != 'darwin' # Non-macOS (CPU-only), like on TPU +torch==2.8.0 ; sys_platform == 'darwin' # macOS +torch==2.8.0+cpu; sys_platform != 'darwin' # Non-macOS (CPU-only), like on TPU yapf==0.40.2 # N.B.: keep in sync with `infra/ansible/config/pip.yaml`, `.github/workflows/lintercheck.yml` flax==0.10.6 diff --git a/torchax/examples/_diffusion.py b/torchax/examples/_diffusion.py deleted file mode 100644 index 9f7578056b06..000000000000 --- a/torchax/examples/_diffusion.py +++ /dev/null @@ -1,106 +0,0 @@ -import functools - -import torch -from time import time -from diffusers import DiffusionPipeline -from torch.utils import _pytree as pytree - -import torchax -import torchax.functions -from torchax.extra import torch_view, jax_view - -import jax -import torch.func - - -class CompiledModule: - - def __init__(self, model): - weights = model.state_dict() - weights.update(model.named_parameters()) - self._weights = pytree.tree_map_only(torch.Tensor, - torchax.tensor.move_to_device, weights) - self._model = model - - self._func_jitted_torch = None #torch_view(func_mod_jitted) - - def _maybe_move_tensor(self, tensor): - if isinstance( - tensor, torch.Tensor) and not isinstance(tensor, torchax.tensor.Tensor): - return torchax.tensor.move_to_device(tensor) - return tensor - - def _make_jitted(self, args, kwargs): - static = [] - for i, a in enumerate(args): - if not isinstance(a, torch.Tensor): - static.append(i + 1) # weight is 0 - static_argnames = [] - for k, v in kwargs.items(): - if not isinstance(v, torch.Tensor): - static_argnames.append(k) - - def f(weights, *args, **kwargs): - weights, args, kwargs = torchax.tensor.wrap((weights, args, kwargs)) - with torchax.functions.XLAFunctionMode(), torchax.tensor.XLADispatchMode( - ): - res = torch.func.functional_call(self._model, weights, args, kwargs) - if isinstance(res, tuple) and len(res) == 1: - res = res[0] - return torchax.tensor.unwrap(res) - - fjit = jax.jit(f, static_argnames=tuple(static_argnames)) - return torch_view(fjit) - - def forward(self, *args, **kwargs): - (args, kwargs) = pytree.tree_map(self._maybe_move_tensor, (args, kwargs)) - if self._func_jitted_torch is None: - self._func_jitted_torch = self._make_jitted(args, kwargs) - return self._func_jitted_torch(self._weights, *args, **kwargs) - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - - def __getattr__(self, key): - return getattr(self._model, key) - - -def compile_pipe(pipe): - pipe.text_encoder = CompiledModule(pipe.text_encoder) - pipe.text_encoder_2 = CompiledModule(pipe.text_encoder_2) - pipe.unet = CompiledModule(pipe.unet) - pipe.vae = CompiledModule(pipe.vae) - - -def main(): - pipe = DiffusionPipeline.from_pretrained( - # "stabilityai/stable-diffusion-xl-base-0.9", - "stabilityai/stable-diffusion-xl-base-1.0", - use_safetensors=True, - ) - compile_pipe(pipe) - - global_bs = 10 - inference_steps = 20 - resol = 1024 - prompts = ["a photo of an astronaut riding a horse on mars"] * global_bs - print( - f'global batch size {global_bs}', - f'inference steps {inference_steps}', - f'Image resolution {resol}', - flush=True) - - iters = 5 - for i in range(iters): - prompt = prompts - # print('per device prompts len',len(prompt)) - # prompt = prompts[rank] - start = time() - image = pipe( - prompt, num_inference_steps=inference_steps, height=resol, - width=resol).images[0] - print(f'Step {i} inference time {time()-start} sec', flush=True) - - -if __name__ == '__main__': - main() diff --git a/torchax/examples/_grad_of_attention.py b/torchax/examples/_grad_of_attention.py deleted file mode 100644 index 8a8882720837..000000000000 --- a/torchax/examples/_grad_of_attention.py +++ /dev/null @@ -1,76 +0,0 @@ -import jax.numpy as jnp -import jax -from jax.experimental.pallas.ops.tpu import flash_attention - -import torchax -from jax.experimental import mesh_utils -from torchax.ops.jtorch import _tpu_flash_attention - -env = torchax.default_env() -jax.config.update('jax_enable_x64', False) -env._mesh = jax.sharding.Mesh( - mesh_utils.create_device_mesh((4,)), - axis_names=("fsdp",), -) -env.use_flash_attention = True - -from torch.nn import functional as F - - -def attn(q, k, v): - q, k, v = env.j2t_iso((q, k, v)) - with env: - x = F.scaled_dot_product_attention(q, k, v, is_causal=True) - x = env.t2j_iso(x) - return jnp.sum(x) - - -import torch - - -class M(torch.nn.Module): - - def __init__(self): - super().__init__() - self.a = torch.nn.Linear(10, 10) - - def forward(self, x): - return self.a(x) - - -m = M() -from torchax.interop import JittableModule - -mjit = JittableModule(m) - -from torch.nn.utils import stateless - - -def f(weights, x): - res = mjit.functional_call('forward', weights, {}, (x,)) - return torch.sum(res) - - -def crossent(x, y): - x, y = env.j2t_iso((x, y)) - res = torch.func.functional_call(m, x, (y,)) - return env.t2j_iso(res) - - -graded = jax.value_and_grad(attn) - -shape = (4, 32, 128, 32) -q = jnp.ones(shape, dtype='bfloat16') -v = jnp.ones(shape, dtype='bfloat16') -k = jnp.ones(shape, dtype='bfloat16') - -env = torchax.default_env() -weights = env.t2j_iso(env.to_xla(mjit.params)) - -from torchax.interop import jax_view - -#print(jax.jit(graded).lower(q, v, k).as_text()) -print( - jax.jit(jax.grad(jax_view(f))).lower(weights, - jax.ShapeDtypeStruct( - (10,), 'float32')).as_text()) diff --git a/torchax/examples/torchbench_models/BERT_pytorch.py b/torchax/examples/torchbench_models/BERT_pytorch.py deleted file mode 100644 index 79ba47b5eaa7..000000000000 --- a/torchax/examples/torchbench_models/BERT_pytorch.py +++ /dev/null @@ -1,52 +0,0 @@ -import torch -import time -import torchax -import torchax.interop -import os -import importlib -import sys -import logging -import sys - -root = logging.getLogger() -root.setLevel(logging.DEBUG) - -handler = logging.StreamHandler(sys.stdout) -handler.setLevel(logging.DEBUG) -formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s') -handler.setFormatter(formatter) -root.addHandler(handler) - -# NOTE: replace this patch below with your installation -TORCH_BENCH_PATH = os.path.expanduser('~/git/qihqi/benchmark') -# If your directory looks like this_file.py, benchmark/ -sys.path.append(TORCH_BENCH_PATH) -model_name = "torchbenchmark.models.BERT_pytorch" # replace this by the name of the model you're working on -module = importlib.import_module(model_name) -benchmark_cls = getattr(module, "Model", None) -benchmark = benchmark_cls( - test="eval", device="cpu") # test = train or eval device = cuda or cpu - -model, example = benchmark.get_module() - -env = torchax.default_env() -env.config.debug_print_each_op = False -model = env.to_xla(model) -example = env.to_xla(example) -with env: - start = time.perf_counter() - print(model(*example)) - end = time.perf_counter() - print('Eager mode time', end - start) - - -def func_call(state, example): - return torch.func.functional_call(model, state, example, tie_weights=False) - - -jitted = torchax.interop.jax_jit(func_call) -start = time.perf_counter() -print(func_call(model.state_dict(), example)) -end = time.perf_counter() -print('Jitted mode time', end - start) diff --git a/torchax/examples/train_gpt/requirements.txt b/torchax/examples/train_gpt/requirements.txt deleted file mode 100644 index d302f474acfa..000000000000 --- a/torchax/examples/train_gpt/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -tqdm -git+https://github.com/karpathy/minGPT.git@master -datasets -tiktoken diff --git a/torchax/pyproject.toml b/torchax/pyproject.toml index 9407829b76ed..2f30f30e7c68 100644 --- a/torchax/pyproject.toml +++ b/torchax/pyproject.toml @@ -48,6 +48,3 @@ odml = ["jax[cpu]>=0.6.2", "jax[cpu]"] [tool.hatch.build.targets.wheel] packages = ["torchax"] - -[tool.pytest.ini_options] -addopts="-n auto" diff --git a/torchax/test-requirements.txt b/torchax/test-requirements.txt index c64af1807b7b..677912bbd04d 100644 --- a/torchax/test-requirements.txt +++ b/torchax/test-requirements.txt @@ -2,9 +2,8 @@ absl-py==2.2.2 immutabledict==4.2.1 pytest==8.3.5 -pytest-xdist==3.6.1 -pytest-forked==1.6.0 -sentencepiece==0.2.0 +sentencepiece expecttest==0.3.0 optax==0.2.4 -tensorflow==2.19.0 +pytest +pytest-xdist diff --git a/torchax/test/test_misc.py b/torchax/test/test_misc.py index b93877a7fd64..9214c5b1eac6 100644 --- a/torchax/test/test_misc.py +++ b/torchax/test/test_misc.py @@ -31,6 +31,17 @@ def forward(self, a, b): jnp.sin(jnp.array([1, 2, 3])) + jnp.cos(jnp.array([3, 4, 5])))) def test_to_device(self): + env = torchax.default_env() + with env: + step1 = torch.ones( + 100, + 100, + ) + step2 = torch.triu(step1, diagonal=1) + step3 = step2.to(dtype=torch.bool, device='jax') + self.assertEqual(step3.device.type, 'jax') + + def test_to_device_twice(self): env = torchax.default_env() env.config.debug_print_each_op = True with env: @@ -40,6 +51,7 @@ def test_to_device(self): ) step2 = torch.triu(step1, diagonal=1) step3 = step2.to(dtype=torch.bool, device='jax') + step3.to('jax') self.assertEqual(step3.device.type, 'jax') diff --git a/torchax/test/test_tf_integration.py b/torchax/test/test_tf_integration.py deleted file mode 100644 index 35e58a6c5b0f..000000000000 --- a/torchax/test/test_tf_integration.py +++ /dev/null @@ -1,51 +0,0 @@ -import os -import tempfile -import numpy as np -import tensorflow as tf -import torch -import torch.nn.functional as F -import torchax - -from torchax import tf_integration -from . import base_test_util - - -class Interpolate(torch.nn.Module): - - def forward(self, masks: torch.Tensor) -> torch.Tensor: - masks = F.interpolate( - masks, - size=(500, 500), - mode="bilinear", - align_corners=False, - ) - return masks - - -class TfIntegrationTest(base_test_util.TestCase): - - def setUp(self): - torch.manual_seed(0) - torchax.enable_accuracy_mode() - - def test_interpolate(self): - """Simple model roundtripped through TF savedmodel""" - - # Create model - arg = (torch.randn(3, 3, 200, 200),) - pt_model = Interpolate() - - # Export to SavedModel - with tempfile.TemporaryDirectory() as tempdir: - sm_path = os.path.join(tempdir, "interpolate.savedmodel") - tf_integration.save_torch_module_as_tf_saved_model(pt_model, arg, sm_path) - - # Reload SM and compare results with PT results - loaded_model = tf.saved_model.load(sm_path) - pt_res = pt_model(*arg) - tf_res = torch.tensor(loaded_model.f(*arg)[0].numpy()) - self.assertTrue(torch.allclose(pt_res, tf_res, atol=1e-4)) - - -if __name__ == "__main__": - base_test_util.main() diff --git a/torchax/test_dist/test_to_device.py b/torchax/test_dist/test_to_device.py new file mode 100644 index 000000000000..78794fad704e --- /dev/null +++ b/torchax/test_dist/test_to_device.py @@ -0,0 +1,27 @@ +import jax +import torch +import torchax +import unittest + +from jax.sharding import NamedSharding, PartitionSpec as P + + +class ToDeviceTest(unittest.TestCase): + + def test_to_device_twice(self): + env = torchax.default_env() + mesh = jax.make_mesh((jax.device_count(),), ('axis',)) + with env: + step1 = torch.ones( + 100, + 100, + ) + step2 = torch.triu(step1, diagonal=1) + step3 = step2.to(dtype=torch.bool, device='jax') + step3.apply_jax_(jax.device_put, NamedSharding(mesh, P())) + print(step3.to('jax')) + self.assertEqual(step3.device.type, 'jax') + + +if __name__ == '__main__': + unittest.main() diff --git a/torchax/torchax/CONTRIBUTING.md b/torchax/torchax/CONTRIBUTING.md index c61462850652..f908cd2e59bb 100644 --- a/torchax/torchax/CONTRIBUTING.md +++ b/torchax/torchax/CONTRIBUTING.md @@ -1,9 +1,7 @@ -# Contributing to TorchXLA2 +# Contributing to torchax We appreciate all contributions. If you are planning to contribute a bug fix for an open issue, please comment on the thread and we're happy to provide any guidance. You are very welcome to pick issues from good first issue and help wanted labels. -If you plan to contribute new features, utility functions or extensions to the core, please first open an issue and discuss the feature with us. Sending a PR without discussion might end up resulting in a rejected PR, because we might be taking the core in a different direction than you might be aware of. - # Developer setup @@ -19,9 +17,17 @@ conda activate pip install --upgrade "jax[cpu]" torch pip install -r test_requirements.txt pip install -e . -pytest test +pip install pytest-xdist # recommended for running test faster +pytest -n auto test ``` +## Setup on GPU or TPU + +Same as Mac setup, except, if you run test using pytest, please also +add `JAX_PLATFORMS=cpu`. The reason is because pytest usually runs +test in multiple threads. CPU device can be accessed concurrently where +TPU devices usually only allow one accesor per process; so it could deadlock. + ### VSCode I use vscode on my Mac. I loosely followed instruction in @@ -35,4 +41,3 @@ The plugins I installed (a subset of the ones listed above) are: I also changed Python interpreter to point at the one in my conda env. That is all the changes I have. - diff --git a/torchax/torchax/tf_integration.py b/torchax/torchax/tf_integration.py deleted file mode 100644 index c9842089bfcf..000000000000 --- a/torchax/torchax/tf_integration.py +++ /dev/null @@ -1,119 +0,0 @@ -# pylint: disable -import os -from typing import Any, Tuple - -from jax.experimental import jax2tf -import tensorflow as tf -import torch -from torchax import export - - -def exported_program_to_tf_function(ep, enable_xla=True): - weights, jax_program = export.exported_program_to_jax(ep) - wrapped = lambda *args: jax_program(weights, (args,)) - avals = export.extract_avals(ep) - input_signature = [ - tf.TensorSpec(shape=t.shape, dtype=t.dtype, name=f"args_{i}") - for i, t in enumerate(avals) - ] - tf_f = tf.function( - jax2tf.convert( - wrapped, - with_gradient=False, - enable_xla=enable_xla, - ), - autograph=False, - input_signature=input_signature, - ) - return tf_f - - -def exported_program_to_tf_module(ep: torch.export.ExportedProgram, - enable_xla=True) -> tf.Module: - tfm = tf.Module() - tfm.f = exported_program_to_tf_function(ep, enable_xla) - return tfm - - -def save_exported_program_as_tf_saved_model( - ep: torch.export.ExportedProgram, - saved_model_dir: os.PathLike, - serving_key: str = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY, - function_alias: str = "", - enable_xla=True, -): - """This function will export and save a pytorch ExportedProgram to tf.saved_model format. - - The resulting tf.saved_model can be used inference using tf.serving model - server - or further convert to tflite flatbuffer for on-device serving. - - Args: - torch_model: torch.nn.Module - model to export and save - args: Tuple[Any] - a set of args to trace the model with, i.e. - torch_model(*args) must run - saved_model_dir: os.PathLike - location to an empty directory to store the - saved_model - serving_key: str - serving key tag, this is used by tf.serving to know - which function to run. - function_alias: str - passed through saved_model.save, used to tag a - function for inference converter or other tools. - """ - tfm = exported_program_to_tf_module(ep, enable_xla=enable_xla) - signatures = { - serving_key: tfm.f.get_concrete_function(*tfm.f.input_signature) - } - save_options = tf.saved_model.SaveOptions(function_aliases={ - function_alias: tfm.f, - }) - tf.saved_model.save( - tfm, - saved_model_dir, - signatures=signatures, - options=save_options, - ) - - -def save_torch_module_as_tf_saved_model( - torch_model: torch.nn.Module, - args: Tuple[Any], - saved_model_dir: os.PathLike, - serving_key: str = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY, - function_alias: str = "", - enable_xla=True, -): - """This function will export and save a pytorch nn.Module to tf.saved_model format. - - The resulting tf.saved_model can be used inference using tf.serving model - server - or further convert to tflite flatbuffer for on-device serving. - - Args: - torch_model: torch.nn.Module - model to export and save - args: Tuple[Any] - a set of args to trace the model with, i.e. - torch_model(*args) must run - saved_model_dir: os.PathLike - location to an empty directory to store the - saved_model - serving_key: str - serving key tag, this is used by tf.serving to know - which function to run. - function_alias: str - passed through saved_model.save, used to tag a - function for inference converter or other tools. - """ - ep = torch.export.export(torch_model, args) - save_exported_program_as_tf_saved_model(ep, saved_model_dir, serving_key, - function_alias, enable_xla) - - -def exported_program_to_tflite_flatbuffer(ep: torch.export.ExportedProgram): - tfm = exported_program_to_tf_module(ep) - tf_concrete_func = tfm.f.get_concrete_function(*tfm.f.input_signature) - converter = tf.lite.TFLiteConverter.from_concrete_functions( - [tf_concrete_func], tfm) - tflite_model = converter.convert() - return tflite_model - - -def torch_module_to_tflite_flatbuffer(torch_model: torch.nn.Module, - args: Tuple[Any]): - ep = torch.export.export(torch_model, args) - return exported_program_to_tflite_flatbuffer(ep) From f6ff30d3c2cd837e940aaa70b61faf948aa805f7 Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Thu, 4 Sep 2025 11:26:07 -0700 Subject: [PATCH 094/133] Do not skip fetching sources. Since that step is removed, we need to run this step while building This used to happen while building plugin before --- infra/ansible/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/infra/ansible/Dockerfile b/infra/ansible/Dockerfile index 3875442e3747..f4c1021f463d 100644 --- a/infra/ansible/Dockerfile +++ b/infra/ansible/Dockerfile @@ -10,7 +10,7 @@ COPY . /ansible ARG ansible_vars # HACK: install build dependencies only, but skip build step RUN ansible-playbook -vvv playbook.yaml -e "stage=build" -e "${ansible_vars}" --tags "bazel,configure_env,install_deps" -RUN ansible-playbook -vvv playbook.yaml -e "stage=build" -e "${ansible_vars}" --skip-tags=fetch_srcs,install_deps +RUN ansible-playbook -vvv playbook.yaml -e "stage=build" -e "${ansible_vars}" --skip-tags=install_deps FROM python:${python_version}-${debian_version} AS release From 251838164442a6ba2693431a485d7392258069de Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Fri, 5 Sep 2025 12:52:04 -0700 Subject: [PATCH 095/133] Update build_and_test.yml to match r2.8 and r2.8.1 --- .github/workflows/build_and_test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index b990f43f6971..ae5ddc32d3ae 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -3,11 +3,11 @@ on: pull_request: branches: - master - - r[0-9]+.[0-9]+ + - r[0-9]+.[0-9]+(\.[0-9]+)? push: branches: - master - - r[0-9]+.[0-9]+ + - r[0-9]+.[0-9]+(\.[0-9]+)? workflow_dispatch: concurrency: From 6ee7627a7bd5a741bb35f24f3124a1cd84f70ed6 Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Fri, 5 Sep 2025 12:55:55 -0700 Subject: [PATCH 096/133] Update build_and_test.yml --- .github/workflows/build_and_test.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index ae5ddc32d3ae..27fca6a00446 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -3,11 +3,13 @@ on: pull_request: branches: - master - - r[0-9]+.[0-9]+(\.[0-9]+)? + - r[0-9]+.[0-9]+ + - r[0-9]+.[0-9]+.[0-9]+ push: branches: - master - - r[0-9]+.[0-9]+(\.[0-9]+)? + - r[0-9]+.[0-9]+ + - r[0-9]+.[0-9]+.[0-9]+ workflow_dispatch: concurrency: From 8274f945248cc49fda63ca715db1c7cdc7bf22ea Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 5 Sep 2025 17:37:28 -0300 Subject: [PATCH 097/133] Replace `GetComputationClientOrDie()` with `GetComputationClient()` (part 1). (#9617) This PR replaces calls of the deprecated function `GetComputationClientOrDie()` with calls to the `GetComputationClient()` function. The difference between them is that the former throws an exception on error, while the latter returns an status object. _Note: this is the part 1 out of 2 PRs. Together, they will phase out `GetComputationClientOrDie()` function_ --- test/cpp/cpp_test_util.cpp | 40 ++++----- test/cpp/test_replication.cpp | 26 +++--- test/cpp/test_runtime.cpp | 5 +- test/cpp/test_xla_sharding.cpp | 26 +++--- torch_xla/csrc/tensor_util.cpp | 30 ++++--- torch_xla/csrc/xla_backend_impl.cpp | 27 ++++--- torch_xla/csrc/xla_graph_executor.cpp | 112 ++++++++++++++------------ torch_xla/csrc/xla_sharding_util.cpp | 43 +++++----- 8 files changed, 167 insertions(+), 142 deletions(-) diff --git a/test/cpp/cpp_test_util.cpp b/test/cpp/cpp_test_util.cpp index 7db1934c37d2..d79bcba70fac 100644 --- a/test/cpp/cpp_test_util.cpp +++ b/test/cpp/cpp_test_util.cpp @@ -222,18 +222,19 @@ void WithAllDevices( const std::function&, const std::vector&)>& devfn) { + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); for (auto device_type : device_types) { std::vector devices; std::vector all_devices; - for (const auto& device_str : - torch_xla::runtime::GetComputationClientOrDie()->GetLocalDevices()) { + + for (const auto& device_str : client->GetLocalDevices()) { torch::lazy::BackendDevice device = ParseDeviceString(device_str); if (device.type() == device_type.type) { devices.push_back(device); } } - for (const auto& device_str : - torch_xla::runtime::GetComputationClientOrDie()->GetAllDevices()) { + for (const auto& device_str : client->GetAllDevices()) { torch::lazy::BackendDevice device = ParseDeviceString(device_str); if (device.type() == device_type.type) { all_devices.push_back(device); @@ -279,37 +280,36 @@ std::vector Execute( XLA_ASSIGN_OR_THROW(xla::XlaComputation computation, lowering_ctx.BuildXla()); XLA_ASSIGN_OR_THROW(xla::ProgramShape program_shape, computation.GetProgramShape()); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); xla::Shape shape = MakeShapeWithDeviceLayout( program_shape.result(), static_cast(device.type())); std::vector instances; - instances.push_back( - {std::move(computation), device.toString(), - torch_xla::runtime::GetComputationClientOrDie()->GetCompilationDevices( - device.toString(), {}), - &shape}); + instances.push_back({std::move(computation), device.toString(), + client->GetCompilationDevices(device.toString(), {}), + &shape}); std::vector< std::shared_ptr> - computations = torch_xla::runtime::GetComputationClientOrDie()->Compile( - std::move(instances)); + computations = client->Compile(std::move(instances)); torch_xla::runtime::ComputationClient::ExecuteComputationOptions options; - XLA_ASSIGN_OR_THROW( - std::vector outputs, - torch_xla::runtime::GetComputationClientOrDie()->ExecuteComputation( - *computations.front(), - UnwrapXlaData(lowering_ctx.GetParametersData()), device.toString(), - options)); + XLA_ASSIGN_OR_THROW(std::vector outputs, + client->ExecuteComputation( + *computations.front(), + UnwrapXlaData(lowering_ctx.GetParametersData()), + device.toString(), options)); return outputs; } std::vector Fetch( absl::Span device_data) { - XLA_ASSIGN_OR_THROW( - std::vector literals, - runtime::GetComputationClientOrDie()->TransferFromDevice(device_data)); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + XLA_ASSIGN_OR_THROW(std::vector literals, + client->TransferFromDevice(device_data)); std::vector tensors; for (auto& literal : literals) { tensors.push_back(MakeTensorFromXlaLiteral( diff --git a/test/cpp/test_replication.cpp b/test/cpp/test_replication.cpp index 73db4ab42392..8ac622c6eace 100644 --- a/test/cpp/test_replication.cpp +++ b/test/cpp/test_replication.cpp @@ -48,10 +48,10 @@ void TestSingleReplication( instances.emplace_back(CreateCrsComputation(shape), device_str, all_device_strings, &shape); } + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); std::vector - compiled_computations = - torch_xla::runtime::GetComputationClientOrDie()->Compile( - std::move(instances)); + compiled_computations = client->Compile(std::move(instances)); std::vector tensors; for (size_t i = 0; i < device_strings.size(); ++i) { @@ -66,14 +66,13 @@ void TestSingleReplication( torch_xla::runtime::ComputationClient::ExecuteComputationOptions exec_options; for (size_t i = 0; i < device_strings.size(); ++i) { auto executor = [&, i]() { - XLA_ASSIGN_OR_THROW( - results[i], - torch_xla::runtime::GetComputationClientOrDie()->ExecuteComputation( - *compiled_computations[i], - {std::dynamic_pointer_cast< - torch_xla::runtime::ComputationClient::Data>( - tensors_data[i])}, - device_strings[i], exec_options)); + XLA_ASSIGN_OR_THROW(results[i], + client->ExecuteComputation( + *compiled_computations[i], + {std::dynamic_pointer_cast< + torch_xla::runtime::ComputationClient::Data>( + tensors_data[i])}, + device_strings[i], exec_options)); counter.DecrementCount(); }; torch_xla::thread::Schedule(std::move(executor)); @@ -81,9 +80,8 @@ void TestSingleReplication( counter.Wait(); for (size_t i = 0; i < results.size(); ++i) { - XLA_ASSIGN_OR_THROW( - std::vector literals, - runtime::GetComputationClientOrDie()->TransferFromDevice(results[i])); + XLA_ASSIGN_OR_THROW(std::vector literals, + client->TransferFromDevice(results[i])); ASSERT_EQ(literals.size(), 1); // The result must be the original tensor value, multiplied by the number of diff --git a/test/cpp/test_runtime.cpp b/test/cpp/test_runtime.cpp index fda473b7985c..d5ec976d21d2 100644 --- a/test/cpp/test_runtime.cpp +++ b/test/cpp/test_runtime.cpp @@ -13,13 +13,10 @@ TEST(RuntimeTest, ComputationClientInitialization) { // Initialize the ComputationClient. // Check all the APIs return the same valid ComputationClient. - client = GetComputationClientOrDie(); - ASSERT_NE(client, nullptr); - auto status = GetComputationClient(); ASSERT_TRUE(status.ok()); - EXPECT_EQ(status.value(), client); + client = status.value(); EXPECT_EQ(GetComputationClientIfInitialized(), client); } diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index f3b7541f6273..4dd5e965720b 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -332,16 +332,16 @@ TEST_F(XLAShardingTest, CreateTensorsData) { std::vector tensors_data = CreateTensorsData(tensors, shardings, devices); - int64_t n_devices = - torch_xla::runtime::GetComputationClientOrDie()->GetLocalDevices().size(); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + int64_t n_devices = client->GetLocalDevices().size(); if (n_devices > 1) { // null sharding is treated as replicated. auto xla_data = std::dynamic_pointer_cast( tensors_data[0]); std::vector shards = - torch_xla::runtime::GetComputationClientOrDie()->GetDataShards( - xla_data); + client->GetDataShards(xla_data); EXPECT_EQ(shards.size(), n_devices); EXPECT_TRUE(xla::Shape::Equal().IgnoreLayout()(xla_data->shape(), shards[0]->shape())); @@ -351,8 +351,7 @@ TEST_F(XLAShardingTest, CreateTensorsData) { auto sharded_xla_data = std::dynamic_pointer_cast( tensors_data[1]); - shards = torch_xla::runtime::GetComputationClientOrDie()->GetDataShards( - sharded_xla_data); + shards = client->GetDataShards(sharded_xla_data); EXPECT_EQ(shards.size(), n_devices); EXPECT_TRUE(xla::Shape::Equal().IgnoreLayout()(sharded_xla_data->shape(), shards[0]->shape())); @@ -362,8 +361,7 @@ TEST_F(XLAShardingTest, CreateTensorsData) { sharded_xla_data = std::dynamic_pointer_cast( tensors_data[2]); - shards = torch_xla::runtime::GetComputationClientOrDie()->GetDataShards( - sharded_xla_data); + shards = client->GetDataShards(sharded_xla_data); EXPECT_EQ(shards.size(), n_devices); EXPECT_TRUE(xla::Shape::Equal().IgnoreLayout()(sharded_xla_data->shape(), shards[0]->shape())); @@ -373,8 +371,9 @@ TEST_F(XLAShardingTest, CreateTensorsData) { TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) { xla::Shape shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {4, 4}); - int64_t n_devices = - torch_xla::runtime::GetComputationClientOrDie()->GetLocalDevices().size(); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + int64_t n_devices = client->GetLocalDevices().size(); xla::Array tile_assignment({1, n_devices}); tile_assignment.FillIota(0); xla::OpSharding tiled = xla::HloSharding::Tile(tile_assignment).ToProto(); @@ -397,15 +396,14 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) { std::vector< std::shared_ptr> - computations = torch_xla::runtime::GetComputationClientOrDie()->Compile( - std::move(instances)); + computations = client->Compile(std::move(instances)); torch_xla::runtime::ComputationClient::ComputationPtr computation = std::make_shared( "add", std::move(computations[0]->move_computation())); // Prepare output sharding propagation, expect a sharded output placeholder. - std::vector tensors{XLATensor::Create( - torch_xla::runtime::GetComputationClientOrDie()->CreateDataPlaceholder( + std::vector tensors{ + XLATensor::Create(client->CreateDataPlaceholder( bridge::GetDefaultDevice()->toString(), std::move(shape)))}; std::vector data_placeholders; std::vector sharding_specs; diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index 688097e188f1..fb88e21cc721 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -550,11 +550,14 @@ torch::lazy::BackendDataPtr TensorToXlaData( const at::Tensor& tensor, const xla::Shape& shape, const torch::lazy::BackendDevice& device) { TORCH_LAZY_TIMED("TensorToData"); + + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + if (static_cast(device.type()) == XlaDeviceType::SPMD) { // The tensor is bypassing the virtual device, so it should be replicated // to all devices. - std::vector local_devices = - runtime::GetComputationClientOrDie()->GetLocalDevices(); + std::vector local_devices = client->GetLocalDevices(); auto replicated_data = std::vector(local_devices.size(), tensor); return ShardingUtil::CreateShardedData(replicated_data, local_devices, @@ -565,8 +568,7 @@ torch::lazy::BackendDataPtr TensorToXlaData( source_tensors.push_back( std::make_shared(tensor, shape, device.toString())); - auto handles = - runtime::GetComputationClientOrDie()->TransferToDevice(source_tensors); + auto handles = client->TransferToDevice(source_tensors); XLA_CHECK_EQ(handles.size(), 1); return handles.front(); } @@ -806,6 +808,9 @@ std::vector CreateTensorsData( return {}; } + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + // CreateTensorsData should be implicitly replicated to all devices. if (IsVirtualDevice(devices[0])) { XLA_CHECK( @@ -813,8 +818,7 @@ std::vector CreateTensorsData( [&](const std::string& s) { return s == devices[0]; })) << "can't mix virtual device and real device."; - std::vector local_devices = - runtime::GetComputationClientOrDie()->GetLocalDevices(); + std::vector local_devices = client->GetLocalDevices(); std::vector handles; for (size_t i = 0; i < tensors.size(); ++i) { auto device = ParseDeviceString(devices[i]); @@ -834,8 +838,7 @@ std::vector CreateTensorsData( source_tensors.push_back(std::make_shared( tensors[i], std::move(shape), devices[i])); } - return WrapXlaData( - runtime::GetComputationClientOrDie()->TransferToDevice(source_tensors)); + return WrapXlaData(client->TransferToDevice(source_tensors)); } std::vector CreateTensorsData( @@ -846,6 +849,9 @@ std::vector CreateTensorsData( XLA_CHECK_EQ(tensors.size(), shardings.size()); XLA_CHECK_EQ(tensors.size(), devices.size()); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + std::vector handles; for (size_t i = 0; i < tensors.size(); ++i) { torch::lazy::BackendDevice device = ParseDeviceString(devices[i]); @@ -858,8 +864,7 @@ std::vector CreateTensorsData( // GetLocalDevices returns the list of local devices specified by their // global ordinals (e.g. ["TPU:4", "TPU:5", "TPU:6", "TPU:7"]). - std::vector local_devices = - runtime::GetComputationClientOrDie()->GetLocalDevices(); + std::vector local_devices = client->GetLocalDevices(); // Shards the input tensors with padding, to split evenly. // The execution requires consistent shard sizes, and the zero-padded // values should be ignored. @@ -871,8 +876,7 @@ std::vector CreateTensorsData( } else { source_tensors.push_back(std::make_shared( tensors[i], std::move(shape), devices[i])); - new_handles = runtime::GetComputationClientOrDie()->TransferToDevice( - source_tensors); + new_handles = client->TransferToDevice(source_tensors); } handles.insert(handles.end(), new_handles.begin(), new_handles.end()); } @@ -910,7 +914,7 @@ absl::StatusOr> ReleaseGilAndTransferData( save = PyEval_SaveThread(); } - XLA_ASSIGN_OR_RETURN(runtime::ComputationClient * client, + XLA_ASSIGN_OR_RETURN(runtime::ComputationClient * absl_nonnull const client, runtime::GetComputationClient()); XLA_ASSIGN_OR_RETURN(std::vector literals, client->TransferFromDevice(UnwrapXlaData(xla_data))); diff --git a/torch_xla/csrc/xla_backend_impl.cpp b/torch_xla/csrc/xla_backend_impl.cpp index 39e488307619..7d7acb735efa 100644 --- a/torch_xla/csrc/xla_backend_impl.cpp +++ b/torch_xla/csrc/xla_backend_impl.cpp @@ -28,8 +28,11 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface { if (!default_device_type_inited_) { // bridge::GetDefaultDevice will trigger the runtime device init, should // not do it during class init time. - default_device_type_ = std::make_shared( - runtime::GetComputationClientOrDie()->GetDeviceType()); + XLA_ASSIGN_OR_THROW( + runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + default_device_type_ = + std::make_shared(client->GetDeviceType()); default_device_type_inited_ = true; } return true; @@ -77,8 +80,10 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface { const torch::lazy::BackendDevice& device, const torch::lazy::Shape& shape) const override { xla::Shape xla_shape = MakeXlaShapeFromLazyShape(shape, device); - return runtime::GetComputationClientOrDie()->CreateDataPlaceholder( - device.toString(), std::move(xla_shape)); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + return client->CreateDataPlaceholder(device.toString(), + std::move(xla_shape)); } torch::lazy::BackendDataPtr GetComputationDataFromNode( @@ -121,8 +126,9 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface { std::vector GetCompilationDevices( const std::string& device, c10::ArrayRef devices) const override { - return runtime::GetComputationClientOrDie()->GetCompilationDevices(device, - devices); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + return client->GetCompilationDevices(device, devices); } std::vector Compile( @@ -155,9 +161,10 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface { torch_xla_computation->get_device_string(), {current_device.toString()}, &output_shapes.back())); } + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); std::vector> - client_computations = runtime::GetComputationClientOrDie()->Compile( - std::move(compile_instances)); + client_computations = client->Compile(std::move(compile_instances)); return {client_computations.begin(), client_computations.end()}; } @@ -165,9 +172,11 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface { torch::lazy::ComputationPtr computation, c10::ArrayRef arguments, const torch::lazy::BackendDevice& device) const override { + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); XLA_ASSIGN_OR_THROW( std::vector results, - runtime::GetComputationClientOrDie()->ExecuteComputation( + client->ExecuteComputation( *std::dynamic_pointer_cast( computation), UnwrapXlaData(arguments), device.toString())); diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index cf6a5a4105d2..42e38b921849 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -90,14 +90,18 @@ XLAGraphExecutor::ComputationCache* CreateComputationCache() { auto serialize_fn = [](XLAGraphExecutor::ComputationCache::TypePtr computation) -> std::string { - return runtime::GetComputationClientOrDie()->SerializeComputation( - computation->computation); + XLA_ASSIGN_OR_THROW( + runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + return client->SerializeComputation(computation->computation); }; auto deserialize_fn = [](std::string serialization) -> XLAGraphExecutor::ComputationCache::TypePtr { + XLA_ASSIGN_OR_THROW( + runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); runtime::ComputationClient::ComputationPtr computation = - runtime::GetComputationClientOrDie()->DeserializeComputation( - serialization); + client->DeserializeComputation(serialization); if (!computation) return nullptr; return std::make_shared( computation, /*is_sharded=*/UseVirtualDevice()); @@ -469,8 +473,10 @@ void XLAGraphExecutor::WaitDeviceOps(absl::Span devices) { if (UseVirtualDevice()) { wait_devices.insert(ParseDeviceString("SPMD:0")); } else { - for (auto& device_str : - runtime::GetComputationClientOrDie()->GetLocalDevices()) { + XLA_ASSIGN_OR_THROW( + runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + for (auto& device_str : client->GetLocalDevices()) { wait_devices.insert(ParseDeviceString(device_str)); } } @@ -584,6 +590,8 @@ XLAGraphExecutor::ComputationCache* XLAGraphExecutor::GetComputationCache() { void XLAGraphExecutor::ClearPendingIrs( std::vector tensors, const torch::lazy::BackendDevice& device) { + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); std::unordered_set tensor_ids; for (size_t i = 0; i < tensors.size(); ++i) { if (tensor_ids.insert(tensors[i]->GetUniqueId()).second && @@ -598,9 +606,8 @@ void XLAGraphExecutor::ClearPendingIrs( } else { xla::Shape shape = MakeShapeWithDeviceLayout( tensors[i]->shape(), static_cast(device.type())); - torch::lazy::BackendDataPtr handle = - runtime::GetComputationClientOrDie()->CreateDataPlaceholder( - device.toString(), std::move(shape)); + torch::lazy::BackendDataPtr handle = client->CreateDataPlaceholder( + device.toString(), std::move(shape)); tensors[i]->data()->handle = handle; TF_VLOG(4) << "Replacing the IR " << ir_value.node.get()->ToString() << " of Tensor with ID " << tensors[i]->GetUniqueId() @@ -638,10 +645,12 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors( // Ensure that the compilation environment and git revisions are reflected // in the hash, so that different versions of the code can produce different // hashes for the same graph. - MergeHash({runtime::GetComputationClientOrDie()->HashCompilationEnv(), - torch::lazy::StringHash(TORCH_GITREV), - torch::lazy::StringHash(XLA_GITREV)}, - &coll.hash); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + MergeHash( + {client->HashCompilationEnv(), torch::lazy::StringHash(TORCH_GITREV), + torch::lazy::StringHash(XLA_GITREV)}, + &coll.hash); coll.config = config; coll.device = *unique_device; coll.indices.reserve(tensors.size()); @@ -779,10 +788,11 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( placeholders = ShardingUtil::CreateShardedPlaceholder(output_sharding_hash[hash]); } else { + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); for (const xla::Shape& shape : *output_shapes) { torch::lazy::BackendDataPtr handle = - runtime::GetComputationClientOrDie()->CreateDataPlaceholder( - device.toString(), std::move(shape)); + client->CreateDataPlaceholder(device.toString(), std::move(shape)); placeholders.push_back(handle); } } @@ -840,18 +850,19 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( std::vector results; if (async->cached_computation->is_sharded) { // TODO(JackCaoG): handle eager mode - std::vector devices = - runtime::GetComputationClientOrDie()->GetLocalDevices(); + XLA_ASSIGN_OR_THROW( + runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + std::vector devices = client->GetLocalDevices(); runtime::ComputationClient::ExecuteReplicatedOptions execute_options; // OutputHandler creates sharded data for sharded // tensor results. Both sharded and unsharded results should be // "Assign"ed to the corresponding data placeholders. XLA_ASSIGN_OR_THROW( std::vector outputs, - runtime::GetComputationClientOrDie()->ExecuteReplicated( - *async->cached_computation->computation, - UnwrapXlaData(async->parameters_data), devices, - execute_options)); + client->ExecuteReplicated(*async->cached_computation->computation, + UnwrapXlaData(async->parameters_data), + devices, execute_options)); results = WrapXlaData(outputs); TF_VLOG(3) << "Executing Dynamo IR sharded graph hash " << torch::lazy::HashToString(hash) << " on devices " @@ -918,16 +929,15 @@ std::vector XLAGraphExecutor::ExecuteStablehlo( xla::Shape shape = MakeShapeWithDeviceLayout( program_shape.result(), static_cast(device.type())); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); std::vector instances; - instances.emplace_back( - std::move(computation), device.toString(), - runtime::GetComputationClientOrDie()->GetCompilationDevices( - device.toString(), - runtime::GetComputationClientOrDie()->GetLocalDevices()), - &shape); + instances.emplace_back(std::move(computation), device.toString(), + client->GetCompilationDevices( + device.toString(), client->GetLocalDevices()), + &shape); std::vector> - computations = - runtime::GetComputationClientOrDie()->Compile(std::move(instances)); + computations = client->Compile(std::move(instances)); std::vector arguments; { @@ -948,8 +958,8 @@ std::vector XLAGraphExecutor::ExecuteStablehlo( XLA_ASSIGN_OR_THROW( std::vector result_data, - runtime::GetComputationClientOrDie()->ExecuteComputation( - *computations[0], UnwrapXlaData(arguments), device.toString())); + client->ExecuteComputation(*computations[0], UnwrapXlaData(arguments), + device.toString())); return WrapXlaData(result_data); } @@ -1044,6 +1054,8 @@ void XLAGraphExecutor::ExtractIRAndPrepareXlaData_( tsl::profiler::TraceMeLevel::kInfo); ir_values.reserve(indices.size()); tensor_data_vec.reserve(indices.size()); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); for (auto index : indices) { XLATensorPtr& tensor = (*tensors)[index]; torch::lazy::Value ir_value = tensor->CurrentIrValue(); @@ -1051,9 +1063,8 @@ void XLAGraphExecutor::ExtractIRAndPrepareXlaData_( const torch::lazy::BackendDevice& tensor_device = tensor->GetDevice(); xla::Shape shape = MakeShapeWithDeviceLayout( tensor->shape(), static_cast(tensor_device.type())); - torch::lazy::BackendDataPtr handle = - runtime::GetComputationClientOrDie()->CreateDataPlaceholder( - tensor_device.toString(), std::move(shape)); + torch::lazy::BackendDataPtr handle = client->CreateDataPlaceholder( + tensor_device.toString(), std::move(shape)); tensor_data_vec.push_back(handle); if (tensor->CurrentDataHandle() == nullptr && config.force_ltc_data) { @@ -1114,9 +1125,11 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph( try { std::vector results; // Execute replicated if the compiled computation is partitioned. + XLA_ASSIGN_OR_THROW( + runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); if (async->cached_computation->is_sharded) { - std::vector devices = - runtime::GetComputationClientOrDie()->GetLocalDevices(); + std::vector devices = client->GetLocalDevices(); runtime::ComputationClient::ExecuteReplicatedOptions execute_options; TF_VLOG(3) << "Executing IR graph hash " << torch::lazy::HashToString(hash) @@ -1126,10 +1139,9 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph( // "Assign"ed to the corresponding data placeholders. XLA_ASSIGN_OR_THROW( std::vector outputs, - runtime::GetComputationClientOrDie()->ExecuteReplicated( - *async->cached_computation->computation, - UnwrapXlaData(async->parameters_data), devices, - execute_options)); + client->ExecuteReplicated(*async->cached_computation->computation, + UnwrapXlaData(async->parameters_data), + devices, execute_options)); results = WrapXlaData(outputs); TORCH_LAZY_COUNTER("ExecuteReplicated", 1); TF_VLOG(3) << "Executing IR graph hash " @@ -1142,11 +1154,11 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph( << async->device << " ..."; XLA_ASSIGN_OR_THROW( std::vector outputs, - runtime::GetComputationClientOrDie()->ExecuteComputation( - *async->cached_computation->computation, - UnwrapXlaData(async->parameters_data), async->device.toString(), - {/*explode_tuple=*/true, - /*eager_mode=*/use_eager_mode})); + client->ExecuteComputation(*async->cached_computation->computation, + UnwrapXlaData(async->parameters_data), + async->device.toString(), + {/*explode_tuple=*/true, + /*eager_mode=*/use_eager_mode})); results = WrapXlaData(outputs); TORCH_LAZY_COUNTER("ExecuteComputation", 1); TF_VLOG(3) << "Executing IR graph hash " @@ -1444,12 +1456,13 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( xla::Shape shape = MakeShapeWithDeviceLayout( program_shape.result(), static_cast(coll.device.type())); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); std::vector instances; instances.push_back( {std::move(computation), coll.device.toString(), - runtime::GetComputationClientOrDie()->GetCompilationDevices( - coll.device.toString(), devices), - &shape, should_wrap_parameter, is_sharded}); + client->GetCompilationDevices(coll.device.toString(), devices), &shape, + should_wrap_parameter, is_sharded}); instances.front().eager_mode = UseEagerMode(); if (use_autosharding) { TF_VLOG(5) << "use_auto_spmd_partitioning is set."; @@ -1480,8 +1493,7 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( << torch::lazy::HashToString(coll.hash) << " on device " << coll.device << " ..."; std::vector> - computations = - runtime::GetComputationClientOrDie()->Compile(std::move(instances)); + computations = client->Compile(std::move(instances)); DebugUtil::post_compilation_analysis(computations[0]); TF_VLOG(3) << "Compiling IR graph hash " << torch::lazy::HashToString(coll.hash) << " on device " diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 55c6ebf186f8..b3f6346020d8 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -516,14 +516,15 @@ std::vector ShardingUtil::CreateShardedPlaceholder( const std::vector& sharding_specs) { std::vector placeholders; placeholders.reserve(sharding_specs.size()); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); for (int i = 0; i < sharding_specs.size(); ++i) { // Create sharded data placeholder, this will be used to // hold the corresponding computation results for both sharding & // replication. - auto sharded_data_placeholder = - runtime::GetComputationClientOrDie()->CreateDataPlaceholder( - GetVirtualDevice().toString(), sharding_specs[i]->shape, - sharding_specs[i]->sharding); + auto sharded_data_placeholder = client->CreateDataPlaceholder( + GetVirtualDevice().toString(), sharding_specs[i]->shape, + sharding_specs[i]->sharding); // Register the sharded data placeholder to the tensor and its node. placeholders.push_back(sharded_data_placeholder); @@ -551,6 +552,8 @@ void ShardingUtil::PrepareOutputShardingPropagation( << "Expected size: " << indices.size() << ", actual size: " << new_sharding_specs.size(); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); for (int i = 0; i < indices.size(); ++i) { auto xtensor = (*tensors)[indices[i]]; (*sharding_specs)[i] = new_sharding_specs[i]; @@ -562,10 +565,9 @@ void ShardingUtil::PrepareOutputShardingPropagation( // Create sharded data placeholder, this will be used to // hold the corresponding computation results for both sharding & // replication. - auto sharded_data_placeholder = - runtime::GetComputationClientOrDie()->CreateDataPlaceholder( - GetVirtualDevice().toString(), (*sharding_specs)[i]->shape, - (*sharding_specs)[i]->sharding); + auto sharded_data_placeholder = client->CreateDataPlaceholder( + GetVirtualDevice().toString(), (*sharding_specs)[i]->shape, + (*sharding_specs)[i]->sharding); // Register the sharded data placeholder to the tensor and its node. (*data_placeholders)[i] = sharded_data_placeholder; @@ -609,7 +611,9 @@ runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( source_tensors.push_back(std::make_shared( local_shards[j], shard_shape, devices[j])); } - return runtime::GetComputationClientOrDie()->TransferShardsToDevice( + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + return client->TransferShardsToDevice( source_tensors, GetVirtualDevice().toString(), global_shape, sharding); } @@ -624,8 +628,9 @@ std::vector ShardingUtil::GetAutoShardingMesh() { for (auto i : mesh_shape) { total_devices *= i; } - XLA_CHECK_EQ(total_devices, - runtime::GetComputationClientOrDie()->GetAllDevices().size()) + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + XLA_CHECK_EQ(total_devices, client->GetAllDevices().size()) << "Invalid auto-sharding mesh_shape: " << absl::StrJoin(mesh_shape, ","); } @@ -640,8 +645,9 @@ std::vector ShardingUtil::GetAutoShardingMeshIds( // as the auto-sharding pass takes only one arrangement for now. // TODO(yeounoh) this was not necessary before; replace if this can be done // during the auto-sharding pass. - int64_t n_devices = - runtime::GetComputationClientOrDie()->GetAllDevices().size(); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + int64_t n_devices = client->GetAllDevices().size(); std::vector device_mesh_ids = std::vector(n_devices); std::iota(device_mesh_ids.begin(), device_mesh_ids.end(), 0); @@ -736,14 +742,15 @@ void ShardingUtil::ReshardParameters( // more-granular control over the peak memory consumption. bool group_sharding = runtime::sys_util::GetEnvBool("XLA_AUTO_USE_GROUP_SHARDING", true); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); if (group_sharding) { - outputs = WrapXlaData(runtime::GetComputationClientOrDie()->ReshardData( - data_to_reshard, shardings_to_reshard)); + outputs = + WrapXlaData(client->ReshardData(data_to_reshard, shardings_to_reshard)); } else { for (int i = 0; i < data_to_reshard.size(); ++i) { - auto output = - WrapXlaData(runtime::GetComputationClientOrDie()->ReshardData( - {data_to_reshard[i]}, {shardings_to_reshard[i]})); + auto output = WrapXlaData( + client->ReshardData({data_to_reshard[i]}, {shardings_to_reshard[i]})); outputs.insert(outputs.end(), output.begin(), output.end()); } } From 92dcabc0571bc0e5df635bb4752e98bfacd5d2cd Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 5 Sep 2025 17:41:29 -0300 Subject: [PATCH 098/133] `mm`: improve error handling and error messages. (#9621) This PR refactors the `mm` operation implementation by improving its error message, and returning a status type value. **Key Changes:** - Make `tensor_methods::mm` return `Status` - Refactor `XLANativeFunctions::mm` overloads to handle the status values - Improve error messages and error handling --- test/test_operations.py | 28 ++++++++++++++++++ torch_xla/csrc/aten_xla_type.cpp | 4 ++- torch_xla/csrc/tensor_methods.cpp | 48 +++++++++++++++++++++++++------ torch_xla/csrc/tensor_methods.h | 3 +- 4 files changed, 73 insertions(+), 10 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index 08c2f39d63c3..635def9634bc 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2503,6 +2503,34 @@ def test_random__raises_error_on_value_out_of_type_value_range(self): "than the upper bound.") self.assertEqual(str(e), expected_error) + def test_mm_raises_error_on_non_matrix_input(self): + device = torch_xla.device() + a = torch.rand(2, 2, 2, device=device) + b = torch.rand(2, 2, device=device) + + try: + torch.mm(a, b) + except RuntimeError as e: + expected_error = ( + "mm(): expected the first input tensor f32[2,2,2] to be a " + "matrix (i.e. a 2D tensor).") + self.assertEqual(str(e), expected_error) + + def test_mm_raises_error_on_incompatible_shapes(self): + device = torch_xla.device() + a = torch.rand(2, 5, device=device) + b = torch.rand(8, 2, device=device) + + try: + torch.mm(a, b) + except RuntimeError as e: + expected_error = ( + "mm(): cannot matrix-multiply tensors f32[2,5] and f32[8,2]. " + "Expected the size of dimension 1 of the first input tensor (5) " + "to be equal the size of dimension 0 of the second input " + "tensor (8).") + self.assertEqual(str(e), expected_error) + class MNISTComparator(nn.Module): diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 86b4cb84707e..c042d703aa36 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2495,7 +2495,9 @@ at::Tensor XLANativeFunctions::mm(const at::Tensor& self, TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); XLA_ASSIGN_OR_THROW(XLATensorPtr xla_mat2, bridge::GetXlaTensor(mat2)); - return bridge::AtenFromXlaTensor(tensor_methods::mm(xla_self, xla_mat2)); + XLA_ASSIGN_OR_THROW(XLATensorPtr output, + tensor_methods::mm(xla_self, xla_mat2)); + return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor XLANativeFunctions::mse_loss(const at::Tensor& self, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index e7814ce517d5..21f4db597133 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -11,6 +11,7 @@ #include #include "absl/log/absl_check.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "torch_xla/csrc/LazyIr.h" @@ -453,14 +454,14 @@ absl::Status CheckGatherRanksAreEqual(const XLATensorPtr& input, return absl::OkStatus(); } -// Checks that all index dimensions are smaller or equal to those of input, -// except on dimension canonical_dim. -absl::Status CheckGatherDimensionsAreCompatible(const XLATensorPtr& input, - const XLATensorPtr& index, - int64_t canonical_dim) { +// Checks that all index dimension sizes are smaller or equal to those of +// input, except on dimension canonical_dim. +absl::Status CheckGatherSizesAreCompatible(const XLATensorPtr& input, + const XLATensorPtr& index, + int64_t canonical_dim) { // Dimensions that fail the "smaller or equal" condition. std::vector bad_dims; - for (int64_t dim = 0; dim < input->shape().get().dimensions_size(); dim++) { + for (int64_t dim = 0; dim < input->shape().get().dimensions().size(); dim++) { if (dim != canonical_dim && input->size(dim) < index->size(dim)) { bad_dims.push_back(dim); } @@ -478,6 +479,33 @@ absl::Status CheckGatherDimensionsAreCompatible(const XLATensorPtr& input, return absl::OkStatus(); } +absl::Status CheckMMInputIsMatrix(const XLATensorPtr& mat, + const std::string_view arg) { + xla::Shape shape = mat->shape(); + if (shape.dimensions().size() != 2) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError( + absl::StrCat("mm(): expected the ", arg, " input tensor ", + shape.ToString(), " to be a matrix (i.e. a 2D tensor)."))); + } + return absl::OkStatus(); +} + +absl::Status CheckMMMatrixSizesAreCompatible(const XLATensorPtr& mat1, + const XLATensorPtr& mat2) { + xla::Shape shape1 = mat1->shape(); + xla::Shape shape2 = mat2->shape(); + if (shape1.dimensions(1) != shape2.dimensions(0)) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + "mm(): cannot matrix-multiply tensors ", shape1.ToString(), " and ", + shape2.ToString(), + ". Expected the size of dimension 1 of the first input tensor (", + shape1.dimensions(1), + ") to be equal the size of dimension 0 of the second input tensor (", + shape2.dimensions(0), ")."))); + } + return absl::OkStatus(); +} + } // namespace ////////////////////////////////////////////////////////////////////////////// @@ -1844,7 +1872,7 @@ absl::StatusOr gather(const XLATensorPtr& input, dim, input->shape().get().dimensions_size()); XLA_RETURN_IF_ERROR(CheckGatherRanksAreEqual(input, index)); XLA_RETURN_IF_ERROR( - CheckGatherDimensionsAreCompatible(input, index, canonical_dim)); + CheckGatherSizesAreCompatible(input, index, canonical_dim)); return input->CreateFrom(torch_xla::MakeNode( input->GetIrValue(), canonical_dim, index->GetIrValue())); } @@ -2349,7 +2377,11 @@ XLATensorPtr mish(const XLATensorPtr& input) { tensor_ops::Softplus(input, 1, 20)->GetIrValue())); } -XLATensorPtr mm(const XLATensorPtr& input, const XLATensorPtr& weight) { +absl::StatusOr mm(const XLATensorPtr& input, + const XLATensorPtr& weight) { + XLA_RETURN_IF_ERROR(CheckMMInputIsMatrix(input, "first")); + XLA_RETURN_IF_ERROR(CheckMMInputIsMatrix(weight, "second")); + XLA_RETURN_IF_ERROR(CheckMMMatrixSizesAreCompatible(input, weight)); return input->CreateFrom(Dot(input->GetIrValue(), weight->GetIrValue())); } diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 597640bf4c49..b25b423d49c2 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -646,7 +646,8 @@ void min_out(XLATensorPtr& min, XLATensorPtr& min_indices, XLATensorPtr mish(const XLATensorPtr& input); -XLATensorPtr mm(const XLATensorPtr& input, const XLATensorPtr& weight); +absl::StatusOr mm(const XLATensorPtr& input, + const XLATensorPtr& weight); XLATensorPtr mse_loss(const XLATensorPtr& input, const XLATensorPtr& target, int64_t reduction); From 6c5478ff7c3d50dd1e3047d72ec5909bea474073 Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Fri, 5 Sep 2025 13:55:16 -0700 Subject: [PATCH 099/133] Add triggers for v2.8.1 version --- .../artifacts.auto.tfvars | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/infra/tpu-pytorch-releases/artifacts.auto.tfvars b/infra/tpu-pytorch-releases/artifacts.auto.tfvars index fd47e4b63f79..acbe6e6beedf 100644 --- a/infra/tpu-pytorch-releases/artifacts.auto.tfvars +++ b/infra/tpu-pytorch-releases/artifacts.auto.tfvars @@ -3,6 +3,42 @@ manual_nightly_builds = [ ] manual_versioned_builds = [ + { + git_tag = "v2.8.1" + package_version = "2.8.1" + pytorch_git_rev = "v2.8.0" + accelerator = "tpu" + python_version = "3.10" + bundle_libtpu = "0" + cxx11_abi = "1" + }, + { + git_tag = "v2.8.1" + package_version = "2.8.1" + pytorch_git_rev = "v2.8.0" + accelerator = "tpu" + python_version = "3.11" + bundle_libtpu = "0" + cxx11_abi = "1" + }, + { + git_tag = "v2.8.1" + package_version = "2.8.1" + pytorch_git_rev = "v2.8.0" + accelerator = "tpu" + python_version = "3.12" + bundle_libtpu = "0" + cxx11_abi = "1" + }, + { + git_tag = "v2.8.1" + package_version = "2.8.1" + pytorch_git_rev = "v2.8.0" + accelerator = "tpu" + python_version = "3.13" + bundle_libtpu = "0" + cxx11_abi = "1" + }, { git_tag = "v2.8.0" package_version = "2.8.0" From aba96d81c8bc2725f837337e8cd1ada3f2f475ae Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Mon, 8 Sep 2025 22:08:13 -0300 Subject: [PATCH 100/133] Replace `GetComputationClientOrDie()` with `GetComputationClient()` (part 2). (#9620) This PR replaces calls of the deprecated function `GetComputationClientOrDie()` with calls to the `GetComputationClient()` function. The difference between them is that the former throws an exception on error, while the latter returns an status object. **Key Changes:** - Remove `GetComputationClientOrDie()` function In general, this PR applies the following replacement pattern: - Create a new `ComputationClient*` variable using `XLA_ASSIGN_OR_THROW()` macro - Replaces all `GetComputationClientOrDie()` with the new variable ```c++ /* Before */ runtime::ComputationClient::ComputationPtr computation = runtime::GetComputationClientOrDie()->DeserializeComputation( serialization); /* After */ XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull client, runtime::GetComputationClient()); runtime::ComputationClient::ComputationPtr computation = client->DeserializeComputation(serialization); ``` _Note: this is the part 2 out of 2 PRs. Together, they will phase out `GetComputationClientOrDie()` function_ --- torch_xla/csrc/aten_xla_bridge.cpp | 20 ++- torch_xla/csrc/cross_replica_reduces.cpp | 5 +- torch_xla/csrc/dl_convertor.cpp | 22 +-- torch_xla/csrc/init_python_bindings.cpp | 195 +++++++++++++++-------- torch_xla/csrc/ir_dump_util.cpp | 15 +- torch_xla/csrc/ops/device_data.cpp | 7 +- torch_xla/csrc/runtime/runtime.cpp | 14 +- torch_xla/csrc/runtime/runtime.h | 7 - 8 files changed, 172 insertions(+), 113 deletions(-) diff --git a/torch_xla/csrc/aten_xla_bridge.cpp b/torch_xla/csrc/aten_xla_bridge.cpp index 8af6f5816756..42d396e9ac2c 100644 --- a/torch_xla/csrc/aten_xla_bridge.cpp +++ b/torch_xla/csrc/aten_xla_bridge.cpp @@ -57,8 +57,10 @@ class AtenXlaDeviceMapper { devices_.emplace_back(ParseDeviceString("SPMD:0")); devices_ordinals_[devices_.back()] = 0; } else { - for (auto& device_str : - torch_xla::runtime::GetComputationClientOrDie()->GetLocalDevices()) { + XLA_ASSIGN_OR_THROW( + runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + for (auto& device_str : client->GetLocalDevices()) { devices_.emplace_back(ParseDeviceString(device_str)); devices_ordinals_[devices_.back()] = devices_.size() - 1; } @@ -398,11 +400,15 @@ std::string ToXlaString(const c10::Device& device) { } const torch::lazy::BackendDevice* GetDefaultDevice() { - static std::string default_device_spec = - UseVirtualDevice() - ? "SPMD:0" - : runtime::GetComputationClientOrDie()->GetDefaultDevice(); - XLA_CHECK(!default_device_spec.empty()); + static std::string default_device_spec = []() -> std::string { + if (UseVirtualDevice()) { + return "SPMD:0"; + } + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + return client->GetDefaultDevice(); + }(); + ABSL_CHECK(!default_device_spec.empty()); static const torch::lazy::BackendDevice default_device = ParseDeviceString(default_device_spec); return &default_device; diff --git a/torch_xla/csrc/cross_replica_reduces.cpp b/torch_xla/csrc/cross_replica_reduces.cpp index 77519c03cfc2..6d8abd33ad6f 100644 --- a/torch_xla/csrc/cross_replica_reduces.cpp +++ b/torch_xla/csrc/cross_replica_reduces.cpp @@ -333,8 +333,9 @@ at::Tensor all_to_all_single(const at::Tensor& input, bool pin_layout = false; const torch::lazy::Value& token = GetAllReduceToken(bridge::GetCurrentDevice()); - int64_t split_count = - runtime::GetComputationClientOrDie()->GetAllDevices().size(); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + int64_t split_count = client->GetAllDevices().size(); std::vector all_groups(split_count); std::iota(all_groups.begin(), all_groups.end(), 0); diff --git a/torch_xla/csrc/dl_convertor.cpp b/torch_xla/csrc/dl_convertor.cpp index c6a68a65f609..274e30f017d5 100644 --- a/torch_xla/csrc/dl_convertor.cpp +++ b/torch_xla/csrc/dl_convertor.cpp @@ -125,8 +125,9 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { ABSL_CHECK(handle != nullptr) << "Could not extract a valid data handle from the input tensor"; - std::shared_ptr pjrt_buffer = - runtime::GetComputationClientOrDie()->GetPjRtBuffer(handle); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + std::shared_ptr pjrt_buffer = client->GetPjRtBuffer(handle); ABSL_CHECK(pjrt_buffer != nullptr) << "Could not get a valid pjrt_buffer"; ABSL_CHECK(!pjrt_buffer->IsTuple()) @@ -169,11 +170,13 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { // Reference: https://github.com/openxla/xla/blob/main/xla/python/dlpack.cc absl::StatusOr DeviceForDLDevice(const DLDevice& context) { switch (context.device_type) { - case DLDeviceType::kDLCPU: - XLA_CHECK_EQ(runtime::GetComputationClientOrDie()->GetPlatformID(), - xla::CpuId()); - return runtime::GetComputationClientOrDie()->LookupAddressableDevice( - context.device_id); + case DLDeviceType::kDLCPU: { + XLA_ASSIGN_OR_RETURN( + runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + XLA_CHECK_EQ(client->GetPlatformID(), xla::CpuId()); + return client->LookupAddressableDevice(context.device_id); + } default: return tsl::errors::InvalidArgument( "Unknown/unsupported DLPack device type %d", context.device_type); @@ -330,10 +333,11 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) { shape, *device->default_memory_space(), on_delete_callback)); ABSL_CHECK(pjrt_buffer.get() != nullptr) << "pjrt buffer is null."; + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); runtime::ComputationClient::DataPtr data = runtime::PjRtComputationClient::CreateData( - runtime::GetComputationClientOrDie()->PjRtDeviceToString(device), - shape, std::move(pjrt_buffer)); + client->PjRtDeviceToString(device), shape, std::move(pjrt_buffer)); at::ScalarType tensor_type = at::toScalarType(dlmt->dl_tensor.dtype); XLATensorPtr xla_tensor = XLATensor::Create(data, tensor_type); diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 1d409850b808..1d205dd86a7e 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -261,7 +261,9 @@ torch::lazy::BackendDevice GetDeviceOrCurrent(const std::string& device_str) { void WaitDeviceOps(absl::Span devices = {}) { XLAGraphExecutor::Get()->WaitDeviceOps(devices); - runtime::GetComputationClientOrDie()->WaitDeviceOps(devices); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + client->WaitDeviceOps(devices); } void PrepareToExit() { @@ -721,8 +723,10 @@ void StepMarker(const std::string& device_str, XLAGraphExecutor::Get()->MarkStep(device, reset_scope); bool debug_mode = runtime::sys_util::GetEnvBool("PT_XLA_DEBUG", false); if (TF_PREDICT_FALSE(debug_mode)) { - std::string report = runtime::metrics::CreatePerformanceReport( - runtime::GetComputationClientOrDie()->GetMetrics()); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + std::string report = + runtime::metrics::CreatePerformanceReport(client->GetMetrics()); if (!report.empty()) { std::string fout = runtime::sys_util::GetEnvString("PT_XLA_DEBUG_FILE", ""); @@ -972,8 +976,9 @@ py::dict GetMemoryInfo(const std::string& device_str) { { NoGilSection nogil; torch::lazy::BackendDevice device = GetDeviceOrCurrent(device_str); - mem_info = - runtime::GetComputationClientOrDie()->GetMemoryInfo(device.toString()); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + mem_info = client->GetMemoryInfo(device.toString()); } auto py_dict = py::dict(); py_dict["bytes_used"] = mem_info.bytes_used; @@ -1283,10 +1288,10 @@ class PyLoweringContext { lowering_ctx.GetParametersData(); // Fetch this parameter data - XLA_ASSIGN_OR_THROW( - std::vector literals, - runtime::GetComputationClientOrDie()->TransferFromDevice( - UnwrapXlaData(device_data))); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + XLA_ASSIGN_OR_THROW(std::vector literals, + client->TransferFromDevice(UnwrapXlaData(device_data))); // Create a mapping from paramater id to the tensor data std::unordered_map results; @@ -1527,10 +1532,11 @@ void InitXlaModuleBindings(py::module m) { xla::Shape global_shape = CreateComputationShapeFromTensor(tensor, nullptr); if (minibatch) { - int num_local_devices = - runtime::GetComputationClientOrDie()->GetLocalDevices().size(); - int num_global_devices = - runtime::GetComputationClientOrDie()->GetAllDevices().size(); + XLA_ASSIGN_OR_THROW( + runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + int num_local_devices = client->GetLocalDevices().size(); + int num_global_devices = client->GetAllDevices().size(); XLA_CHECK(tile_assignment.size() == num_global_devices) << "Minibatch sharding only supports sharding along the batch " "dimension"; @@ -1751,37 +1757,45 @@ void InitXlaModuleBindings(py::module m) { }) .def("_xla_get_devices", []() { + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); if (UseVirtualDevice()) { // Under SPMD context, there is only one virtual devices from // user perspective. - std::vector all_devices = - runtime::GetComputationClientOrDie()->GetAllDevices(); + std::vector all_devices = client->GetAllDevices(); all_devices.resize(1); return all_devices; } else { - return runtime::GetComputationClientOrDie()->GetLocalDevices(); + return client->GetLocalDevices(); } }) .def("_xla_get_platform_version", []() { - return runtime::GetComputationClientOrDie()->GetPlatformVersion(); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + return client->GetPlatformVersion(); }) .def("_xla_num_devices", []() -> int64_t { if (UseVirtualDevice()) { return 1; } else { - return runtime::GetComputationClientOrDie()->GetNumLocalDevices(); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + return client->GetNumLocalDevices(); } }) .def("_xla_num_global_devices", []() -> int64_t { - return runtime::GetComputationClientOrDie()->GetNumDevices(); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + return client->GetNumDevices(); }) .def("_xla_get_all_devices", []() { - std::vector all_devices = - runtime::GetComputationClientOrDie()->GetAllDevices(); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + std::vector all_devices = client->GetAllDevices(); if (UseVirtualDevice()) { // Under SPMD context, there is only one virtual devices from // user perspective. @@ -1792,22 +1806,31 @@ void InitXlaModuleBindings(py::module m) { } }) .def("_xla_get_runtime_devices", - []() { return runtime::GetComputationClientOrDie()->GetLocalDevices(); }) + []() { + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + return client->GetLocalDevices(); + }) .def("_xla_num_runtime_devices", []() -> int64_t { - return runtime::GetComputationClientOrDie()->GetNumLocalDevices(); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + return client->GetNumLocalDevices(); }) .def("_xla_get_all_runtime_devices", []() { - std::vector all_devices = - runtime::GetComputationClientOrDie()->GetAllDevices(); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + std::vector all_devices = client->GetAllDevices(); return all_devices; }) .def( "_xla_real_devices", [](const std::optional> devices) { if (!devices) { - return runtime::GetComputationClientOrDie()->GetLocalDevices(); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + return client->GetLocalDevices(); } std::vector xla_devices; @@ -1822,27 +1845,33 @@ void InitXlaModuleBindings(py::module m) { "_xla_device_kind", [](const std::string& device) { auto xla_device = bridge::AtenDeviceToXlaDevice(device).toString(); - return runtime::GetComputationClientOrDie()->GetDeviceKind(xla_device); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + return client->GetDeviceKind(xla_device); }, py::arg("device") = "") .def("_xla_set_replication_devices", [](const std::vector& devices) { auto replication_devices = std::make_shared>(devices); - runtime::GetComputationClientOrDie()->SetReplicationDevices( + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + client->SetReplicationDevices( std::move(replication_devices)); }) .def("_xla_get_replication_devices", []() { - auto replication_devices = - runtime::GetComputationClientOrDie()->GetReplicationDevices(); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + auto replication_devices = client->GetReplicationDevices(); return replication_devices != nullptr ? *replication_devices : std::vector(); }) .def("_xla_get_replication_devices_count", []() { - auto replication_devices = - runtime::GetComputationClientOrDie()->GetReplicationDevices(); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + auto replication_devices = client->GetReplicationDevices(); return replication_devices != nullptr ? replication_devices->size() : 0; }) @@ -2191,9 +2220,10 @@ void InitXlaModuleBindings(py::module m) { "_xla_create_placeholder_tensor", [](py::object py_shape) { xla::Shape shape = op_builder::PyShapeToShape(py_shape); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); auto xla_tensor = - XLATensor::Create(torch_xla::runtime::GetComputationClientOrDie() - ->CreateDataPlaceholder( + XLATensor::Create(client->CreateDataPlaceholder( bridge::GetCurrentDevice().toString(), std::move(shape))); return bridge::AtenFromXlaTensor(xla_tensor); @@ -2212,9 +2242,17 @@ void InitXlaModuleBindings(py::module m) { return device.ordinal(); }) .def("_xla_get_process_index", - []() { return runtime::GetComputationClientOrDie()->GetProcessIndex(); }) + []() { + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + return client->GetProcessIndex(); + }) .def("_xla_get_num_processes", - []() { return runtime::GetComputationClientOrDie()->GetNumProcesses(); }) + []() { + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + return client->GetNumProcesses(); + }) .def("_xla_get_num_cached_compilation_graph", []() -> int64_t { return XLAGraphExecutor::Get()->GetNumGraphHash(); @@ -2225,10 +2263,12 @@ void InitXlaModuleBindings(py::module m) { }) .def("_xla_get_device_attributes", [](const std::string& device_str) { + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); const absl::flat_hash_map< std::string, runtime::ComputationClient::DeviceAttribute> attributes = - runtime::GetComputationClientOrDie()->GetDeviceAttributes( + client->GetDeviceAttributes( bridge::AtenDeviceToXlaDevice(device_str).toString()); py::dict dict; @@ -2239,14 +2279,15 @@ void InitXlaModuleBindings(py::module m) { }) .def("_xla_get_all_device_attributes", []() { - std::vector global_devices = - runtime::GetComputationClientOrDie()->GetAllDevices(); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + std::vector global_devices = client->GetAllDevices(); std::vector list; for (auto const& device : global_devices) { const absl::flat_hash_map< std::string, runtime::ComputationClient::DeviceAttribute>& attributes = - runtime::GetComputationClientOrDie()->GetDeviceAttributes(device); + client->GetDeviceAttributes(device); py::dict dict; for (auto const& [name, value] : attributes) { dict[py::str(name)] = py::cast(value); @@ -2419,9 +2460,11 @@ void InitXlaModuleBindings(py::module m) { // cannot depend on PyTorch (as part of TensorFlow). // TODO(jwtan): Unify them once ComputationClient becomes a // standalone library. + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); return torch::lazy::CreateMetricReport() + runtime::metrics_reader::CreateMetricReport( - runtime::GetComputationClientOrDie()->GetMetrics()); + client->GetMetrics()); }) .def("_short_xla_metrics_report", [](const py::list& counter_names, const py::list& metric_names) { @@ -2689,8 +2732,9 @@ void InitXlaModuleBindings(py::module m) { std::optional>& global_shape) -> at::Tensor { XLA_CHECK(UseVirtualDevice()) << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; - auto local_devices = - runtime::GetComputationClientOrDie()->GetLocalDevices(); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + auto local_devices = client->GetLocalDevices(); XLA_CHECK(local_devices.size() == shards.size()) << "Must specify a shard for each local device"; XLA_CHECK(!global_shape.has_value() || @@ -2764,6 +2808,8 @@ void InitXlaModuleBindings(py::module m) { std::vector handles; std::vector element_types; // Find all shard handles for transfer + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); for (auto& tensor : input) { XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(tensor)); @@ -2775,7 +2821,7 @@ void InitXlaModuleBindings(py::module m) { std::dynamic_pointer_cast( xtensor->GetXlaData()); std::vector shard_handles = - runtime::GetComputationClientOrDie()->GetDataShards(handle); + client->GetDataShards(handle); handles.insert(handles.end(), shard_handles.begin(), shard_handles.end()); element_types.insert( @@ -2788,8 +2834,7 @@ void InitXlaModuleBindings(py::module m) { XlaDataToTensors(WrapXlaData(handles), element_types)); // Populate the resulting vector of shards and device strings std::vector>> result; - int shards_per_tensor = - runtime::GetComputationClientOrDie()->GetLocalDevices().size(); + int shards_per_tensor = client->GetLocalDevices().size(); result.reserve(cpu_shards.size() / shards_per_tensor); for (int i = 0; i < cpu_shards.size(); i += shards_per_tensor) { std::vector> shard_devices; @@ -2818,6 +2863,8 @@ void InitXlaModuleBindings(py::module m) { [](const std::vector& input_tensors) -> std::vector>> { std::vector>> result; + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); for (auto& tensor : input_tensors) { XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(tensor)); @@ -2827,7 +2874,7 @@ void InitXlaModuleBindings(py::module m) { std::dynamic_pointer_cast( xtensor->GetXlaData()); auto shards = - runtime::GetComputationClientOrDie()->GetDataShards(handle); + client->GetDataShards(handle); std::vector shard_devices; for (auto& shard : shards) { shard_devices.push_back(shard->device()); @@ -2881,8 +2928,9 @@ void InitXlaModuleBindings(py::module m) { bridge::GetXlaTensor(tensor)); XLA_CHECK(xtensor->sharding_spec() != nullptr) << "Cannot load local shards into a non sharded tensor"; - XLA_CHECK(devices.size() == - runtime::GetComputationClientOrDie()->GetLocalDevices().size()) + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + XLA_CHECK(devices.size() == client->GetLocalDevices().size()) << "Shards must be provided for all local devices"; auto sharding = xtensor->sharding_spec()->sharding; auto sharding_spec = xtensor->sharding_spec(); @@ -2907,10 +2955,10 @@ void InitXlaModuleBindings(py::module m) { "_ensure_xla_coordinator_initialized", [](int global_rank, int world_size, std::string master_addr, std::string master_port) { - auto comp_client = runtime::GetComputationClientOrDie(); - if (!comp_client->CoordinatorInitialized()) { - runtime::GetComputationClientOrDie()->InitializeCoordinator( - global_rank, world_size, master_addr, master_port); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + if (!client->CoordinatorInitialized()) { + client->InitializeCoordinator(global_rank, world_size, master_addr, master_port); } }, py::arg("global_rank"), // @@ -2924,10 +2972,11 @@ void InitXlaModuleBindings(py::module m) { // effect. "_activate_preemption_sync_manager", []() { - auto comp_client = runtime::GetComputationClientOrDie(); - XLA_CHECK(comp_client->CoordinatorInitialized()) + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + XLA_CHECK(client->CoordinatorInitialized()) << "Coordinator must be initialized"; - auto& coordinator = comp_client->GetCoordinator(); + auto& coordinator = client->GetCoordinator(); coordinator.ActivatePreemptionSyncManager(); }) .def( @@ -2935,10 +2984,11 @@ void InitXlaModuleBindings(py::module m) { // is active "_deactivate_preemption_sync_manager", []() { - auto comp_client = runtime::GetComputationClientOrDie(); - XLA_CHECK(comp_client->CoordinatorInitialized()) + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + XLA_CHECK(client->CoordinatorInitialized()) << "Coordinator must be initialized"; - auto& coordinator = comp_client->GetCoordinator(); + auto& coordinator = client->GetCoordinator(); coordinator.DeactivatePreemptionSyncManager(); }) .def( @@ -2947,10 +2997,11 @@ void InitXlaModuleBindings(py::module m) { // PreemptionSyncManager activated. "_sync_point_reached", [](int step) { - auto comp_client = runtime::GetComputationClientOrDie(); - XLA_CHECK(comp_client->CoordinatorInitialized()) + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + XLA_CHECK(client->CoordinatorInitialized()) << "Coordinator must be initialized"; - auto& coordinator = comp_client->GetCoordinator(); + auto& coordinator = client->GetCoordinator(); return coordinator.ReachedSyncPoint(step); }) .def("_is_placecholder", @@ -3058,8 +3109,9 @@ void InitXlaModuleBindings(py::module m) { .def("_xla_register_custom_call_target", [](const std::string& fn_name, const py::capsule& function_ptr, const std::string& platform) { - runtime::GetComputationClientOrDie()->RegisterCustomCall( - fn_name, function_ptr.get_pointer(), platform); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + client->RegisterCustomCall(fn_name, function_ptr.get_pointer(), platform); }) .def("_set_xla_custom_op_name_prefix", [](const at::Tensor& input, const std::string& op_name_prefix, @@ -3218,26 +3270,29 @@ void InitXlaModuleBindings(py::module m) { } XLA_ERROR() << "Could not get buffer for tensor"; } - runtime::GetComputationClientOrDie()->OnReadyCallback(data, - callback); + XLA_ASSIGN_OR_THROW( + runtime::ComputationClient * absl_nonnull client, + runtime::GetComputationClient()); + client->OnReadyCallback(data, callback); }) .def("_unsafe_buffer_pointer", [](const at::Tensor& input) -> std::uintptr_t { XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(input)); + XLA_ASSIGN_OR_THROW( + runtime::ComputationClient * absl_nonnull client, + runtime::GetComputationClient()); if (xtensor->CurrentDataHandle() != nullptr) { std::shared_ptr data = std::dynamic_pointer_cast( xtensor->CurrentDataHandle()); - return runtime::GetComputationClientOrDie()->UnsafeBufferPointer( - data); + return client->UnsafeBufferPointer(data); } else if (xtensor->CurrentIrValue().node != nullptr) { DeviceData* device_data = DeviceData::Cast(xtensor->CurrentIrValue().node.get()); if (device_data != nullptr) { torch::lazy::BackendDataPtr data = device_data->data(); - return runtime::GetComputationClientOrDie() - ->UnsafeBufferPointer(UnwrapXlaData(data)); + return client->UnsafeBufferPointer(UnwrapXlaData(data)); } else { XLA_ERROR() << "Could not get the buffer pointer for XLATensor " diff --git a/torch_xla/csrc/ir_dump_util.cpp b/torch_xla/csrc/ir_dump_util.cpp index 3453bd642c38..8a8fcd73dd44 100644 --- a/torch_xla/csrc/ir_dump_util.cpp +++ b/torch_xla/csrc/ir_dump_util.cpp @@ -274,15 +274,14 @@ std::string DumpUtil::ToHlo(c10::ArrayRef values, xla::Shape shape = MakeShapeWithDeviceLayout( program_shape.result(), static_cast(device.type())); std::vector instances; - instances.push_back( - {std::move(computation), device.toString(), - runtime::GetComputationClientOrDie()->GetCompilationDevices( - device.toString(), {}), - &shape, - /*parameter_is_tupled_arguments=*/false, is_sharded}); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + instances.push_back({std::move(computation), device.toString(), + client->GetCompilationDevices(device.toString(), {}), + &shape, + /*parameter_is_tupled_arguments=*/false, is_sharded}); std::vector> - computations = - runtime::GetComputationClientOrDie()->Compile(std::move(instances)); + computations = client->Compile(std::move(instances)); computation = std::move(computations[0]->move_computation()); } diff --git a/torch_xla/csrc/ops/device_data.cpp b/torch_xla/csrc/ops/device_data.cpp index a5f5536b5b67..bc04820a8033 100644 --- a/torch_xla/csrc/ops/device_data.cpp +++ b/torch_xla/csrc/ops/device_data.cpp @@ -16,9 +16,10 @@ DeviceData::DeviceData(std::shared_ptr data) /*num_outputs=*/1, /*hash_seed=*/(uint32_t)101), data_(std::move(data)) { - std::optional op_sharding = - torch_xla::runtime::GetComputationClientOrDie()->GetDataSharding( - std::dynamic_pointer_cast(data_)); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + std::optional op_sharding = client->GetDataSharding( + std::dynamic_pointer_cast(data_)); if (op_sharding.has_value()) { // DeviceData Node only has 1 output. SetSharding(op_sharding.value(), 0); diff --git a/torch_xla/csrc/runtime/runtime.cpp b/torch_xla/csrc/runtime/runtime.cpp index 3836f6975719..4dced2531bd5 100644 --- a/torch_xla/csrc/runtime/runtime.cpp +++ b/torch_xla/csrc/runtime/runtime.cpp @@ -60,14 +60,14 @@ const absl::StatusOr& GetComputationClient() { return maybe_client; } -ComputationClient* absl_nonnull GetComputationClientOrDie() { - XLA_ASSIGN_OR_THROW(ComputationClient * client, GetComputationClient()); - return client; -} - ComputationClient* GetComputationClientIfInitialized() { - return g_computation_client_initialized ? GetComputationClientOrDie() - : nullptr; + if (!g_computation_client_initialized) { + return nullptr; + } + const absl::StatusOr& client = + GetComputationClient(); + XLA_CHECK_OK(client); + return client.value(); } } // namespace torch_xla::runtime diff --git a/torch_xla/csrc/runtime/runtime.h b/torch_xla/csrc/runtime/runtime.h index 6a1588935e6f..9d9c3f59f29e 100644 --- a/torch_xla/csrc/runtime/runtime.h +++ b/torch_xla/csrc/runtime/runtime.h @@ -10,13 +10,6 @@ namespace torch_xla::runtime { // Returns the ComputationClient singleton. const absl::StatusOr& GetComputationClient(); -ABSL_DEPRECATED( - "Use GetComputationClient(), instead. " - "This function throws an exception on error, instead of " - "actually handling the StatusOr return value, which is " - "safer.") -ComputationClient* absl_nonnull GetComputationClientOrDie(); - // Returns the ComputationClient singleton if it was successfully initialized. // Returns a nullptr if the ComputationClient wasn't initialized yet. // Throws an exception if the ComputationClient was initialized but the From 49dec2ee1d90e8a861de4d24f0837e5f23d9a05a Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Wed, 10 Sep 2025 01:55:15 -0700 Subject: [PATCH 101/133] Upgrade build infra to use debian-12 and gcc-11 (#9631) Similar to https://github.com/pytorch/xla/pull/5451 --- infra/ansible/config/apt.yaml | 4 ++-- infra/ansible/config/env.yaml | 8 ++++---- infra/ansible/config/vars.yaml | 2 +- infra/ansible/development.Dockerfile | 2 +- infra/ansible/e2e_tests.Dockerfile | 2 +- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/infra/ansible/config/apt.yaml b/infra/ansible/config/apt.yaml index ae3d95468344..dcb63cfa2b71 100644 --- a/infra/ansible/config/apt.yaml +++ b/infra/ansible/config/apt.yaml @@ -15,8 +15,8 @@ apt: - wget - clang-format - clang-{{ clang_version }} - - gcc-10 - - g++-10 + - gcc-11 + - g++-11 - lcov - less diff --git a/infra/ansible/config/env.yaml b/infra/ansible/config/env.yaml index 4fd733b494f7..7271375b2b5e 100644 --- a/infra/ansible/config/env.yaml +++ b/infra/ansible/config/env.yaml @@ -3,8 +3,8 @@ release_env: common: # Force GCC because clang/bazel has issues. - CC: gcc-10 - CXX: g++-10 + CC: gcc-11 + CXX: g++-11 # CC: "clang-{{ clang_version }}" # CXX: "clang++-{{ clang_version }}" LD_LIBRARY_PATH: "$LD_LIBRARY_PATH:/usr/local/lib" @@ -20,8 +20,8 @@ build_env: # Set explicitly to 0 as setup.py defaults this flag to true if unset. BUILD_CPP_TESTS: "{{ build_cpp_tests }}" # Force GCC because clang/bazel has issues. - CC: gcc-10 - CXX: g++-10 + CC: gcc-11 + CXX: g++-11 PYTORCH_BUILD_NUMBER: 1 TORCH_XLA_VERSION: "{{ package_version }}" PYTORCH_BUILD_VERSION: "{{ package_version }}" diff --git a/infra/ansible/config/vars.yaml b/infra/ansible/config/vars.yaml index c336e7754f46..9307e2c01289 100644 --- a/infra/ansible/config/vars.yaml +++ b/infra/ansible/config/vars.yaml @@ -1,5 +1,5 @@ # Used for fetching clang from the right repo, see apt.yaml. -llvm_debian_repo: bullseye +llvm_debian_repo: bookworm clang_version: 17 # PyTorch and PyTorch/XLA wheel versions. package_version: 2.9.0 diff --git a/infra/ansible/development.Dockerfile b/infra/ansible/development.Dockerfile index 2b0bc4ad5323..5efe81cd43e6 100644 --- a/infra/ansible/development.Dockerfile +++ b/infra/ansible/development.Dockerfile @@ -2,7 +2,7 @@ # The built image contains all required pip and apt packages for building and # running PyTorch and PyTorch/XLA. The image doesn't contain any source code. ARG python_version=3.8 -ARG debian_version=bullseye +ARG debian_version=bookworm FROM python:${python_version}-${debian_version} diff --git a/infra/ansible/e2e_tests.Dockerfile b/infra/ansible/e2e_tests.Dockerfile index 2a097e803f0c..96f5ecd0700d 100644 --- a/infra/ansible/e2e_tests.Dockerfile +++ b/infra/ansible/e2e_tests.Dockerfile @@ -1,5 +1,5 @@ ARG python_version=3.8 -ARG debian_version=bullseye +ARG debian_version=bookworm FROM python:${python_version}-${debian_version} AS build From caa809f74fcc1fddd9f036a26a130c40ae3e34ec Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Wed, 10 Sep 2025 02:36:47 -0700 Subject: [PATCH 102/133] Remove libopenblas-dev from ansible dependencies (#9632) libopenblas-dev fails the docker build as it's not available in debian-12 --- infra/ansible/Dockerfile | 2 +- infra/ansible/config/apt.yaml | 1 - infra/tpu-pytorch-releases/README.md | 8 ++++++++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/infra/ansible/Dockerfile b/infra/ansible/Dockerfile index f4c1021f463d..537c2e990035 100644 --- a/infra/ansible/Dockerfile +++ b/infra/ansible/Dockerfile @@ -1,5 +1,5 @@ ARG python_version=3.8 -ARG debian_version=bullseye +ARG debian_version=bookworm FROM python:${python_version}-${debian_version} AS build diff --git a/infra/ansible/config/apt.yaml b/infra/ansible/config/apt.yaml index dcb63cfa2b71..29b32946fd47 100644 --- a/infra/ansible/config/apt.yaml +++ b/infra/ansible/config/apt.yaml @@ -28,7 +28,6 @@ apt: - git - gnupg - libgomp1 - - libopenblas-base - patch - vim diff --git a/infra/tpu-pytorch-releases/README.md b/infra/tpu-pytorch-releases/README.md index a70e0b064a6e..ad564d671697 100644 --- a/infra/tpu-pytorch-releases/README.md +++ b/infra/tpu-pytorch-releases/README.md @@ -83,6 +83,14 @@ unset properties of existing triggers. 5. See section [Manually trigger a Cloud Build](#manually-trigger-a-cloud-build) to manually trigger the created build and produce all the artifacts. +### Build development docker locally for testing + +Sample command to build the development docker container locally for testing: +``` +cd infra/ansible +docker build -f development.Dockerfile --build-arg=python_version=3.12 --build-arg=ansible_vars='{"xla_git_rev":"master", "pytorch_git_rev":"main", "accelerator":"tpu", "arch": "amd64", "python_version":"3.12"}' +``` + ### Nightly releases From 7aba922cde76d6140bf33f2eb60f107ae4a6aa30 Mon Sep 17 00:00:00 2001 From: Junjie Qian Date: Wed, 10 Sep 2025 10:46:34 -0700 Subject: [PATCH 103/133] support load and save checkpoint in torchax (#9616) This PR supports checkpointing with torchax: 1. load a checkpoint file in torch tensors and convert to Jax arrays; Or load a checkpoint file in Jax arrays 2. save a checkpoint file in Jax arrays. This support single worker now. --- .github/workflows/_test.yml | 1 + .github/workflows/_tpu_ci.yml | 1 + torchax/README.md | 36 +++++++++++ torchax/test/test_checkpoint.py | 102 ++++++++++++++++++++++++++++++++ torchax/torchax/__init__.py | 4 +- torchax/torchax/checkpoint.py | 60 +++++++++++++++++++ 6 files changed, 203 insertions(+), 1 deletion(-) create mode 100644 torchax/test/test_checkpoint.py create mode 100644 torchax/torchax/checkpoint.py diff --git a/.github/workflows/_test.yml b/.github/workflows/_test.yml index 23ffe34f8a46..6c2175117e57 100644 --- a/.github/workflows/_test.yml +++ b/.github/workflows/_test.yml @@ -111,6 +111,7 @@ jobs: # TODO: Add these in setup.py pip install fsspec pip install rich + pip install flax - name: Checkout PyTorch Repo if: inputs.has_code_changes == 'true' uses: actions/checkout@v4 diff --git a/.github/workflows/_tpu_ci.yml b/.github/workflows/_tpu_ci.yml index dc766c53a897..b67f695f81ed 100644 --- a/.github/workflows/_tpu_ci.yml +++ b/.github/workflows/_tpu_ci.yml @@ -55,6 +55,7 @@ jobs: pip install --pre 'torch_xla[pallas]' --index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html' pip install --pre 'torch_xla[tpu]' --index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html' pip install --upgrade protobuf + pip install flax - name: Run Tests (${{ matrix.test_script }}) if: inputs.has_code_changes == 'true' env: diff --git a/torchax/README.md b/torchax/README.md index 2b1fa8d58f33..57a212b4838d 100644 --- a/torchax/README.md +++ b/torchax/README.md @@ -182,6 +182,42 @@ The first time `m_jitted` is called, it will trigger `jax.jit` to compile the compile for the given input shapes. Subsequent calls with the same input shapes will be fast as the compilation is cached. +## Saving and Loading Checkpoints + +You can use `torchax.save_checkpoint` and `torchax.load_checkpoint` to save and load your training state. The state can be a dictionary containing the model's weights, optimizer state, and any other information you want to save. + +```python +import torchax +import torch +import optax + +# Assume model, optimizer, and other states are defined +model = MyModel() +optimizer = optax.adam(1e-3) +opt_state = optimizer.init(model.parameters()) +weights = model.parameters() +buffers = model.buffers() +epoch = 10 + +state = { + 'weights': weights, + 'buffers': buffers, + 'opt_state': opt_state, + 'epoch': epoch, +} + +# Save checkpoint +torchax.save_checkpoint(state, '/path/to/checkpoint.pt') + +# Load checkpoint +loaded_state = torchax.load_checkpoint('/path/to/checkpoint.pt') + +# Restore state +model.load_state_dict(loaded_state['weights']) +opt_state = loaded_state['opt_state'] +epoch = loaded_state['epoch'] +``` + ## Citation ``` diff --git a/torchax/test/test_checkpoint.py b/torchax/test/test_checkpoint.py new file mode 100644 index 000000000000..4867d44b1eb8 --- /dev/null +++ b/torchax/test/test_checkpoint.py @@ -0,0 +1,102 @@ +import unittest +import torch +import torch.nn as nn +import torchax +from torchax.checkpoint import _to_torch, _to_jax +import optax +import tempfile +import os +import jax +import jax.numpy as jnp +import shutil + + +class CheckpointTest(unittest.TestCase): + + def test_save_and_load_jax_style_checkpoint(self): + model = torch.nn.Linear(10, 20) + optimizer = optax.adam(1e-3) + + torchax.enable_globally() + params_jax, _ = torchax.extract_jax(model) + opt_state = optimizer.init(params_jax) + torchax.disable_globally() + + epoch = 1 + state = { + 'model': model.state_dict(), + 'opt_state': opt_state, + 'epoch': epoch, + } + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, 'checkpoint') + torchax.save_checkpoint(state, path, step=epoch) + loaded_state_jax = torchax.load_checkpoint(path) + loaded_state = _to_torch(loaded_state_jax) + + self.assertEqual(state['epoch'], loaded_state['epoch']) + + # Compare model state_dict + for key in state['model']: + self.assertTrue( + torch.allclose(state['model'][key], loaded_state['model'][key])) + + # Compare optimizer state + original_leaves = jax.tree_util.tree_leaves(state['opt_state']) + loaded_leaves = jax.tree_util.tree_leaves(loaded_state['opt_state']) + for original_leaf, loaded_leaf in zip(original_leaves, loaded_leaves): + if isinstance(original_leaf, (jnp.ndarray, jax.Array)): + # Convert loaded leaf to numpy array for comparison if it is a DeviceArray + self.assertTrue(jnp.allclose(original_leaf, jnp.asarray(loaded_leaf))) + else: + self.assertEqual(original_leaf, loaded_leaf) + + def test_load_pytorch_style_checkpoint(self): + model = torch.nn.Linear(10, 20) + optimizer = optax.adam(1e-3) + + torchax.enable_globally() + params_jax, _ = torchax.extract_jax(model) + opt_state = optimizer.init(params_jax) + torchax.disable_globally() + + epoch = 1 + state = { + 'model': model.state_dict(), + 'opt_state': opt_state, + 'epoch': epoch, + } + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, 'checkpoint.pt') + torch.save(state, path) + loaded_state_jax = torchax.load_checkpoint(path) + + # convert original state to jax for comparison + state_jax = _to_jax(state) + + self.assertEqual(state_jax['epoch'], loaded_state_jax['epoch']) + + # Compare model state_dict + for key in state_jax['model']: + self.assertTrue( + jnp.allclose(state_jax['model'][key], + loaded_state_jax['model'][key])) + + # Compare optimizer state + original_leaves = jax.tree_util.tree_leaves(state_jax['opt_state']) + loaded_leaves = jax.tree_util.tree_leaves(loaded_state_jax['opt_state']) + for original_leaf, loaded_leaf in zip(original_leaves, loaded_leaves): + if isinstance(original_leaf, (jnp.ndarray, jax.Array)): + self.assertTrue(jnp.allclose(original_leaf, loaded_leaf)) + else: + self.assertEqual(original_leaf, loaded_leaf) + + def test_load_non_existent_checkpoint(self): + with self.assertRaises(FileNotFoundError): + torchax.load_checkpoint('/path/to/non_existent_checkpoint') + + +if __name__ == '__main__': + unittest.main() diff --git a/torchax/torchax/__init__.py b/torchax/torchax/__init__.py index 240cd70175a7..d5e964416dcd 100644 --- a/torchax/torchax/__init__.py +++ b/torchax/torchax/__init__.py @@ -15,9 +15,11 @@ 'default_env', 'extract_jax', 'enable_globally', + 'save_checkpoint', + 'load_checkpoint', ] -from jax._src import xla_bridge +from .checkpoint import save_checkpoint, load_checkpoint os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1') diff --git a/torchax/torchax/checkpoint.py b/torchax/torchax/checkpoint.py new file mode 100644 index 000000000000..daded1c3afad --- /dev/null +++ b/torchax/torchax/checkpoint.py @@ -0,0 +1,60 @@ +import torch +import os +from typing import Any, Dict +from flax.training import checkpoints +import jax +import jax.numpy as jnp +import numpy as np + + +def _to_jax(pytree): + return jax.tree_util.tree_map( + lambda x: jnp.asarray(x.cpu().numpy()) + if isinstance(x, torch.Tensor) else x, pytree) + + +def _to_torch(pytree): + return jax.tree_util.tree_map( + lambda x: torch.from_numpy(np.asarray(x)) + if isinstance(x, (jnp.ndarray, jax.Array)) else x, pytree) + + +def save_checkpoint(state: Dict[str, Any], path: str, step: int): + """Saves a checkpoint to a file in JAX style. + + Args: + state: A dictionary containing the state to save. torch.Tensors will be + converted to jax.Array. + path: The path to save the checkpoint to. This is a directory. + step: The training step. + """ + state = _to_jax(state) + checkpoints.save_checkpoint(path, state, step=step, overwrite=True) + + +def load_checkpoint(path: str) -> Dict[str, Any]: + """Loads a checkpoint and returns it in JAX format. + + This function can load both PyTorch-style (single file) and JAX-style + (directory) checkpoints. + + If the checkpoint is in PyTorch format, it will be converted to JAX format. + + Args: + path: The path to the checkpoint. + + Returns: + The loaded state in JAX format (pytree with jax.Array leaves). + """ + if os.path.isdir(path): + # JAX-style checkpoint + state = checkpoints.restore_checkpoint(path, target=None) + if state is None: + raise FileNotFoundError(f"No checkpoint found at {path}") + return state + elif os.path.isfile(path): + # PyTorch-style checkpoint + state = torch.load(path, weights_only=False) + return _to_jax(state) + else: + raise FileNotFoundError(f"No such file or directory: {path}") From 8efa5682e605d5a5d967adb84688970658855e86 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 11 Sep 2025 15:16:50 -0300 Subject: [PATCH 104/133] Set `allow_broken_conditionals` configuration variable at `ansible.cfg`. (#9634) Currently, CI is breaking with: ``` Wednesday 10 September 2025 14:27:33 +0000 (0:00:00.273) 0:16:18.634 *** Error: : Task failed: Conditional result (True) was derived from value of type 'str' at "". Conditionals must have a boolean result. ``` This is likely because the installed `ansible` is now on version `2.19.2`. The error is documented [here](https://docs.ansible.com/ansible/latest/porting_guides/porting_guide_core_2.19.html#example-implicit-boolean-conversion). ~In order to quickly fix it, I have introduced `ALLOW_BROKEN_CONDITIONALS=1` environment variable before `ansible-playbook` executions.~ **Key Changes:** - Add `allow_broken_conditionals = true` to ansible configuration file `ansible.cfg` --- infra/ansible/ansible.cfg | 5 ++++- test/stablehlo/test_unbounded_dynamism.py | 7 +++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/infra/ansible/ansible.cfg b/infra/ansible/ansible.cfg index 490c16aea505..fba3ad26351e 100644 --- a/infra/ansible/ansible.cfg +++ b/infra/ansible/ansible.cfg @@ -9,8 +9,11 @@ callbacks_enabled = profile_tasks localhost_warning = False # Make output human-readable. stdout_callback = yaml +# Ansible 2.19 requires this environment variable being set, so that we can use +# string variables as boolean. +allow_broken_conditionals = true [inventory] # Silence warning about no inventory. # This option is available since Ansible 2.14 (available only with Python 3.9+). -inventory_unparsed_warning = False \ No newline at end of file +inventory_unparsed_warning = False diff --git a/test/stablehlo/test_unbounded_dynamism.py b/test/stablehlo/test_unbounded_dynamism.py index 88fce368b668..4bbdd4989702 100644 --- a/test/stablehlo/test_unbounded_dynamism.py +++ b/test/stablehlo/test_unbounded_dynamism.py @@ -25,6 +25,7 @@ class UnboundedDynamismExportTest(unittest.TestCase): + @unittest.skip("https://github.com/pytorch/xla/issues/9637") def test_add(self): args = (torch.rand((10, 197, 768)), torch.rand((10, 197, 768))) dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),) @@ -78,6 +79,7 @@ def test_addmm(self): # Hit stablehlo.dot shape refinement error when inferencing saved_model in TF. compare_exported_program_and_saved_model_result(ep, tempdir, args) + @unittest.skip("https://github.com/pytorch/xla/issues/9637") def test_bmm(self): args = ( torch.rand((24, 197, 64)), @@ -120,6 +122,7 @@ def test_bmm_dynamic_out_dim(self): self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb'))) compare_exported_program_and_saved_model_result(ep, tempdir, args) + @unittest.skip("https://github.com/pytorch/xla/issues/9637") def test_bmm_dynamic_reduction_dim(self): args = ( torch.rand((8, 128, 3)), @@ -141,6 +144,7 @@ def test_bmm_dynamic_reduction_dim(self): self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb'))) compare_exported_program_and_saved_model_result(ep, tempdir, args) + @unittest.skip("https://github.com/pytorch/xla/issues/9637") def test_cat(self): args = (torch.rand((10, 1, 768)), torch.rand((10, 196, 768))) dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),) @@ -240,6 +244,7 @@ def test_cumsum(self): self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb'))) compare_exported_program_and_saved_model_result(ep, tempdir, args) + @unittest.skip("https://github.com/pytorch/xla/issues/9637") def test_div(self): args = (torch.rand((10, 12, 197)), torch.rand((10, 12, 197))) dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),) @@ -340,6 +345,7 @@ def forward(self, x): self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb'))) compare_exported_program_and_saved_model_result(ep, tempdir, args) + @unittest.skip("https://github.com/pytorch/xla/issues/9637") def test_mul(self): args = (torch.rand((10, 2, 768)), torch.rand((10, 2, 768))) dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),) @@ -571,6 +577,7 @@ def test_softmax(self): self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb'))) compare_exported_program_and_saved_model_result(ep, tempdir, args) + @unittest.skip("https://github.com/pytorch/xla/issues/9637") def test_sub(self): args = (torch.rand((10, 1, 1, 10)), torch.rand((10, 1, 1, 10))) dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),) From c77852e117bdf056c8e9a087e51d6f65cf6ba53d Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 11 Sep 2025 16:45:45 -0300 Subject: [PATCH 105/133] Move torch ops error message tests into a new file. (#9622) In summary, this PR: - Moves tests that checked error message of PyTorch operations into `test_ops_error_message.py` - Introduces `expecttest` as a dependency in `requirements.in` file - Introduces `expecttest` to those tests, so as to avoid copy-and-pasting error messages The introduction of `expecttest` Python package was, in fact, implicit to our tests because of the following PyTorch testing library import: https://github.com/pytorch/xla/blob/f6ff30d3c2cd837e940aaa70b61faf948aa805f7/test/test_operations.py#L33 --- requirements.in | 1 + requirements_lock_3_10.txt | 4 + requirements_lock_3_11.txt | 4 + requirements_lock_3_12.txt | 4 + requirements_lock_3_13.txt | 4 + requirements_lock_3_8.txt | 4 + requirements_lock_3_9.txt | 4 + test/test_operations.py | 164 ----------------------------- test/test_ops_error_message.py | 181 +++++++++++++++++++++++++++++++++ 9 files changed, 206 insertions(+), 164 deletions(-) create mode 100644 test/test_ops_error_message.py diff --git a/requirements.in b/requirements.in index 6a41033d3ed2..acb0a55ab83a 100644 --- a/requirements.in +++ b/requirements.in @@ -1,3 +1,4 @@ +expecttest filelock fsspec jinja2 diff --git a/requirements_lock_3_10.txt b/requirements_lock_3_10.txt index 225f30d14432..9a4f8f8471c3 100644 --- a/requirements_lock_3_10.txt +++ b/requirements_lock_3_10.txt @@ -4,6 +4,10 @@ # # bazel run //:requirements.update # +expecttest==0.3.0 \ + --hash=sha256:60f88103086e1754240b42175f622be83b6ffeac419434691ee5a5be819d0544 \ + --hash=sha256:6e8512fb86523ada1f94fd1b14e280f924e379064bb8a29ee399950e513eeccd + # via -r requirements.in filelock==3.14.0 \ --hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \ --hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a diff --git a/requirements_lock_3_11.txt b/requirements_lock_3_11.txt index 78862541e948..350e692590c0 100644 --- a/requirements_lock_3_11.txt +++ b/requirements_lock_3_11.txt @@ -4,6 +4,10 @@ # # bazel run //:requirements.update # +expecttest==0.3.0 \ + --hash=sha256:60f88103086e1754240b42175f622be83b6ffeac419434691ee5a5be819d0544 \ + --hash=sha256:6e8512fb86523ada1f94fd1b14e280f924e379064bb8a29ee399950e513eeccd + # via -r requirements.in filelock==3.14.0 \ --hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \ --hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a diff --git a/requirements_lock_3_12.txt b/requirements_lock_3_12.txt index 9f6d690140ee..4507e4b1baff 100644 --- a/requirements_lock_3_12.txt +++ b/requirements_lock_3_12.txt @@ -4,6 +4,10 @@ # # bazel run //:requirements.update # +expecttest==0.3.0 \ + --hash=sha256:60f88103086e1754240b42175f622be83b6ffeac419434691ee5a5be819d0544 \ + --hash=sha256:6e8512fb86523ada1f94fd1b14e280f924e379064bb8a29ee399950e513eeccd + # via -r requirements.in filelock==3.18.0 \ --hash=sha256:adbc88eabb99d2fec8c9c1b229b171f18afa655400173ddc653d5d01501fb9f2 \ --hash=sha256:c401f4f8377c4464e6db25fff06205fd89bdd83b65eb0488ed1b160f780e21de diff --git a/requirements_lock_3_13.txt b/requirements_lock_3_13.txt index 5f288e816a59..3b91c897a93d 100644 --- a/requirements_lock_3_13.txt +++ b/requirements_lock_3_13.txt @@ -4,6 +4,10 @@ # # bazel run //:requirements.update # +expecttest==0.3.0 \ + --hash=sha256:60f88103086e1754240b42175f622be83b6ffeac419434691ee5a5be819d0544 \ + --hash=sha256:6e8512fb86523ada1f94fd1b14e280f924e379064bb8a29ee399950e513eeccd + # via -r requirements.in filelock==3.18.0 \ --hash=sha256:adbc88eabb99d2fec8c9c1b229b171f18afa655400173ddc653d5d01501fb9f2 \ --hash=sha256:c401f4f8377c4464e6db25fff06205fd89bdd83b65eb0488ed1b160f780e21de diff --git a/requirements_lock_3_8.txt b/requirements_lock_3_8.txt index 022d1e07f3ea..79352c5f3fa7 100644 --- a/requirements_lock_3_8.txt +++ b/requirements_lock_3_8.txt @@ -4,6 +4,10 @@ # # bazel run //:requirements.update # +expecttest==0.3.0 \ + --hash=sha256:60f88103086e1754240b42175f622be83b6ffeac419434691ee5a5be819d0544 \ + --hash=sha256:6e8512fb86523ada1f94fd1b14e280f924e379064bb8a29ee399950e513eeccd + # via -r requirements.in filelock==3.14.0 \ --hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \ --hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a diff --git a/requirements_lock_3_9.txt b/requirements_lock_3_9.txt index a01cb47146da..c02de8e8ef1a 100644 --- a/requirements_lock_3_9.txt +++ b/requirements_lock_3_9.txt @@ -4,6 +4,10 @@ # # bazel run //:requirements.update # +expecttest==0.3.0 \ + --hash=sha256:60f88103086e1754240b42175f622be83b6ffeac419434691ee5a5be819d0544 \ + --hash=sha256:6e8512fb86523ada1f94fd1b14e280f924e379064bb8a29ee399950e513eeccd + # via -r requirements.in filelock==3.14.0 \ --hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \ --hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a diff --git a/test/test_operations.py b/test/test_operations.py index 635def9634bc..a02bb7466660 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -88,11 +88,6 @@ def skipIfFunctionalizationDisabled(reason): return _skipIfFunctionalization(value=True, reason=reason) -def onlyOnCPU(fn): - accelerator = os.environ.get("PJRT_DEVICE").lower() - return unittest.skipIf(accelerator != "cpu", "PJRT_DEVICE=CPU required")(fn) - - def onlyIfXLAExperimentalContains(feat): experimental = os.environ.get("XLA_EXPERIMENTAL", "").split(":") return unittest.skipIf(feat not in experimental, @@ -2372,165 +2367,6 @@ def test_isneginf_no_fallback(self): t = t.to(torch.float16) self._test_no_fallback(torch.isneginf, (t,)) - def test_add_broadcast_error(self): - a = torch.rand(2, 2, 4, 4, device="xla") - b = torch.rand(2, 2, device="xla") - - expected_regex = ( - r"Shapes are not compatible for broadcasting: f32\[2,2,4,4\] vs. f32\[2,2\]. " - r"Expected dimension 2 of shape f32\[2,2,4,4\] \(4\) to match dimension " - r"0 of shape f32\[2,2\] \(2\). .*") - - with self.assertRaisesRegex(RuntimeError, expected_regex): - torch.add(a, b) - torch_xla.sync() - - @onlyOnCPU - def test_construct_large_tensor_raises_error(self): - with self.assertRaisesRegex(RuntimeError, - r"Out of memory allocating \d+ bytes"): - # When eager-mode is enabled, OOM is triggered here. - a = torch.rand(1024, 1024, 1024, 1024, 1024, device=torch_xla.device()) - b = a.sum() - # OOM is raised when we try to bring data from the device. - b.cpu() - - def test_cat_raises_error_on_incompatible_shapes(self): - a = torch.rand(2, 2, device=torch_xla.device()) - b = torch.rand(5, 1, device=torch_xla.device()) - - try: - torch.cat([a, b]) - except RuntimeError as e: - expected_error = ( - "cat(): cannot concatenate tensors of shape f32[2,2] with f32[5,1] " - "at dimension 0. Expected shapes to be equal (except at dimension 0) " - "or that either of them was a 1D empty tensor of size (0,).") - self.assertEqual(str(e), expected_error) - - def test_div_raises_error_on_invalid_rounding_mode(self): - a = torch.rand(2, 2, device=torch_xla.device()) - - try: - torch.div(a, 2, rounding_mode="bad") - except RuntimeError as e: - expected_error = ( - "div(): invalid rounding mode `bad`. Expected it to be either " - "'trunc', 'floor', or be left unspecified.") - self.assertEqual(str(e), expected_error) - - def test_flip_raises_error_on_duplicated_dims(self): - a = torch.rand(2, 2, 2, 2, device=torch_xla.device()) - dims = [0, 0, 0, 1, 2, 3, -1] - dims_suggestion = [0, 1, 2, 3] - - try: - torch.flip(a, dims=dims) - except RuntimeError as e: - expected_error = ( - "flip(): expected each dimension to appear at most once. Found " - "dimensions: 0 (3 times), 3 (2 times). Consider changing dims " - f"from {dims} to {dims_suggestion}.") - self.assertEqual(str(e), expected_error) - - def test_full_raises_error_on_negative_size(self): - shape = [2, -2, 2] - try: - torch.full(shape, 1.5, device="xla") - except RuntimeError as e: - expected_error = ( - "full(): expected concrete sizes (i.e. non-symbolic) to be " - f"positive values. However found negative ones: {shape}.") - self.assertEqual(str(e), expected_error) - - def test_gather_raises_error_on_rank_mismatch(self): - S = 2 - - input = torch.arange(4, device=torch_xla.device()).view(S, S) - index = torch.randint(0, S, (S, S, S), device=torch_xla.device()) - dim = 1 - - try: - torch.gather(input, dim, index) - except RuntimeError as e: - expected_error = ( - "gather(): expected rank of input (2) and index (3) tensors " - "to be the same.") - self.assertEqual(str(e), expected_error) - - def test_gather_raises_error_on_invalid_index_size(self): - S = 2 - X = S + 2 - - input = torch.arange(16, device=torch_xla.device()).view(S, S, S, S) - index = torch.randint(0, S, (X, S, X, S), device=torch_xla.device()) - dim = 1 - - try: - torch.gather(input, dim, index) - except RuntimeError as e: - expected_error = ( - f"gather(): expected sizes of index [{X}, {S}, {X}, {S}] to be " - f"smaller or equal those of input [{S}, {S}, {S}, {S}] on all " - f"dimensions, except on dimension {dim}. " - "However, that's not true on dimensions [0, 2].") - self.assertEqual(str(e), expected_error) - - def test_random__raises_error_on_empty_interval(self): - a = torch.empty(10, device=torch_xla.device()) - from_ = 3 - to_ = 1 - - try: - a.random_(from_, to_) - except RuntimeError as e: - expected_error = ( - f"random_(): expected `from` ({from_}) to be smaller than " - f"`to` ({to_}).") - self.assertEqual(str(e), expected_error) - - def test_random__raises_error_on_value_out_of_type_value_range(self): - a = torch.empty(10, device=torch_xla.device(), dtype=torch.float16) - from_ = 3 - to_ = 65504 + 1 - - try: - a.random_(from_, to_) - except RuntimeError as e: - expected_error = ( - f"random_(): expected `to` to be within the range " - f"[-65504, 65504]. However got value {to_}, which is greater " - "than the upper bound.") - self.assertEqual(str(e), expected_error) - - def test_mm_raises_error_on_non_matrix_input(self): - device = torch_xla.device() - a = torch.rand(2, 2, 2, device=device) - b = torch.rand(2, 2, device=device) - - try: - torch.mm(a, b) - except RuntimeError as e: - expected_error = ( - "mm(): expected the first input tensor f32[2,2,2] to be a " - "matrix (i.e. a 2D tensor).") - self.assertEqual(str(e), expected_error) - - def test_mm_raises_error_on_incompatible_shapes(self): - device = torch_xla.device() - a = torch.rand(2, 5, device=device) - b = torch.rand(8, 2, device=device) - - try: - torch.mm(a, b) - except RuntimeError as e: - expected_error = ( - "mm(): cannot matrix-multiply tensors f32[2,5] and f32[8,2]. " - "Expected the size of dimension 1 of the first input tensor (5) " - "to be equal the size of dimension 0 of the second input " - "tensor (8).") - self.assertEqual(str(e), expected_error) - class MNISTComparator(nn.Module): diff --git a/test/test_ops_error_message.py b/test/test_ops_error_message.py new file mode 100644 index 000000000000..bbb9f4b95b77 --- /dev/null +++ b/test/test_ops_error_message.py @@ -0,0 +1,181 @@ +import expecttest +import os +import torch +import torch_xla +import unittest + + +def onlyOnCPU(fn): + accelerator = os.environ.get("PJRT_DEVICE").lower() + return unittest.skipIf(accelerator != "cpu", "PJRT_DEVICE=CPU required")(fn) + + +class TestOpsErrorMessage(expecttest.TestCase): + + def test_add_broadcast_error(self): + a = torch.rand(2, 2, 4, 4, device="xla") + b = torch.rand(2, 2, device="xla") + + def test(): + return torch.add(a, b) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""Shapes are not compatible for broadcasting: f32[2,2,4,4] vs. f32[2,2]. Expected dimension 2 of shape f32[2,2,4,4] (4) to match dimension 0 of shape f32[2,2] (2). Either that or that any of them is either 1 or unbounded. Try reshaping one of the tensors to match the other.""" + ) + + @onlyOnCPU + def test_construct_large_tensor_raises_error(self): + + def test(): + # When eager-mode is enabled, OOM is triggered here. + a = torch.rand(1024, 1024, 1024, 1024, 1024, device=torch_xla.device()) + b = a.sum() + # OOM is raised when we try to bring data from the device. + return b.cpu() + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""Error preparing computation: Out of memory allocating 4503599761588224 bytes.""" + ) + + def test_cat_raises_error_on_incompatible_shapes(self): + a = torch.rand(2, 2, device=torch_xla.device()) + b = torch.rand(5, 1, device=torch_xla.device()) + + def test(): + return torch.cat([a, b]) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""cat(): cannot concatenate tensors of shape f32[2,2] with f32[5,1] at dimension 0. Expected shapes to be equal (except at dimension 0) or that either of them was a 1D empty tensor of size (0,).""" + ) + + def test_div_raises_error_on_invalid_rounding_mode(self): + a = torch.rand(2, 2, device=torch_xla.device()) + + def test(): + return torch.div(a, 2, rounding_mode="bad") + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""div(): invalid rounding mode `bad`. Expected it to be either 'trunc', 'floor', or be left unspecified.""" + ) + + def test_flip_raises_error_on_duplicated_dims(self): + a = torch.rand(2, 2, 2, 2, device=torch_xla.device()) + dims = [0, 0, 0, 1, 2, 3, -1] + + def test(): + return torch.flip(a, dims=dims) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""flip(): expected each dimension to appear at most once. Found dimensions: 0 (3 times), 3 (2 times). Consider changing dims from [0, 0, 0, 1, 2, 3, -1] to [0, 1, 2, 3].""" + ) + + def test_full_raises_error_on_negative_size(self): + shape = [2, -2, 2] + + def test(): + return torch.full(shape, 1.5, device="xla") + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""full(): expected concrete sizes (i.e. non-symbolic) to be positive values. However found negative ones: [2, -2, 2].""" + ) + + def test_gather_raises_error_on_rank_mismatch(self): + S = 2 + + input = torch.arange(4, device=torch_xla.device()).view(S, S) + index = torch.randint(0, S, (S, S, S), device=torch_xla.device()) + dim = 1 + + def test(): + return torch.gather(input, dim, index) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""gather(): expected rank of input (2) and index (3) tensors to be the same.""" + ) + + def test_gather_raises_error_on_invalid_index_size(self): + S = 2 + X = S + 2 + + input = torch.arange(16, device=torch_xla.device()).view(S, S, S, S) + index = torch.randint(0, S, (X, S, X, S), device=torch_xla.device()) + dim = 1 + + def test(): + return torch.gather(input, dim, index) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""gather(): expected sizes of index [4, 2, 4, 2] to be smaller or equal those of input [2, 2, 2, 2] on all dimensions, except on dimension 1. However, that's not true on dimensions [0, 2].""" + ) + + def test_random__raises_error_on_empty_interval(self): + a = torch.empty(10, device=torch_xla.device()) + from_ = 3 + to_ = 1 + + def test(): + return a.random_(from_, to_) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""random_(): expected `from` (3) to be smaller than `to` (1).""" + ) + + def test_random__raises_error_on_value_out_of_type_value_range(self): + a = torch.empty(10, device=torch_xla.device(), dtype=torch.float16) + from_ = 3 + to_ = 65_504 + 2 + + def test(): + return a.random_(from_, to_) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""random_(): expected `to` to be within the range [-65504, 65504]. However got value 65505, which is greater than the upper bound.""" + ) + + def test_mm_raises_error_on_non_matrix_input(self): + device = torch_xla.device() + a = torch.rand(2, 2, 2, device=device) + b = torch.rand(2, 2, device=device) + + def test(): + torch.mm(a, b) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""mm(): expected the first input tensor f32[2,2,2] to be a matrix (i.e. a 2D tensor).""" + ) + + def test_mm_raises_error_on_incompatible_shapes(self): + device = torch_xla.device() + a = torch.rand(2, 5, device=device) + b = torch.rand(8, 2, device=device) + + def test(): + torch.mm(a, b) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""mm(): cannot matrix-multiply tensors f32[2,5] and f32[8,2]. Expected the size of dimension 1 of the first input tensor (5) to be equal the size of dimension 0 of the second input tensor (8).""" + ) From 23297469abe4dda82925a79554fe3540954987e8 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Mon, 15 Sep 2025 09:23:21 -0300 Subject: [PATCH 106/133] Fix `test_ops_error_message.py` and run it on CI. (#9640) This PR fixes #9622, which extracted error messages checks out of `test_operations.py` into `test_ops_error_message.py`. There were a few problems with that PR, namely: - `unittest.main()` wasn't being run: although calling `python -m pytest` runs it automatically, running it without the `pytest` module did nothing - `onlyOnCPU()` would error if `PJRT_DEVICE` environment variable wasn't set - `test_ops_error_message.py` wasn't being run on CI This PR fixes all the aforementioned PRs. --- test/run_tests.sh | 1 + test/test_ops_error_message.py | 10 +++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/test/run_tests.sh b/test/run_tests.sh index bb03d7abe161..3982b531a49c 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -150,6 +150,7 @@ function run_xla_op_tests1 { run_dynamic "$_TEST_DIR/ds/test_dynamic_shape_models.py" "$@" --verbosity=$VERBOSITY run_eager_debug "$_TEST_DIR/test_operations.py" "$@" --verbosity=$VERBOSITY run_test "$_TEST_DIR/test_operations.py" "$@" --verbosity=$VERBOSITY + run_test "$_TEST_DIR/test_ops_error_message.py" run_test "$_TEST_DIR/test_xla_graph_execution.py" "$@" --verbosity=$VERBOSITY run_pt_xla_debug_level2 "$_TEST_DIR/test_xla_graph_execution.py" "$@" --verbosity=$VERBOSITY run_test_without_functionalization "$_TEST_DIR/test_operations.py" "$@" --verbosity=$VERBOSITY diff --git a/test/test_ops_error_message.py b/test/test_ops_error_message.py index bbb9f4b95b77..8fbbf66d69ed 100644 --- a/test/test_ops_error_message.py +++ b/test/test_ops_error_message.py @@ -6,7 +6,7 @@ def onlyOnCPU(fn): - accelerator = os.environ.get("PJRT_DEVICE").lower() + accelerator = os.environ.get("PJRT_DEVICE", "").lower() return unittest.skipIf(accelerator != "cpu", "PJRT_DEVICE=CPU required")(fn) @@ -158,7 +158,7 @@ def test_mm_raises_error_on_non_matrix_input(self): b = torch.rand(2, 2, device=device) def test(): - torch.mm(a, b) + return torch.mm(a, b) self.assertExpectedRaisesInline( exc_type=RuntimeError, @@ -172,10 +172,14 @@ def test_mm_raises_error_on_incompatible_shapes(self): b = torch.rand(8, 2, device=device) def test(): - torch.mm(a, b) + return torch.mm(a, b) self.assertExpectedRaisesInline( exc_type=RuntimeError, callable=test, expect="""mm(): cannot matrix-multiply tensors f32[2,5] and f32[8,2]. Expected the size of dimension 1 of the first input tensor (5) to be equal the size of dimension 0 of the second input tensor (8).""" ) + + +if __name__ == "__main__": + unittest.main() From efe20ab8ff0026a17aa1a533a536cbfee5c67ecd Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Mon, 15 Sep 2025 11:03:07 -0700 Subject: [PATCH 107/133] Do not warn on jax usage when workarounds are available (#9624) This prevent excessive logging when using xp.Trace or get_op_sharding. --- torch_xla/_internal/jax_workarounds.py | 10 ++++++---- torch_xla/debug/profiler.py | 2 +- torch_xla/distributed/spmd/xla_sharding.py | 3 ++- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/torch_xla/_internal/jax_workarounds.py b/torch_xla/_internal/jax_workarounds.py index d2d665704184..d16bc8fa0340 100644 --- a/torch_xla/_internal/jax_workarounds.py +++ b/torch_xla/_internal/jax_workarounds.py @@ -58,7 +58,7 @@ def maybe_get_torchax(): return None -def maybe_get_jax(): +def maybe_get_jax(log=True): try: jax_import_guard() with jax_env_context(): @@ -67,6 +67,8 @@ def maybe_get_jax(): jax.config.update('jax_use_shardy_partitioner', False) return jax except (ModuleNotFoundError, ImportError): - logging.warn('You are trying to use a feature that requires jax/pallas.' - 'You can install Jax/Pallas via pip install torch_xla[pallas]') - return None \ No newline at end of file + if log: + logging.warning( + 'You are trying to use a feature that requires jax/pallas.' + 'You can install Jax/Pallas via pip install torch_xla[pallas]') + return None diff --git a/torch_xla/debug/profiler.py b/torch_xla/debug/profiler.py index ffbd754b1c76..37dfa501e4f5 100644 --- a/torch_xla/debug/profiler.py +++ b/torch_xla/debug/profiler.py @@ -131,7 +131,7 @@ def __enter__(self): self._jax_scope = None # Also enter the JAX named scope, to support torchax lowering. - if jax := maybe_get_jax(): + if jax := maybe_get_jax(log=False): self._jax_scope = jax.named_scope(self.name) self._jax_scope.__enter__() diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index c010fd4c3523..be6daca582ed 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -646,7 +646,8 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})." tx = maybe_get_torchax() - jax = maybe_get_jax() + # Do not log jax warnings when workarounds are available. + jax = maybe_get_jax(log=False) if (jax is not None) and (tx is not None) and isinstance(t, tx.tensor.Tensor): from jax.sharding import PartitionSpec as P, NamedSharding jmesh = mesh.get_jax_mesh() From a66cfc381e6b28a4b6f732de7db4e08046cbea47 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Tue, 16 Sep 2025 11:47:35 -0300 Subject: [PATCH 108/133] `roll`: improve error handling and error messages. (#9628) This PR refactors the `roll` operation implementation by improving its error message, and returning a status type value. **Key Changes:** - Make `tensor_methods::roll` return `StatusOr` - Improve error messages and error handling - Create `CheckRollShiftsRequired` and `CheckRollDimsAndShiftsAreCompatible` functions --- test/test_ops_error_message.py | 42 +++++++++++++++++++++++++ torch_xla/csrc/aten_xla_type.cpp | 12 ++++--- torch_xla/csrc/tensor_methods.cpp | 52 ++++++++++++++++++++++++------- torch_xla/csrc/tensor_methods.h | 5 +-- 4 files changed, 94 insertions(+), 17 deletions(-) diff --git a/test/test_ops_error_message.py b/test/test_ops_error_message.py index 8fbbf66d69ed..bb23810f2349 100644 --- a/test/test_ops_error_message.py +++ b/test/test_ops_error_message.py @@ -180,6 +180,48 @@ def test(): expect="""mm(): cannot matrix-multiply tensors f32[2,5] and f32[8,2]. Expected the size of dimension 1 of the first input tensor (5) to be equal the size of dimension 0 of the second input tensor (8).""" ) + def test_roll_raises_error_on_empty_shifts(self): + device = torch_xla.device() + a = torch.rand(2, 2, 2, device=device) + shifts = [] + + def test(): + return torch.roll(a, shifts) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""roll(): expected `shifts` to have at least 1 element.""") + + def test_roll_raises_error_on_shifts_with_empty_dims(self): + device = torch_xla.device() + a = torch.rand(2, 2, 2, device=device) + shifts = [2, 2] + + def test(): + return torch.roll(a, shifts) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""roll(): expected `shifts` [2, 2] (size=2) to have exactly 1 element when `dims` is empty.""" + ) + + def test_roll_raises_error_on_mismatched_dims_and_shifts(self): + device = torch_xla.device() + a = torch.rand(2, 2, 2, device=device) + shifts = [2, 2] + dims = [0] + + def test(): + return torch.roll(a, shifts, dims) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""roll(): expected `dims` [0] (size=1) to match the size of `shifts` [2, 2] (size=2).""" + ) + if __name__ == "__main__": unittest.main() diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index c042d703aa36..106407d70be5 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -17,8 +17,8 @@ #include #include +#include "absl/base/nullability.h" #include "absl/log/absl_check.h" -#include "status.h" #include "torch/csrc/lazy/core/helpers.h" #include "torch/csrc/lazy/core/shape_inference.h" #include "torch/csrc/lazy/core/tensor_util.h" @@ -3317,9 +3317,13 @@ at::Tensor XLANativeFunctions::roll(const at::Tensor& self, at::IntArrayRef shifts, at::IntArrayRef dims) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - return bridge::AtenFromXlaTensor(tensor_methods::roll( - xla_self, XlaHelpers::I64List(shifts), XlaHelpers::I64List(dims))); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW( + absl_nonnull XLATensorPtr output, + tensor_methods::roll(xla_self, XlaHelpers::I64List(shifts), + XlaHelpers::I64List(dims))); + return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor XLANativeFunctions::rrelu_with_noise( diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 21f4db597133..1fc9d66a8295 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -13,6 +13,7 @@ #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "torch_xla/csrc/LazyIr.h" #include "torch_xla/csrc/aten_xla_bridge.h" @@ -506,6 +507,37 @@ absl::Status CheckMMMatrixSizesAreCompatible(const XLATensorPtr& mat1, return absl::OkStatus(); } +absl::Status CheckRollShiftsRequired(absl::Span shifts) { + if (shifts.empty()) { + return absl::InvalidArgumentError( + "roll(): expected `shifts` to have at least 1 element."); + } + return absl::OkStatus(); +} + +absl::Status CheckRollDimsAndShiftsAreCompatible( + absl::Span dims, absl::Span shifts) { + if (dims.empty()) { + // If `dims` is empty, then return an error status if `shifts` is not + // of size one. Otherwise, `dims` and `shifts` are valid. + if (shifts.size() != 1) { + return absl::InvalidArgumentError(absl::StrCat( + "roll(): expected `shifts` [", absl::StrJoin(shifts, /* sep= */ ", "), + "] (size=", shifts.size(), + ") to have exactly 1 element when `dims` is empty.")); + } + } else if (dims.size() != shifts.size()) { + // If `dims` is not empty, then return an error status if its size + // does not match with `shifts` size. + return absl::InvalidArgumentError(absl::StrCat( + "roll(): expected `dims` [", absl::StrJoin(dims, /* sep= */ ", "), + "] (size=", dims.size(), ") to match the size of `shifts` [", + absl::StrJoin(shifts, /* sep= */ ", "), "] (size=", shifts.size(), + ").")); + } + return absl::OkStatus(); +} + } // namespace ////////////////////////////////////////////////////////////////////////////// @@ -3052,17 +3084,15 @@ void resize_(XLATensorPtr& input, std::vector size) { } } -XLATensorPtr roll(const XLATensorPtr& input, absl::Span shifts, - absl::Span dims) { - XLA_CHECK_GT(shifts.size(), 0) << "`shifts` required"; - if (dims.size() != 0) { - XLA_CHECK_EQ(shifts.size(), dims.size()) - << "shifts and dimensions must align. shifts: " << shifts.size() - << ", dims:" << dims.size(); - } - auto canonical_dims = torch::lazy::GetCanonicalDimensionIndices( - torch::lazy::ToVector(dims), - input->shape().get().dimensions_size()); +absl::StatusOr roll( + const absl_nonnull XLATensorPtr& input, absl::Span shifts, + absl::Span dims) { + XLA_RETURN_IF_ERROR(CheckRollShiftsRequired(shifts)); + XLA_RETURN_IF_ERROR(CheckRollDimsAndShiftsAreCompatible(dims, shifts)); + const std::vector canonical_dims = + torch::lazy::GetCanonicalDimensionIndices( + torch::lazy::ToVector(dims), + input->shape().get().dimensions().size()); return input->CreateFrom(torch_xla::MakeNode( input->GetIrValue(), torch::lazy::ToVector(shifts), canonical_dims)); diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index b25b423d49c2..3b8fd0518ab9 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -821,8 +821,9 @@ XLATensorPtr replication_pad3d_backward(const XLATensorPtr& grad_output, void resize_(XLATensorPtr& input, std::vector size); -XLATensorPtr roll(const XLATensorPtr& input, absl::Span shifts, - absl::Span dims); +absl::StatusOr roll( + const absl_nonnull XLATensorPtr& input, absl::Span shifts, + absl::Span dims); XLATensorPtr rrelu_with_noise(const XLATensorPtr& input, XLATensorPtr& noise, const at::Scalar& lower, const at::Scalar& upper, From 6d755eec17b8b074a85a52dde8204258f6eeac33 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Tue, 16 Sep 2025 17:02:31 -0300 Subject: [PATCH 109/133] `stack`: improve error handling and error messages. (#9629) This PR refactors the `stack` operation implementation by improving its error message, and returning a status type value. **Key Changes:** - `tensor_methods::stack` returns `StatusOr` - Improve error messages and error handling - Create `CheckStackAtLeastOneTensor` function --- torch_xla/csrc/aten_xla_type.cpp | 17 +++++++++++------ torch_xla/csrc/ops/index_ops.cpp | 10 +++++----- torch_xla/csrc/tensor_methods.cpp | 25 +++++++++++++++++++------ torch_xla/csrc/tensor_methods.h | 3 ++- torch_xla/csrc/tensor_ops.cpp | 4 +++- 5 files changed, 40 insertions(+), 19 deletions(-) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 106407d70be5..c2514b410976 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -14,6 +14,7 @@ #include #include +#include #include #include @@ -3696,12 +3697,16 @@ at::Tensor XLANativeFunctions::squeeze_copy(const at::Tensor& self, at::Tensor XLANativeFunctions::stack(at::TensorList tensors, int64_t dim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); at::ScalarType result_type = at::native::result_type(tensors); - std::vector c_tensors(tensors.size()); - std::transform(tensors.begin(), tensors.end(), c_tensors.begin(), - [=](const at::Tensor& t) { return t.to(result_type); }); - XLA_ASSIGN_OR_THROW(std::vector xla_c_tensors, - bridge::GetXlaTensors(c_tensors)); - return bridge::AtenFromXlaTensor(tensor_methods::stack(xla_c_tensors, dim)); + std::vector xla_tensors; + std::transform(tensors.begin(), tensors.end(), + std::back_inserter(xla_tensors), [=](const at::Tensor& t) { + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_t, + bridge::GetXlaTensor(t.to(result_type))); + return xla_t; + }); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output, + tensor_methods::stack(xla_tensors, dim)); + return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor XLANativeFunctions::std(const at::Tensor& self, bool unbiased) { diff --git a/torch_xla/csrc/ops/index_ops.cpp b/torch_xla/csrc/ops/index_ops.cpp index 9dddc9424e03..2b6038bce0a2 100644 --- a/torch_xla/csrc/ops/index_ops.cpp +++ b/torch_xla/csrc/ops/index_ops.cpp @@ -339,8 +339,8 @@ XLATensorPtr IndexByTensors(const XLATensorPtr& base, canonical_indices.front()->shape().get().dimensions_size(); // Stack the indices to allow the whole multi-indexing to be dispatched with a // single gather. - XLATensorPtr indices_nd = - tensor_methods::stack(canonical_indices, indices_rank); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr indices_nd, + tensor_methods::stack(canonical_indices, indices_rank)); return XLATensor::Create( torch_xla::MakeNode(base->GetIrValue(), indices_nd->GetIrValue(), start_dim), @@ -356,11 +356,11 @@ torch::lazy::Value IndexPutByTensors( } auto canonical_indices = WrapIndicesOnce(base, indices, start_dim); int64_t indices_rank = - canonical_indices.front()->shape().get().dimensions_size(); + canonical_indices.front()->shape().get().dimensions().size(); // Stack the indices to allow the whole multi-indexing to be dispatched with a // single scatter. - XLATensorPtr indices_nd = - tensor_methods::stack(canonical_indices, indices_rank); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr indices_nd, + tensor_methods::stack(canonical_indices, indices_rank)); return torch_xla::MakeNode( torch_xla::MakeNode(base->GetIrValue(), indices_nd->GetIrValue(), start_dim, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 1fc9d66a8295..11545f72ff43 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -538,6 +538,15 @@ absl::Status CheckRollDimsAndShiftsAreCompatible( return absl::OkStatus(); } +absl::Status CheckStackAtLeastOneTensor( + absl::Span tensors) { + if (tensors.size() == 0) { + return XLA_ERROR_WITH_LOCATION( + absl::InvalidArgumentError("stack(): expected at least one tensor.")); + } + return absl::OkStatus(); +} + } // namespace ////////////////////////////////////////////////////////////////////////////// @@ -3422,14 +3431,18 @@ XLATensorPtr squeeze(const XLATensorPtr& input, std::vector dims) { return view(input, output_dimensions); } -XLATensorPtr stack(absl::Span tensors, int64_t dim) { - XLA_CHECK_GT(tensors.size(), 0); +absl::StatusOr stack( + absl::Span tensors, int64_t dim) { + XLA_RETURN_IF_ERROR(CheckStackAtLeastOneTensor(tensors)); + std::vector values; - for (auto& tensor : tensors) { - values.push_back(tensor->GetIrValue()); - } + std::transform( + tensors.begin(), tensors.end(), std::back_inserter(values), + [](const absl_nonnull XLATensorPtr t) { return t->GetIrValue(); }); + int64_t canonical_dim = torch::lazy::GetCanonicalDimensionIndex( - dim, tensors.front()->shape().get().dimensions_size() + 1); + dim, tensors.front()->shape().get().dimensions().size() + 1); + return tensors[0]->CreateFrom( torch_xla::MakeNode(values, canonical_dim)); } diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 3b8fd0518ab9..71a522d4d06a 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -926,7 +926,8 @@ XLATensorPtr squeeze(const XLATensorPtr& input, std::vector dims); void squeeze_(XLATensorPtr& input); void squeeze_(XLATensorPtr& input, int64_t dim); -XLATensorPtr stack(absl::Span tensors, int64_t dim); +absl::StatusOr stack( + absl::Span tensors, int64_t dim); XLATensorPtr std(const XLATensorPtr& input, std::vector dimensions, bool keep_reduced_dimensions, double correction); diff --git a/torch_xla/csrc/tensor_ops.cpp b/torch_xla/csrc/tensor_ops.cpp index 84d788d624fb..ad267ac2d693 100644 --- a/torch_xla/csrc/tensor_ops.cpp +++ b/torch_xla/csrc/tensor_ops.cpp @@ -62,7 +62,9 @@ XLATensorPtr Cross(const XLATensorPtr& input, const XLATensorPtr& other, XLATensorPtr s3 = tensor_methods::sub(tensor_methods::mul(u1, v2), tensor_methods::mul(u2, v1), one); // Stack the terms into one result tensor. - return tensor_methods::stack({s1, s2, s3}, canonical_dim); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output, + tensor_methods::stack({s1, s2, s3}, canonical_dim)); + return output; } XLATensorPtr MakeMatrixWithDiagonal(const XLATensorPtr& input, From 0c0ae2d22d5eec52b9c69da831185a92bbb37d1c Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 17 Sep 2025 16:19:49 -0300 Subject: [PATCH 110/133] `expand`: improve error handling and error messages. (#9645) This PR refactors the `expand` operation implementation by improving its error message, and returning a status type value. **Key Changes:** - Make `tensor_methods::expand` return `StatusOr` - Improve error messages and error handling - Create new `CheckExpandValidRank` for checking the input and the given sizes' rank - Modified `GetExpandDimensions`, calling the check function above, and also checking whether input and sizes corresponding dimensions are valid --- test/run_tests.sh | 1 + ...rror_message_functionalization_disabled.py | 43 +++++++++++ torch_xla/csrc/aten_xla_type.cpp | 32 ++++---- torch_xla/csrc/init_python_bindings.cpp | 16 ++-- torch_xla/csrc/tensor_methods.cpp | 73 +++++++++++++------ torch_xla/csrc/tensor_methods.h | 11 +-- torch_xla/csrc/tensor_ops.cpp | 4 +- 7 files changed, 132 insertions(+), 48 deletions(-) create mode 100644 test/test_ops_error_message_functionalization_disabled.py diff --git a/test/run_tests.sh b/test/run_tests.sh index 3982b531a49c..9cb711f49215 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -151,6 +151,7 @@ function run_xla_op_tests1 { run_eager_debug "$_TEST_DIR/test_operations.py" "$@" --verbosity=$VERBOSITY run_test "$_TEST_DIR/test_operations.py" "$@" --verbosity=$VERBOSITY run_test "$_TEST_DIR/test_ops_error_message.py" + run_test "$_TEST_DIR/test_ops_error_message_functionalization_disabled.py" run_test "$_TEST_DIR/test_xla_graph_execution.py" "$@" --verbosity=$VERBOSITY run_pt_xla_debug_level2 "$_TEST_DIR/test_xla_graph_execution.py" "$@" --verbosity=$VERBOSITY run_test_without_functionalization "$_TEST_DIR/test_operations.py" "$@" --verbosity=$VERBOSITY diff --git a/test/test_ops_error_message_functionalization_disabled.py b/test/test_ops_error_message_functionalization_disabled.py new file mode 100644 index 000000000000..f14ee4f76518 --- /dev/null +++ b/test/test_ops_error_message_functionalization_disabled.py @@ -0,0 +1,43 @@ +import os + +os.environ["XLA_DISABLE_FUNCTIONALIZATION"] = "1" + +import expecttest +import torch +import torch_xla +import unittest + + +class TestOpsErrorMessageFunctionalizationDisabled(expecttest.TestCase): + + def test_expand_raises_error_on_higher_rank_tensor(self): + device = torch_xla.device() + a = torch.rand(1, 1, 2, 3, device=device) + sizes = [-1, 3] + + def test(): + return a.expand(sizes) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""expand(): expected the `input` tensor f32[1,1,2,3] (rank: 4) to have a rank smaller or equal to the given `sizes` [-1, 3] (rank: 2).""" + ) + + def test_expand_raises_error_on_size_mismatch(self): + device = torch_xla.device() + a = torch.rand(1, 1, 2, 3, device=device) + sizes = [1, 1, 1, 3] + + def test(): + return a.expand(sizes) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""expand(): expected dimension 2 of the given `sizes` [1, 1, 1, 3] (1) to be -1, or equal to the size of the `input` tensor f32[1,1,2,3] at dimension 2 (2).""" + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index c2514b410976..91101778d7aa 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1786,19 +1786,21 @@ at::Tensor XLANativeFunctions::empty_strided_symint( } at::Tensor XLANativeFunctions::expand_copy_symint(const at::Tensor& self, - at::SymIntArrayRef sym_size, + at::SymIntArrayRef sym_sizes, bool implicit) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - std::optional size = c10::asIntArrayRefSlowOpt(sym_size); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - if (size.has_value()) { - return bridge::AtenFromXlaTensor(tensor_methods::expand( - xla_self, torch::lazy::ToVector(*size))); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + std::optional sizes = c10::asIntArrayRefSlowOpt(sym_sizes); + if (sizes.has_value()) { + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output, + tensor_methods::expand(xla_self, *sizes)); + return bridge::AtenFromXlaTensor(std::move(output)); } else { // at least one of the dimension is symbolic, use the sym_int version of the // node return bridge::AtenFromXlaTensor( - tensor_methods::expand_symint(xla_self, sym_size)); + tensor_methods::expand_symint(xla_self, sym_sizes)); } } @@ -4563,19 +4565,21 @@ at::Tensor XLANativeFunctions::diagonal(const at::Tensor& self, int64_t offset, } at::Tensor XLANativeFunctions::expand_symint(const at::Tensor& self, - at::SymIntArrayRef sym_size, + at::SymIntArrayRef sym_sizes, bool implicit) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - std::optional size = c10::asIntArrayRefSlowOpt(sym_size); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - if (size.has_value()) { - return bridge::AtenFromXlaTensor(tensor_methods::expand( - xla_self, torch::lazy::ToVector(*size))); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + std::optional sizes = c10::asIntArrayRefSlowOpt(sym_sizes); + if (sizes.has_value()) { + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output, + tensor_methods::expand(xla_self, *sizes)); + return bridge::AtenFromXlaTensor(std::move(output)); } else { // at least one of the dimension is symbolic, use the sym_int version of the // node return bridge::AtenFromXlaTensor( - tensor_methods::expand_symint(xla_self, sym_size)); + tensor_methods::expand_symint(xla_self, sym_sizes)); } } diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 1d205dd86a7e..a52ecc8124e7 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -24,6 +24,7 @@ #include #include +#include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" @@ -451,15 +452,18 @@ at::Tensor AllReduce(const std::string& reduce_type, const at::Tensor& input, } at::Tensor DynamicExpand(const at::Tensor& input, - const std::vector& size, + const std::vector& sizes, const at::Tensor& src_tensor, int src_dim, int target_dim) { - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_input, bridge::GetXlaTensor(input)); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_src_tensor, + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_input, + bridge::GetXlaTensor(input)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_src_tensor, bridge::GetXlaTensor(src_tensor)); - XLATensorPtr result = tensor_methods::dynamic_expand( - xla_input, size, xla_src_tensor, src_dim, target_dim); - return bridge::AtenFromXlaTensor(std::move(result)); + XLA_ASSIGN_OR_THROW( + absl_nonnull XLATensorPtr output, + tensor_methods::dynamic_expand(xla_input, sizes, xla_src_tensor, src_dim, + target_dim)); + return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor DynamicView(const at::Tensor& input, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 11545f72ff43..0ec204b0fff1 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -10,11 +10,13 @@ #include #include +#include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "absl/types/span.h" #include "torch_xla/csrc/LazyIr.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/data_ops.h" @@ -234,16 +236,45 @@ void CheckBmmDimension(const std::string& tag, const XLATensorPtr& batch1, "batch2", 2); } -std::vector GetExpandDimensions(const xla::Shape& shape, - std::vector dimensions) { - XLA_CHECK_GE(dimensions.size(), shape.dimensions_size()) << shape; - int64_t base = dimensions.size() - shape.dimensions_size(); - for (size_t i = 0; i < shape.dimensions_size(); ++i) { - if (dimensions[base + i] == -1) { - dimensions[base + i] = shape.dimensions(i); +absl::Status CheckExpandValidRank(const XLATensorPtr& input, + const absl::Span sizes) { + xla::Shape shape = input->shape(); + int64_t rank = shape.dimensions().size(); + if (rank > sizes.size()) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + "expand(): expected the `input` tensor ", shape.ToString(), " (rank: ", + rank, ") to have a rank smaller or equal to the given `sizes` [", + absl::StrJoin(sizes, /* sep= */ ", "), "] (rank: ", sizes.size(), + ")."))); + } + return absl::OkStatus(); +} + +absl::StatusOr> GetExpandDimensions( + const XLATensorPtr& input, const absl::Span sizes) { + XLA_RETURN_IF_ERROR(CheckExpandValidRank(input, sizes)); + + xla::Shape shape = input->shape(); + const int64_t rank = shape.dimensions().size(); + const int64_t base = sizes.size() - rank; + + std::vector expanded_dimensions(sizes.begin(), sizes.end()); + for (size_t i = 0; i < shape.dimensions().size(); ++i) { + const int64_t dim = base + i; + const int64_t size = sizes[dim]; + if (size == -1) { + expanded_dimensions[dim] = shape.dimensions(i); + } else if (shape.dimensions(i) != 1 && size != shape.dimensions(i)) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + "expand(): expected dimension ", dim, " of the given `sizes` [", + absl::StrJoin(sizes, /* sep= */ ", "), "] (", size, + ") to be -1, or equal to the size of the `input` tensor ", + shape.ToString(), " at dimension ", i, " (", shape.dimensions(i), + ")."))); } } - return dimensions; + + return expanded_dimensions; } // Resizes and / or checks whether a list is of the given size. The list is only @@ -1791,11 +1822,12 @@ XLATensorPtr exp(const XLATensorPtr& input) { return input->CreateFrom(Exp(input->GetIrValue())); } -XLATensorPtr expand(const XLATensorPtr& input, std::vector size) { - auto input_shape = input->shape(); - auto output = input->CreateFrom(torch_xla::MakeNode( - input->GetIrValue(), - GetExpandDimensions(input_shape.get(), std::move(size)))); +absl::StatusOr expand( + const XLATensorPtr& input, const absl::Span sizes) { + XLA_ASSIGN_OR_RETURN(std::vector expanded_dimensions, + GetExpandDimensions(input, sizes)); + auto output = input->CreateFrom( + torch_xla::MakeNode(input->GetIrValue(), expanded_dimensions)); output->SetStorage(input->Storage()); return output; } @@ -2927,15 +2959,14 @@ XLATensorPtr cast_int4(const XLATensorPtr& weight, // Dynamic Reshape ops here. ////////////////////////////////////////////////////////////////////////////// -XLATensorPtr dynamic_expand(const XLATensorPtr& input, - const std::vector& size, - const XLATensorPtr& src_tensor, int src_dim, - int target_dim) { - std::vector expanded_size = - GetExpandDimensions(input->shape().get(), size); +absl::StatusOr dynamic_expand( + const XLATensorPtr& input, const absl::Span sizes, + const XLATensorPtr& src_tensor, int src_dim, int target_dim) { + XLA_ASSIGN_OR_RETURN(std::vector expanded_dimensions, + GetExpandDimensions(input, sizes)); torch::lazy::NodePtr node = torch_xla::MakeNode( - input->GetIrValue(), expanded_size, src_tensor->GetIrValue(), src_dim, - target_dim); + input->GetIrValue(), expanded_dimensions, src_tensor->GetIrValue(), + src_dim, target_dim); return input->CreateFrom(torch::lazy::Value(node)); } diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 71a522d4d06a..e91e92ad96e5 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -2,6 +2,7 @@ #define XLA_TORCH_XLA_CSRC_TENSOR_METHODS_H_ #include "absl/base/nullability.h" +#include "absl/types/span.h" #include "torch_xla/csrc/cross_replica_reduces.h" #include "torch_xla/csrc/ops/custom_sharding.h" #include "torch_xla/csrc/runtime/computation_client.h" @@ -158,10 +159,9 @@ XLATensorPtr cast_int4(const XLATensorPtr& weight, // Dynamic Reshape ops here. ////////////////////////////////////////////////////////////////////////////// -XLATensorPtr dynamic_expand(const XLATensorPtr& input, - const std::vector& size, - const XLATensorPtr& src_tensor, int src_dim, - int target_dim); +absl::StatusOr dynamic_expand( + const XLATensorPtr& input, const absl::Span sizes, + const XLATensorPtr& src_tensor, int src_dim, int target_dim); XLATensorPtr dynamic_view(const XLATensorPtr& input, const std::vector& size, @@ -427,7 +427,8 @@ XLATensorPtr eq(const XLATensorPtr& input, const XLATensorPtr& other); XLATensorPtr exp(const XLATensorPtr& input); -XLATensorPtr expand(const XLATensorPtr& input, std::vector size); +absl::StatusOr expand( + const XLATensorPtr& input, const absl::Span sizes); XLATensorPtr expand_symint(const XLATensorPtr& input, c10::SymIntArrayRef sym_size); diff --git a/torch_xla/csrc/tensor_ops.cpp b/torch_xla/csrc/tensor_ops.cpp index ad267ac2d693..ef74063540c2 100644 --- a/torch_xla/csrc/tensor_ops.cpp +++ b/torch_xla/csrc/tensor_ops.cpp @@ -240,9 +240,9 @@ XLATensorPtr EmbeddingDenseBackward(const XLATensorPtr& grad_output, // padding_idx. XLATensorPtr skip_padding = tensor_methods::unsqueeze( tensor_methods::ne(indices_rank1, padding_idx), 1); - skip_padding = tensor_methods::expand( + XLA_ASSIGN_OR_THROW( skip_padding, - torch::lazy::ToVector(grad->shape().get().dimensions())); + tensor_methods::expand(skip_padding, grad->shape().get().dimensions())); XLATensorPtr zero_grad = tensor_methods::full_like(grad, 0, grad->GetDevice(), grad->dtype()); return tensor_methods::index_put( From 0fc62aa26a30ed7ca419d285f285cb5ba02c4394 Mon Sep 17 00:00:00 2001 From: qihqi Date: Wed, 24 Sep 2025 14:15:25 -0700 Subject: [PATCH 111/133] update gcc (#9650) --- .github/workflows/_build_torch_xla.yml | 153 ++++++----- .github/workflows/_docs.yml | 134 +++++----- .github/workflows/_test.yml | 342 ++++++++++++------------- .github/workflows/_torchprime_ci.yml | 222 ++++++++-------- .github/workflows/setup/action.yml | 98 +++---- WORKSPACE | 8 +- bazel/fmt.BUILD | 9 + bazel/torch.BUILD | 3 + scripts/build_torch_wheels.sh | 6 +- 9 files changed, 496 insertions(+), 479 deletions(-) create mode 100644 bazel/fmt.BUILD diff --git a/.github/workflows/_build_torch_xla.yml b/.github/workflows/_build_torch_xla.yml index 34b0df460c83..aaff9c2bf9dc 100644 --- a/.github/workflows/_build_torch_xla.yml +++ b/.github/workflows/_build_torch_xla.yml @@ -1,80 +1,79 @@ name: build-torch-xla on: - workflow_call: - inputs: - dev-image: - required: true - type: string - description: Base image for builds - torch-commit: - required: true - type: string - description: torch-commit - runner: - required: false - type: string - description: Runner type for the test - default: linux.12xlarge - timeout-minutes: - required: false - type: number - description: Timeout in minutes for the build job - default: 45 # Takes ~20m as of 2025/5/30. - has_code_changes: - required: false - type: string - description: Whether to run full workflow or not - default: 'true' - secrets: - gcloud-service-key: - required: true - description: Secret to access Bazel build cache + workflow_call: + inputs: + dev-image: + required: true + type: string + description: Base image for builds + torch-commit: + required: true + type: string + description: torch-commit + runner: + required: false + type: string + description: Runner type for the test + default: linux.12xlarge + timeout-minutes: + required: false + type: number + description: Timeout in minutes for the build job + default: 45 # Takes ~20m as of 2025/5/30. + has_code_changes: + required: false + type: string + description: Whether to run full workflow or not + default: "true" + secrets: + gcloud-service-key: + required: true + description: Secret to access Bazel build cache jobs: - build: - runs-on: ${{ inputs.runner }} - timeout-minutes: ${{ inputs.timeout-minutes }} - container: - image: ${{ inputs.dev-image }} - env: - GCLOUD_SERVICE_KEY: ${{ secrets.gcloud-service-key }} - BAZEL_REMOTE_CACHE: ${{ github.event_name == 'push' || github.event.pull_request.head.repo.full_name == github.repository }} - BAZEL_JOBS: '' # Let bazel decide the parallelism based on the number of CPUs. - BUILD_CPP_TESTS: 1 - steps: - # Need to check out local composite actions before using them - # https://github.com/orgs/community/discussions/11771 - - name: Checkout actions - if: inputs.has_code_changes == 'true' - uses: actions/checkout@v4 - with: - sparse-checkout: | - .github/workflows/setup - path: .actions - - name: Setup - if: inputs.has_code_changes == 'true' - uses: ./.actions/.github/workflows/setup - with: - torch-commit: ${{ inputs.torch-commit }} - - name: Build - if: inputs.has_code_changes == 'true' - shell: bash - run: | - cd pytorch/xla/infra/ansible - ansible-playbook playbook.yaml -vvv -e "stage=build arch=amd64 accelerator=tpu src_root=${GITHUB_WORKSPACE} bundle_libtpu=0 build_cpp_tests=1 git_versioned_xla_build=1 cache_suffix=-ci" --skip-tags=fetch_srcs,install_deps - - name: Upload wheel - if: inputs.has_code_changes == 'true' - uses: actions/upload-artifact@v4 - with: - name: torch-xla-wheels - path: /dist/*.whl - - name: Upload CPP test binaries - if: inputs.has_code_changes == 'true' - uses: actions/upload-artifact@v4 - with: - name: cpp-test-bin - path: /tmp/test/bin - - name: Report no code changes - if: inputs.has_code_changes == 'false' - run: | - echo "No code changes were detected that require running the full test suite." - + build: + runs-on: ${{ inputs.runner }} + timeout-minutes: ${{ inputs.timeout-minutes }} + container: + image: ${{ inputs.dev-image }} + env: + GCLOUD_SERVICE_KEY: ${{ secrets.gcloud-service-key }} + BAZEL_REMOTE_CACHE: ${{ github.event_name == 'push' || github.event.pull_request.head.repo.full_name == github.repository }} + BAZEL_JOBS: "" # Let bazel decide the parallelism based on the number of CPUs. + BUILD_CPP_TESTS: 1 + steps: + # Need to check out local composite actions before using them + # https://github.com/orgs/community/discussions/11771 + - name: Checkout actions + if: inputs.has_code_changes == 'true' + uses: actions/checkout@v4 + with: + sparse-checkout: | + .github/workflows/setup + path: .actions + - name: Setup + if: inputs.has_code_changes == 'true' + uses: ./.actions/.github/workflows/setup + with: + torch-commit: ${{ inputs.torch-commit }} + - name: Build + if: inputs.has_code_changes == 'true' + shell: bash + run: | + cd pytorch/xla/infra/ansible + ansible-playbook playbook.yaml -vvv -e "stage=build arch=amd64 accelerator=tpu src_root=${GITHUB_WORKSPACE} bundle_libtpu=0 build_cpp_tests=1 git_versioned_xla_build=1 cache_suffix=-ci" --skip-tags=fetch_srcs,install_deps + - name: Upload wheel + if: inputs.has_code_changes == 'true' + uses: actions/upload-artifact@v4 + with: + name: torch-xla-wheels + path: /dist/*.whl + - name: Upload CPP test binaries + if: inputs.has_code_changes == 'true' + uses: actions/upload-artifact@v4 + with: + name: cpp-test-bin + path: /tmp/test/bin + - name: Report no code changes + if: inputs.has_code_changes == 'false' + run: | + echo "No code changes were detected that require running the full test suite." diff --git a/.github/workflows/_docs.yml b/.github/workflows/_docs.yml index ee2650834e82..f35e0adbca65 100644 --- a/.github/workflows/_docs.yml +++ b/.github/workflows/_docs.yml @@ -1,70 +1,70 @@ name: xla-docs-build on: - workflow_call: - inputs: - dev-image: - required: true - type: string - description: Base image for builds - runner: - required: false - type: string - description: Runner type for the test - default: linux.4xlarge - secrets: - torchxla-bot-token: - required: true + workflow_call: + inputs: + dev-image: + required: true + type: string + description: Base image for builds + runner: + required: false + type: string + description: Runner type for the test + default: linux.4xlarge + secrets: + torchxla-bot-token: + required: true jobs: - build-docs: - runs-on: ubuntu-24.04 - timeout-minutes: 45 - container: - image: ${{ inputs.dev-image }} - env: - BRANCH_NAME: ${{ github.ref_name }} - steps: - - name: Fetch wheels - uses: actions/download-artifact@v4 - with: - name: torch-xla-wheels - path: /tmp/wheels/ - - name: Install wheels - shell: bash - run: | - pip install /tmp/wheels/*.whl - - name: Checkout PyTorch/XLA Repo - uses: actions/checkout@v4 - with: - path: pytorch/xla - - name: Build docs - shell: bash - run: | - cd pytorch/xla/docs - pip install -r requirements.txt - sphinx-build -b html source build - - name: Checkout GitHub Pages - uses: actions/checkout@v4 - with: - path: gh-pages - ref: gh-pages - token: ${{ github.event_name == 'push' && secrets.torchxla-bot-token || github.token }} - - name: Merge changes - shell: bash - run: | - subdir=${{ env.BRANCH_NAME == 'master' && 'master' || format('{0}/{1}', 'release', env.BRANCH_NAME) }} - mkdir -p gh-pages/$subdir - cp -fR pytorch/xla/docs/build/* gh-pages/$subdir - - name: Upload preview as artifact - uses: actions/upload-artifact@v4 - with: - name: github-pages - path: pytorch/xla/docs/build/ - - name: Deploy - shell: bash - run: | - cd gh-pages - git config user.email "pytorchxla@gmail.com" - git config user.name "torchxlabot2" - git add . -v - git diff --cached --exit-code || git commit -m "Update doc from commit ${{ github.sha }}" - git push origin gh-pages + build-docs: + runs-on: ubuntu-24.04 + timeout-minutes: 45 + container: + image: ${{ inputs.dev-image }} + env: + BRANCH_NAME: ${{ github.ref_name }} + steps: + - name: Fetch wheels + uses: actions/download-artifact@v5 + with: + name: torch-xla-wheels + path: /tmp/wheels/ + - name: Install wheels + shell: bash + run: | + pip install /tmp/wheels/*.whl + - name: Checkout PyTorch/XLA Repo + uses: actions/checkout@v4 + with: + path: pytorch/xla + - name: Build docs + shell: bash + run: | + cd pytorch/xla/docs + pip install -r requirements.txt + sphinx-build -b html source build + - name: Checkout GitHub Pages + uses: actions/checkout@v4 + with: + path: gh-pages + ref: gh-pages + token: ${{ github.event_name == 'push' && secrets.torchxla-bot-token || github.token }} + - name: Merge changes + shell: bash + run: | + subdir=${{ env.BRANCH_NAME == 'master' && 'master' || format('{0}/{1}', 'release', env.BRANCH_NAME) }} + mkdir -p gh-pages/$subdir + cp -fR pytorch/xla/docs/build/* gh-pages/$subdir + - name: Upload preview as artifact + uses: actions/upload-artifact@v4 + with: + name: github-pages + path: pytorch/xla/docs/build/ + - name: Deploy + shell: bash + run: | + cd gh-pages + git config user.email "pytorchxla@gmail.com" + git config user.name "torchxlabot2" + git add . -v + git diff --cached --exit-code || git commit -m "Update doc from commit ${{ github.sha }}" + git push origin gh-pages diff --git a/.github/workflows/_test.yml b/.github/workflows/_test.yml index 6c2175117e57..10bc92327dee 100644 --- a/.github/workflows/_test.yml +++ b/.github/workflows/_test.yml @@ -1,180 +1,180 @@ name: xla-test on: - workflow_call: - inputs: - dev-image: - required: true - type: string - description: Base image for builds - runner: - required: false - type: string - description: Runner type for the test - default: linux.12xlarge - collect-coverage: - required: false - type: boolean - description: Set to true to collect coverage information - default: false - timeout-minutes: - required: false - type: number - default: 180 # Takes ~105m as of 2025/5/30. - description: | - Set the maximum (in minutes) how long the workflow should take to finish + workflow_call: + inputs: + dev-image: + required: true + type: string + description: Base image for builds + runner: + required: false + type: string + description: Runner type for the test + default: linux.12xlarge + collect-coverage: + required: false + type: boolean + description: Set to true to collect coverage information + default: false timeout-minutes: - torch-commit: - required: true - type: string - description: torch-commit - has_code_changes: - required: false - type: string - description: Whether to run full workflow or not - default: 'true' - secrets: - gcloud-service-key: - required: true - description: Secret to access Bazel build cache + required: false + type: number + default: 180 # Takes ~105m as of 2025/5/30. + description: | + Set the maximum (in minutes) how long the workflow should take to finish + timeout-minutes: + torch-commit: + required: true + type: string + description: torch-commit + has_code_changes: + required: false + type: string + description: Whether to run full workflow or not + default: "true" + secrets: + gcloud-service-key: + required: true + description: Secret to access Bazel build cache jobs: - test: - runs-on: ${{ inputs.runner }} - container: - image: ${{ inputs.dev-image }} - options: "--shm-size 16g" - strategy: - fail-fast: false - matrix: - include: - # Use readable strings as they define the workflow titles. - - run_benchmark_tests: 'benchmark_tests' - - run_python_tests: 'python_tests' - run_xla_op_tests1: 'xla_op1' - - run_python_tests: 'python_tests' - run_xla_op_tests2: 'xla_op2' - - run_python_tests: 'python_tests' - run_xla_op_tests3: 'xla_op3' - - run_python_tests: 'python_tests' - run_xla_op_tests4: 'xla_op4' - - run_python_tests: 'python_tests' - run_xla_op_tests5: 'xla_op5' - - run_python_tests: 'python_tests' - run_torch_mp_op_tests: 'torch_mp_op' - - run_cpp_tests: 'cpp_tests' - timeout-minutes: ${{ inputs.timeout-minutes }} - env: - GCLOUD_SERVICE_KEY: ${{ secrets.gcloud-service-key }} - GOOGLE_APPLICATION_CREDENTIALS: /tmp/default_credentials.json - USE_COVERAGE: ${{ inputs.collect-coverage && '1' || '0' }} - RUN_BENCHMARK_TESTS: ${{ matrix.run_benchmark_tests }} - RUN_PYTHON_TESTS: ${{ matrix.run_python_tests }} - RUN_XLA_OP_TESTS1: ${{ matrix.run_xla_op_tests1 }} - RUN_XLA_OP_TESTS2: ${{ matrix.run_xla_op_tests2 }} - RUN_XLA_OP_TESTS3: ${{ matrix.run_xla_op_tests3 }} - RUN_XLA_OP_TESTS4: ${{ matrix.run_xla_op_tests4 }} - RUN_XLA_OP_TESTS5: ${{ matrix.run_xla_op_tests5 }} - RUN_TORCH_MP_OP_TESTS: ${{ matrix.run_torch_mp_op_tests }} - RUN_CPP_TESTS: ${{ matrix.run_cpp_tests }} - BAZEL_JOBS: '' # Let bazel decide the parallelism based on the number of CPUs. - BAZEL_REMOTE_CACHE: 1 - steps: - - name: Checkout actions - if: inputs.has_code_changes == 'true' - uses: actions/checkout@v4 - with: - sparse-checkout: | - .github/workflows/setup - path: .actions - - name: Setup - if: inputs.has_code_changes == 'true' - uses: ./.actions/.github/workflows/setup - with: - torch-commit: ${{ inputs.torch-commit }} - wheels-artifact: torch-xla-wheels - - name: Fetch CPP test binaries - if: inputs.has_code_changes == 'true' && matrix.run_cpp_tests - uses: actions/download-artifact@v4 - with: - name: cpp-test-bin - path: /tmp/test/bin - # GitHub Actions doesn't preserve executable permissions - # https://github.com/actions/download-artifact?tab=readme-ov-file#permission-loss - - name: Set CPP test permissions - if: inputs.has_code_changes == 'true' && matrix.run_cpp_tests - run: | - chmod +x /tmp/test/bin/* - ls -l /tmp/test/bin - - name: Install test deps - if: inputs.has_code_changes == 'true' - shell: bash - run: | - # TODO: Add these in setup.py - pip install fsspec - pip install rich - pip install flax - - name: Checkout PyTorch Repo - if: inputs.has_code_changes == 'true' - uses: actions/checkout@v4 - with: - repository: pytorch/pytorch - path: pytorch - ref: ${{ inputs.torch-commit }} - - name: Checkout PyTorch/XLA Repo - if: inputs.has_code_changes == 'true' - uses: actions/checkout@v4 - with: - path: pytorch/xla - - name: Extra CI deps - if: inputs.has_code_changes == 'true' - shell: bash - run: | - set -x + test: + runs-on: ${{ inputs.runner }} + container: + image: ${{ inputs.dev-image }} + options: "--shm-size 16g" + strategy: + fail-fast: false + matrix: + include: + # Use readable strings as they define the workflow titles. + - run_benchmark_tests: "benchmark_tests" + - run_python_tests: "python_tests" + run_xla_op_tests1: "xla_op1" + - run_python_tests: "python_tests" + run_xla_op_tests2: "xla_op2" + - run_python_tests: "python_tests" + run_xla_op_tests3: "xla_op3" + - run_python_tests: "python_tests" + run_xla_op_tests4: "xla_op4" + - run_python_tests: "python_tests" + run_xla_op_tests5: "xla_op5" + - run_python_tests: "python_tests" + run_torch_mp_op_tests: "torch_mp_op" + - run_cpp_tests: "cpp_tests" + timeout-minutes: ${{ inputs.timeout-minutes }} + env: + GCLOUD_SERVICE_KEY: ${{ secrets.gcloud-service-key }} + GOOGLE_APPLICATION_CREDENTIALS: /tmp/default_credentials.json + USE_COVERAGE: ${{ inputs.collect-coverage && '1' || '0' }} + RUN_BENCHMARK_TESTS: ${{ matrix.run_benchmark_tests }} + RUN_PYTHON_TESTS: ${{ matrix.run_python_tests }} + RUN_XLA_OP_TESTS1: ${{ matrix.run_xla_op_tests1 }} + RUN_XLA_OP_TESTS2: ${{ matrix.run_xla_op_tests2 }} + RUN_XLA_OP_TESTS3: ${{ matrix.run_xla_op_tests3 }} + RUN_XLA_OP_TESTS4: ${{ matrix.run_xla_op_tests4 }} + RUN_XLA_OP_TESTS5: ${{ matrix.run_xla_op_tests5 }} + RUN_TORCH_MP_OP_TESTS: ${{ matrix.run_torch_mp_op_tests }} + RUN_CPP_TESTS: ${{ matrix.run_cpp_tests }} + BAZEL_JOBS: "" # Let bazel decide the parallelism based on the number of CPUs. + BAZEL_REMOTE_CACHE: 1 + steps: + - name: Checkout actions + if: inputs.has_code_changes == 'true' + uses: actions/checkout@v4 + with: + sparse-checkout: | + .github/workflows/setup + path: .actions + - name: Setup + if: inputs.has_code_changes == 'true' + uses: ./.actions/.github/workflows/setup + with: + torch-commit: ${{ inputs.torch-commit }} + wheels-artifact: torch-xla-wheels + - name: Fetch CPP test binaries + if: inputs.has_code_changes == 'true' && matrix.run_cpp_tests + uses: actions/download-artifact@v5 + with: + name: cpp-test-bin + path: /tmp/test/bin + # GitHub Actions doesn't preserve executable permissions + # https://github.com/actions/download-artifact?tab=readme-ov-file#permission-loss + - name: Set CPP test permissions + if: inputs.has_code_changes == 'true' && matrix.run_cpp_tests + run: | + chmod +x /tmp/test/bin/* + ls -l /tmp/test/bin + - name: Install test deps + if: inputs.has_code_changes == 'true' + shell: bash + run: | + # TODO: Add these in setup.py + pip install fsspec + pip install rich + pip install flax + - name: Checkout PyTorch Repo + if: inputs.has_code_changes == 'true' + uses: actions/checkout@v4 + with: + repository: pytorch/pytorch + path: pytorch + ref: ${{ inputs.torch-commit }} + - name: Checkout PyTorch/XLA Repo + if: inputs.has_code_changes == 'true' + uses: actions/checkout@v4 + with: + path: pytorch/xla + - name: Extra CI deps + if: inputs.has_code_changes == 'true' + shell: bash + run: | + set -x - pip install expecttest unittest-xml-reporting - pip install --pre 'torch_xla[pallas]' --index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html + pip install expecttest unittest-xml-reporting + pip install --pre 'torch_xla[pallas]' --index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html - if [[ ! -z "$RUN_BENCHMARK_TESTS" ]]; then - pip install -r pytorch/xla/benchmarks/requirements.txt - fi - - name: Test - if: inputs.has_code_changes == 'true' - shell: bash - run: pytorch/xla/.github/scripts/run_tests.sh pytorch/ pytorch/xla/ $USE_COVERAGE - - name: Upload coverage results - if: inputs.has_code_changes == 'true' && inputs.collect-coverage - shell: bash - env: - CIRCLE_WORKFLOW_ID: ${{ github.run_id }} - CIRCLE_BUILD_NUM: ${{ github.run_number }} - BENCHMARK_TEST_NAME: ${{ env.RUN_BENCHMARK_TESTS }} - PYTHON_TEST_NAME: ${{ env.RUN_PYTHON_TESTS }}${{ env.RUN_XLA_OP_TESTS1 }}${{ env.RUN_XLA_OP_TESTS2 }}${{ env.RUN_XLA_OP_TESTS3 }}${{ env.RUN_XLA_OP_TESTS4 }}${{ env.RUN_XLA_OP_TESTS5 }}${{ env.RUN_TORCH_MP_OP_TESTS }} - CPP_TEST_NAME: ${{ env.RUN_CPP_TESTS }} - run: | - # TODO(yeounoh) collect coverage report as needed. - if [ -n "${BENCHMARK_TEST_NAME}" ]; then - exit 0 - fi - docker cp "${pid}":/home/jenkins/htmlcov "${GITHUB_WORKSPACE}" - if [ -n "${PYTHON_TEST_NAME}" ]; then - gsutil cp ${GITHUB_WORKSPACE}/htmlcov/lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/absolute/pytorchxla/${CIRCLE_WORKFLOW_ID}/cpu_python_coverage_${PYTHON_TEST_NAME}.out - gsutil cp ${GITHUB_WORKSPACE}/htmlcov/lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/incremental/pytorchxla/${CIRCLE_WORKFLOW_ID}/cpu_python_coverage_${PYTHON_TEST_NAME}.out - fi + if [[ ! -z "$RUN_BENCHMARK_TESTS" ]]; then + pip install -r pytorch/xla/benchmarks/requirements.txt + fi + - name: Test + if: inputs.has_code_changes == 'true' + shell: bash + run: pytorch/xla/.github/scripts/run_tests.sh pytorch/ pytorch/xla/ $USE_COVERAGE + - name: Upload coverage results + if: inputs.has_code_changes == 'true' && inputs.collect-coverage + shell: bash + env: + CIRCLE_WORKFLOW_ID: ${{ github.run_id }} + CIRCLE_BUILD_NUM: ${{ github.run_number }} + BENCHMARK_TEST_NAME: ${{ env.RUN_BENCHMARK_TESTS }} + PYTHON_TEST_NAME: ${{ env.RUN_PYTHON_TESTS }}${{ env.RUN_XLA_OP_TESTS1 }}${{ env.RUN_XLA_OP_TESTS2 }}${{ env.RUN_XLA_OP_TESTS3 }}${{ env.RUN_XLA_OP_TESTS4 }}${{ env.RUN_XLA_OP_TESTS5 }}${{ env.RUN_TORCH_MP_OP_TESTS }} + CPP_TEST_NAME: ${{ env.RUN_CPP_TESTS }} + run: | + # TODO(yeounoh) collect coverage report as needed. + if [ -n "${BENCHMARK_TEST_NAME}" ]; then + exit 0 + fi + docker cp "${pid}":/home/jenkins/htmlcov "${GITHUB_WORKSPACE}" + if [ -n "${PYTHON_TEST_NAME}" ]; then + gsutil cp ${GITHUB_WORKSPACE}/htmlcov/lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/absolute/pytorchxla/${CIRCLE_WORKFLOW_ID}/cpu_python_coverage_${PYTHON_TEST_NAME}.out + gsutil cp ${GITHUB_WORKSPACE}/htmlcov/lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/incremental/pytorchxla/${CIRCLE_WORKFLOW_ID}/cpu_python_coverage_${PYTHON_TEST_NAME}.out + fi - if [ -n "${CPP_TEST_NAME}" ]; then - gsutil cp ${GITHUB_WORKSPACE}/htmlcov/cpp_lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/absolute/pytorchxla/${CIRCLE_WORKFLOW_ID}/cpu_cpp_coverage_${CPP_TEST_NAME}.out - gsutil cp ${GITHUB_WORKSPACE}/htmlcov/cpp_lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/incremental/pytorchxla/${CIRCLE_WORKFLOW_ID}/cpu_cpp_coverage_${CPP_TEST_NAME}.out - fi + if [ -n "${CPP_TEST_NAME}" ]; then + gsutil cp ${GITHUB_WORKSPACE}/htmlcov/cpp_lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/absolute/pytorchxla/${CIRCLE_WORKFLOW_ID}/cpu_cpp_coverage_${CPP_TEST_NAME}.out + gsutil cp ${GITHUB_WORKSPACE}/htmlcov/cpp_lcov.info gs://ng3-metrics/ng3-pytorchxla-coverage/incremental/pytorchxla/${CIRCLE_WORKFLOW_ID}/cpu_cpp_coverage_${CPP_TEST_NAME}.out + fi - if [ "${CPP_TEST_NAME}" == "cpp_tests" ]; then - ABS_METADATA='{"host": "github", "project": "pytorchxla", "trace_type": "LCOV", "commit_id": '\"${GITHUB_SHA}\"', "ref": "HEAD", "source": "https://github.com/pytorch/xla", "owner": "cloud-tpu-pt-dev", "bug_component": "587012"}' - echo $ABS_METADATA > abs_metadata.json - gsutil cp abs_metadata.json gs://ng3-metrics/ng3-pytorchxla-coverage/absolute/pytorchxla/${CIRCLE_WORKFLOW_ID}/metadata.json + if [ "${CPP_TEST_NAME}" == "cpp_tests" ]; then + ABS_METADATA='{"host": "github", "project": "pytorchxla", "trace_type": "LCOV", "commit_id": '\"${GITHUB_SHA}\"', "ref": "HEAD", "source": "https://github.com/pytorch/xla", "owner": "cloud-tpu-pt-dev", "bug_component": "587012"}' + echo $ABS_METADATA > abs_metadata.json + gsutil cp abs_metadata.json gs://ng3-metrics/ng3-pytorchxla-coverage/absolute/pytorchxla/${CIRCLE_WORKFLOW_ID}/metadata.json - INC_METADATA='{"host": "github", "project": "pytorchxla", "trace_type": "LCOV", "patchset_num": 1, "change_id": '${CIRCLE_BUILD_NUM}', "owner": "cloud-tpu-pt-dev", "bug_component": "587012"}' - echo $INC_METADATA > inc_metadata.json - gsutil cp inc_metadata.json gs://ng3-metrics/ng3-pytorchxla-coverage/incremental/pytorchxla/${CIRCLE_WORKFLOW_ID}/metadata.json - fi - - name: Report no code changes - if: inputs.has_code_changes == 'false' - run: | - echo "No code changes were detected that require running the full test suite." + INC_METADATA='{"host": "github", "project": "pytorchxla", "trace_type": "LCOV", "patchset_num": 1, "change_id": '${CIRCLE_BUILD_NUM}', "owner": "cloud-tpu-pt-dev", "bug_component": "587012"}' + echo $INC_METADATA > inc_metadata.json + gsutil cp inc_metadata.json gs://ng3-metrics/ng3-pytorchxla-coverage/incremental/pytorchxla/${CIRCLE_WORKFLOW_ID}/metadata.json + fi + - name: Report no code changes + if: inputs.has_code_changes == 'false' + run: | + echo "No code changes were detected that require running the full test suite." diff --git a/.github/workflows/_torchprime_ci.yml b/.github/workflows/_torchprime_ci.yml index 55ab65193df0..93172a7492ac 100644 --- a/.github/workflows/_torchprime_ci.yml +++ b/.github/workflows/_torchprime_ci.yml @@ -1,115 +1,115 @@ name: torchprime E2E tests description: | - This workflow builds a docker image with the PyTorch/XLA wheels and then - triggers a torchprime (https://github.com/AI-Hypercomputer/torchprime) - E2E test using that docker image. It is intended to catch performance - regressions and API breaking changes in PyTorch/XLA pull requests. + This workflow builds a docker image with the PyTorch/XLA wheels and then + triggers a torchprime (https://github.com/AI-Hypercomputer/torchprime) + E2E test using that docker image. It is intended to catch performance + regressions and API breaking changes in PyTorch/XLA pull requests. on: - workflow_call: - inputs: - timeout-minutes: - required: false - type: number - description: Timeout in minutes for the job run - default: 80 - has_code_changes: - required: false - type: string - description: Whether to run full workflow or not - default: 'true' - secrets: - # This is a token for the `torchxlabot2` user, which has access to the torchprime repo. - # It is used to trigger the torchprime E2E test workflow. - # The token should be managed in the "Settings > Secrets and variables > Actions" - # section of the repo. - TORCH_XLA_BOT_TOKEN: - required: true - GCLOUD_SERVICE_KEY: - required: true + workflow_call: + inputs: + timeout-minutes: + required: false + type: number + description: Timeout in minutes for the job run + default: 80 + has_code_changes: + required: false + type: string + description: Whether to run full workflow or not + default: "true" + secrets: + # This is a token for the `torchxlabot2` user, which has access to the torchprime repo. + # It is used to trigger the torchprime E2E test workflow. + # The token should be managed in the "Settings > Secrets and variables > Actions" + # section of the repo. + TORCH_XLA_BOT_TOKEN: + required: true + GCLOUD_SERVICE_KEY: + required: true jobs: - torchprime-e2e-test: - name: Run torchprime E2E tests - timeout-minutes: ${{ inputs.timeout-minutes }} - runs-on: ubuntu-22.04 - steps: - - name: Use Docker in rootless mode - if: inputs.has_code_changes == 'true' - uses: ScribeMD/rootless-docker@0.2.2 - - name: Add user to docker group - if: inputs.has_code_changes == 'true' - run: | - sudo usermod -aG docker $USER - newgrp docker - shell: bash - # Googlers: if this fails, follow go/ptxla-sa-key to debug. - - uses: google-github-actions/auth@v2 - if: inputs.has_code_changes == 'true' - with: - credentials_json: '${{ secrets.GCLOUD_SERVICE_KEY }}' - - uses: google-github-actions/setup-gcloud@v2 - if: inputs.has_code_changes == 'true' - with: - version: '>= 363.0.0' - install_components: 'beta,gke-gcloud-auth-plugin' - - name: Verify GCP setup - if: inputs.has_code_changes == 'true' - run: gcloud info - shell: bash - - name: Authenticate Docker - if: inputs.has_code_changes == 'true' - run: gcloud auth configure-docker --quiet - shell: bash - - name: Activate SA credentials - if: inputs.has_code_changes == 'true' - run: gcloud auth activate-service-account --key-file=$GOOGLE_APPLICATION_CREDENTIALS - shell: bash - - name: Checkout infra - if: inputs.has_code_changes == 'true' - uses: actions/checkout@v4 - with: - sparse-checkout: | - infra - fetch-depth: 1 - path: pytorch-xla - # Build a docker image for torchprime E2E test - # First download the torch-xla-wheels - - name: Fetch wheels - if: inputs.has_code_changes == 'true' - uses: actions/download-artifact@v4 - with: - name: torch-xla-wheels - path: /tmp/wheels/ - # Generate a 16-character random ID for the docker tag - - name: Generate random docker tag - if: inputs.has_code_changes == 'true' - id: random_tag - shell: bash - run: | - echo "random_id=$(openssl rand -hex 8)" >> $GITHUB_OUTPUT - # Then run docker to install them and push a docker - - name: Build and push docker image - if: inputs.has_code_changes == 'true' - id: build_docker - shell: bash - working-directory: pytorch-xla - run: | - . ./infra/ansible/publish_torchprime_e2e_test_docker.sh - echo "docker_url=gcr.io/${DOCKER_PROJECT}/${DOCKER_IMAGE_NAME}:${DOCKER_IMAGE_TAG}" >> $GITHUB_OUTPUT - env: - DEFAULT_CONTEXT_PATH: /tmp/wheels - DOCKER_IMAGE_NAME: for-torchprime-ci - DOCKER_IMAGE_TAG: ${{ steps.random_tag.outputs.random_id }} - DOCKER_PROJECT: tpu-pytorch - # Trigger torchprime E2E test workflow. - # (Googlers only) in case of infra failure, refer to go/ptxla-torchprime-trigger - # Refer to the same doc on the retention policy of the docker images. - - uses: convictional/trigger-workflow-and-wait@v1.6.5 - if: inputs.has_code_changes == 'true' - with: - owner: AI-Hypercomputer - repo: torchprime - github_token: ${{ secrets.TORCH_XLA_BOT_TOKEN }} - workflow_file_name: e2e_test.yml - wait_interval: 60 - ref: main - client_payload: '{"docker_url": "${{ steps.build_docker.outputs.docker_url }}"}' + torchprime-e2e-test: + name: Run torchprime E2E tests + timeout-minutes: ${{ inputs.timeout-minutes }} + runs-on: ubuntu-22.04 + steps: + - name: Use Docker in rootless mode + if: inputs.has_code_changes == 'true' + uses: ScribeMD/rootless-docker@0.2.2 + - name: Add user to docker group + if: inputs.has_code_changes == 'true' + run: | + sudo usermod -aG docker $USER + newgrp docker + shell: bash + # Googlers: if this fails, follow go/ptxla-sa-key to debug. + - uses: google-github-actions/auth@v2 + if: inputs.has_code_changes == 'true' + with: + credentials_json: "${{ secrets.GCLOUD_SERVICE_KEY }}" + - uses: google-github-actions/setup-gcloud@v2 + if: inputs.has_code_changes == 'true' + with: + version: ">= 363.0.0" + install_components: "beta,gke-gcloud-auth-plugin" + - name: Verify GCP setup + if: inputs.has_code_changes == 'true' + run: gcloud info + shell: bash + - name: Authenticate Docker + if: inputs.has_code_changes == 'true' + run: gcloud auth configure-docker --quiet + shell: bash + - name: Activate SA credentials + if: inputs.has_code_changes == 'true' + run: gcloud auth activate-service-account --key-file=$GOOGLE_APPLICATION_CREDENTIALS + shell: bash + - name: Checkout infra + if: inputs.has_code_changes == 'true' + uses: actions/checkout@v4 + with: + sparse-checkout: | + infra + fetch-depth: 1 + path: pytorch-xla + # Build a docker image for torchprime E2E test + # First download the torch-xla-wheels + - name: Fetch wheels + if: inputs.has_code_changes == 'true' + uses: actions/download-artifact@v5 + with: + name: torch-xla-wheels + path: /tmp/wheels/ + # Generate a 16-character random ID for the docker tag + - name: Generate random docker tag + if: inputs.has_code_changes == 'true' + id: random_tag + shell: bash + run: | + echo "random_id=$(openssl rand -hex 8)" >> $GITHUB_OUTPUT + # Then run docker to install them and push a docker + - name: Build and push docker image + if: inputs.has_code_changes == 'true' + id: build_docker + shell: bash + working-directory: pytorch-xla + run: | + . ./infra/ansible/publish_torchprime_e2e_test_docker.sh + echo "docker_url=gcr.io/${DOCKER_PROJECT}/${DOCKER_IMAGE_NAME}:${DOCKER_IMAGE_TAG}" >> $GITHUB_OUTPUT + env: + DEFAULT_CONTEXT_PATH: /tmp/wheels + DOCKER_IMAGE_NAME: for-torchprime-ci + DOCKER_IMAGE_TAG: ${{ steps.random_tag.outputs.random_id }} + DOCKER_PROJECT: tpu-pytorch + # Trigger torchprime E2E test workflow. + # (Googlers only) in case of infra failure, refer to go/ptxla-torchprime-trigger + # Refer to the same doc on the retention policy of the docker images. + - uses: convictional/trigger-workflow-and-wait@v1.6.5 + if: inputs.has_code_changes == 'true' + with: + owner: AI-Hypercomputer + repo: torchprime + github_token: ${{ secrets.TORCH_XLA_BOT_TOKEN }} + workflow_file_name: e2e_test.yml + wait_interval: 60 + ref: main + client_payload: '{"docker_url": "${{ steps.build_docker.outputs.docker_url }}"}' diff --git a/.github/workflows/setup/action.yml b/.github/workflows/setup/action.yml index e1d6fdb8599d..c5c9a9ad1e66 100644 --- a/.github/workflows/setup/action.yml +++ b/.github/workflows/setup/action.yml @@ -1,53 +1,53 @@ name: Set up PyTorch/XLA inputs: - torch-commit: - type: string - description: PyTorch commit to check out, if provided - wheels-artifact: - type: string - description: | - Artifact containing `torch` (cpu) and `torch-xla` wheels to install + torch-commit: + type: string + description: PyTorch commit to check out, if provided + wheels-artifact: + type: string + description: | + Artifact containing `torch` (cpu) and `torch-xla` wheels to install runs: - using: "composite" - steps: - # See https://github.com/actions/checkout/issues/1014#issuecomment-1906802802 - - name: Clean up workspace - shell: bash - run: | - ls -la - rm -rvf ${GITHUB_WORKSPACE}/* - - name: Setup gcloud - shell: bash - run: | - echo "${GCLOUD_SERVICE_KEY}" > /tmp/default_credentials.json - echo "GOOGLE_APPLICATION_CREDENTIALS=/tmp/default_credentials.json" >> $GITHUB_ENV - # GCLOUD_SERVICE_KEY needs to be set from the outside because for some - # reason composite actions don't support secrets. - # https://docs.github.com/en/actions/using-workflows/avoiding-duplication - if: ${{ env.GCLOUD_SERVICE_KEY }} - - name: Checkout PyTorch Repo - uses: actions/checkout@v4 - with: - repository: pytorch/pytorch - path: pytorch - ref: ${{ inputs.torch-commit }} - submodules: recursive - if: ${{ inputs.torch-commit }} - - name: Checkout PyTorch/XLA Repo - uses: actions/checkout@v4 - with: - path: pytorch/xla - - name: Fetch PyTorch/XLA packages - uses: actions/download-artifact@v4 - with: - name: ${{ inputs.wheels-artifact }} - path: /tmp/wheels/ - if: ${{ inputs.wheels-artifact }} - - name: Install wheels - shell: bash - run: | - pip install /tmp/wheels/*.whl + using: "composite" + steps: + # See https://github.com/actions/checkout/issues/1014#issuecomment-1906802802 + - name: Clean up workspace + shell: bash + run: | + ls -la + rm -rvf ${GITHUB_WORKSPACE}/* + - name: Setup gcloud + shell: bash + run: | + echo "${GCLOUD_SERVICE_KEY}" > /tmp/default_credentials.json + echo "GOOGLE_APPLICATION_CREDENTIALS=/tmp/default_credentials.json" >> $GITHUB_ENV + # GCLOUD_SERVICE_KEY needs to be set from the outside because for some + # reason composite actions don't support secrets. + # https://docs.github.com/en/actions/using-workflows/avoiding-duplication + if: ${{ env.GCLOUD_SERVICE_KEY }} + - name: Checkout PyTorch Repo + uses: actions/checkout@v4 + with: + repository: pytorch/pytorch + path: pytorch + ref: ${{ inputs.torch-commit }} + submodules: recursive + if: ${{ inputs.torch-commit }} + - name: Checkout PyTorch/XLA Repo + uses: actions/checkout@v4 + with: + path: pytorch/xla + - name: Fetch PyTorch/XLA packages + uses: actions/download-artifact@v5 + with: + name: ${{ inputs.wheels-artifact }} + path: /tmp/wheels/ + if: ${{ inputs.wheels-artifact }} + - name: Install wheels + shell: bash + run: | + pip install /tmp/wheels/*.whl - echo "Import check..." - python -c "import torch_xla" - if: ${{ inputs.wheels-artifact }} + echo "Import check..." + python -c "import torch_xla" + if: ${{ inputs.wheels-artifact }} diff --git a/WORKSPACE b/WORKSPACE index 70b7d9cc098d..78e928d2a0f0 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -42,6 +42,12 @@ new_local_repository( path = PYTORCH_LOCAL_DIR, ) +new_local_repository( + name = "fmt", + build_file = "//bazel:fmt.BUILD", + path = PYTORCH_LOCAL_DIR + "/third_party/fmt", +) + ############################# OpenXLA Setup ############################### # To build PyTorch/XLA with a new revison of OpenXLA, update the xla_hash to @@ -82,7 +88,7 @@ http_archive( # Initialize OpenXLA's external dependencies. There is an specific order # which those dependencies are initialized, because for bazel it's the # first definition that takes precedence. -# We follow what openxla/xla does exactly: +# We follow what openxla/xla does exactly: # https://github.com/openxla/xla/blob/main/WORKSPACE#L37 load("@xla//:workspace4.bzl", "xla_workspace4") diff --git a/bazel/fmt.BUILD b/bazel/fmt.BUILD new file mode 100644 index 000000000000..ea8c566b98a5 --- /dev/null +++ b/bazel/fmt.BUILD @@ -0,0 +1,9 @@ +load("@rules_cc//cc:defs.bzl", "cc_library") + +cc_library( + name = "fmt", + hdrs = glob(["include/fmt/*.h",]), + defines = ["FMT_HEADER_ONLY=1"], + includes = ["include"], + visibility = ["//visibility:public"], +) diff --git a/bazel/torch.BUILD b/bazel/torch.BUILD index afc6bb57af9e..cfbd620e0777 100644 --- a/bazel/torch.BUILD +++ b/bazel/torch.BUILD @@ -10,6 +10,9 @@ cc_library( ["torch/include/**/*.h"], ["torch/include/google/protobuf/**/*.h"], ), + deps = [ + "@fmt", + ], strip_include_prefix = "torch/include", ) diff --git a/scripts/build_torch_wheels.sh b/scripts/build_torch_wheels.sh index 0f30a1e4e623..e773b5c04499 100755 --- a/scripts/build_torch_wheels.sh +++ b/scripts/build_torch_wheels.sh @@ -119,9 +119,9 @@ function install_llvm_clang() { function install_gcc() { sudo apt-get -y install gcc-11 g++-11 - export CC=/usr/bin/gcc-10 export CXX=/usr/bin/g++-11 - sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 100 - sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-10 100 + export CC=/usr/bin/gcc-11 export CXX=/usr/bin/g++-11 + sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-11 100 + sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-11 100 } function install_req_packages() { From 03d4dc06772a97c8092ecc4de65e904a898c8ff8 Mon Sep 17 00:00:00 2001 From: Hsiang-Chieh Tsou <65450151+hsjts0u@users.noreply.github.com> Date: Mon, 29 Sep 2025 15:09:36 -0700 Subject: [PATCH 112/133] Add default args for _aten_conv2d (#9623) Add default args for _aten_conv2d, which would otherwise fail in the following code snippet ```python import torch from torch.export import export_for_training import torchax from torchax import interop from torch.utils import _pytree as pytree import jax from torchax.ops import mappings class Simple(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=4, bias=False) def forward(self, x): x = self.conv1(x) return x model = Simple() exported = export_for_training(model, (torch.randn(1, 3, 224, 224),)) def make_shape_struct(x): return jax.ShapeDtypeStruct(x.shape, mappings.t2j_dtype(x.dtype)) def map_nth(v, i): def f(t): if isinstance(t, torch.Tensor): return t[i : i + 1] return t return pytree.tree_map(f, v) env = torchax.default_env() with env: model = exported.module().to("jax") def func_to_export(x): # hard code weights in model return model(x) example_inputs_jax = pytree.tree_map_only( torch.Tensor, lambda x: x.to("jax"), map_nth(exported.example_inputs, 0) ) res = jax.jit(interop.jax_view(func_to_export)).lower(*example_inputs_jax[0]) # TypeError: _aten_conv2d() missing 5 required positional arguments: 'bias', 'stride', 'padding', 'dilation', and 'groups' ``` cc @qihqi --- torchax/torchax/ops/jaten.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchax/torchax/ops/jaten.py b/torchax/torchax/ops/jaten.py index 851a2d6103ef..700d581d7736 100644 --- a/torchax/torchax/ops/jaten.py +++ b/torchax/torchax/ops/jaten.py @@ -1017,11 +1017,11 @@ def _aten_bucketize(input, def _aten_conv2d( input, weight, - bias, - stride, - padding, - dilation, - groups, + bias=None, + stride=[1, 1], + padding=[0, 0], + dilation=[1, 1], + groups=1, ): return _aten_convolution( input, From 302c3f1f7b2af22732aee43b6441d9f3e5ff75e1 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Tue, 30 Sep 2025 20:17:49 -0300 Subject: [PATCH 113/133] Pin `flax` and skip C++ test `SiLUBackward`. (#9660) Since https://github.com/pytorch/pytorch/pull/162659 was merged again, we observed that `SiLUBackward` C++ test was crashing with a segmentation fault #9561. Not only that, but TPU tests started failing because `flax` 0.12.0 (old: 0.11.2) started pulling a newer `jax` 0.7.2 (old: 0.7.1). - Old CI build: [link](https://github.com/pytorch/xla/actions/runs/17931468317/job/51089906800) - Recent broken CI build: [link](https://github.com/pytorch/xla/actions/runs/18008717023/job/51550125217?pr=9655) Therefore, in this PR: - Pin `flax` to version 0.11.2 - Skip `SiLUBackward` C++ test Additionally, it also installs `jax` and `libtpu` using the CI PyTorch/XLA wheels metadata instead of using PyPI wheels metadata. This should avoid other version compatibilities. --- .github/workflows/_tpu_ci.yml | 40 +++++++++++++++-- test/cpp/test_aten_xla_tensor_1.cpp | 32 +++++++++++++ test/cpp/test_aten_xla_tensor_2.cpp | 4 ++ test/cpp/test_aten_xla_tensor_3.cpp | 30 +++++++++++++ test/cpp/test_aten_xla_tensor_4.cpp | 8 ++++ test/cpp/test_aten_xla_tensor_5.cpp | 2 + test/cpp/test_aten_xla_tensor_6.cpp | 70 +++++++++++++++++++++++++++++ 7 files changed, 182 insertions(+), 4 deletions(-) diff --git a/.github/workflows/_tpu_ci.yml b/.github/workflows/_tpu_ci.yml index b67f695f81ed..656372263d28 100644 --- a/.github/workflows/_tpu_ci.yml +++ b/.github/workflows/_tpu_ci.yml @@ -37,25 +37,56 @@ jobs: sparse-checkout: | .github/workflows/setup path: .actions + - name: Setup if: inputs.has_code_changes == 'true' uses: ./.actions/.github/workflows/setup with: torch-commit: ${{ inputs.torch-commit }} wheels-artifact: torch-xla-wheels + - name: Install test dependencies if: inputs.has_code_changes == 'true' shell: bash run: | + set -x + # TODO: Add these in setup.py pip install --upgrade pip pip install fsspec pip install rich - # jax and libtpu is needed for pallas tests. - pip install --pre 'torch_xla[pallas]' --index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html' - pip install --pre 'torch_xla[tpu]' --index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html' + + # PyTorch/XLA Optional Dependencies + # ================================= + # + # Install `JAX` and `libtpu` dependencies for pallas and TPU tests. + # + # Note that we might need to install pre-release versions of both, in + # external artifact repositories. + + # Retrieve the PyTorch/XLA ".whl" file. + # This assumes PyTorch/XLA wheels are downloaded in "/tmp/wheels". + WHL=$(ls /tmp/wheels/torch_xla*) + + # Links for finding `jax` and `libtpu` versions. + INDEX="https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ " + LINKS="https://storage.googleapis.com/jax-releases/libtpu_releases.html" + + pip install "$WHL[pallas]" --pre --index-url $INDEX --find-links $LINKS + pip install "$WHL[tpu]" --pre --index-url $INDEX --find-links $LINKS + pip install --upgrade protobuf - pip install flax + + # Flax Pin + # ======== + # + # Be careful when bumping the `flax` version, since it can cause tests that + # depend on `jax` to start breaking. + # + # Newer `flax` versions might pull newer `jax` versions, which might be incompatible + # with the current version of PyTorch/XLA. + pip install flax==0.11.2 + - name: Run Tests (${{ matrix.test_script }}) if: inputs.has_code_changes == 'true' env: @@ -64,6 +95,7 @@ jobs: run: | cd pytorch/xla ${{ matrix.test_script }} + - name: Report no code changes # Only report the first instance if: inputs.has_code_changes == 'false' && strategy.job-index == 0 diff --git a/test/cpp/test_aten_xla_tensor_1.cpp b/test/cpp/test_aten_xla_tensor_1.cpp index 2c79925bc161..bac303be96b4 100644 --- a/test/cpp/test_aten_xla_tensor_1.cpp +++ b/test/cpp/test_aten_xla_tensor_1.cpp @@ -356,6 +356,8 @@ TEST_F(AtenXlaTensorTest, TestSiLU) { } TEST_F(AtenXlaTensorTest, TestSiLUBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::silu(inputs[0]); }; @@ -681,6 +683,8 @@ TEST_F(AtenXlaTensorTest, TestEinsumOuter) { } TEST_F(AtenXlaTensorTest, TestEinsumOuterBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; torch::Tensor a = torch::rand({5}, torch::TensorOptions(torch::kFloat).requires_grad(true)); torch::Tensor b = @@ -719,6 +723,8 @@ TEST_F(AtenXlaTensorTest, TestEinsumBatchMatMul) { } TEST_F(AtenXlaTensorTest, TestEinsumBatchMatMulBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; if (UsingTpu()) { GTEST_SKIP(); } @@ -759,6 +765,8 @@ TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerBilinear) { } TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerBilinearBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; torch::Tensor a = torch::rand( {3, 5, 4}, torch::TensorOptions(torch::kFloat).requires_grad(true)); torch::Tensor l = torch::rand( @@ -795,6 +803,8 @@ TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerDiagonal) { } TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerDiagonalBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; torch::Tensor input = torch::rand( {3, 3}, torch::TensorOptions(torch::kFloat).requires_grad(true)); std::string equation = "ii->i"; @@ -827,6 +837,8 @@ TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerBatchDiagonal) { } TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerBatchDiagonalBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; torch::Tensor input = torch::rand( {4, 3, 3}, torch::TensorOptions(torch::kFloat).requires_grad(true)); std::string equation = "...ii->...i"; @@ -859,6 +871,8 @@ TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerBatchPermute) { } TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerBatchPermuteBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; torch::Tensor input = torch::rand( {2, 3, 4, 5}, torch::TensorOptions(torch::kFloat).requires_grad(true)); std::string equation = "...ij->...ji"; @@ -892,6 +906,8 @@ TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerRepeatedAxis) { } TEST_F(AtenXlaTensorTest, TestEinsumPyTorchLowerRepeatedAxisBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; torch::Tensor x = torch::rand( {2, 3, 3}, torch::TensorOptions(torch::kFloat).requires_grad(true)); torch::Tensor y = @@ -1036,6 +1052,8 @@ TEST_F(AtenXlaTensorTest, TestUpsampleNearest2D) { } TEST_F(AtenXlaTensorTest, TestUpsampleNearest2DBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; int batch_size = 2; int h = 5; int w = 5; @@ -1094,6 +1112,8 @@ TEST_F(AtenXlaTensorTest, TestUpsampleNearest2DWithScale) { } TEST_F(AtenXlaTensorTest, TestUpsampleNearest2DBackwardWithScale) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; struct ImageInfo { int batch_size; int h; @@ -1223,6 +1243,8 @@ TEST_F(AtenXlaTensorTest, TestUpsampleBilinear2DWithScale) { } TEST_F(AtenXlaTensorTest, TestUpsampleBilinear2DBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; int batch_size = 2; int h = 5; int w = 5; @@ -1245,6 +1267,8 @@ TEST_F(AtenXlaTensorTest, TestUpsampleBilinear2DBackward) { } TEST_F(AtenXlaTensorTest, TestUpsampleBilinear2DBackwardWithScale) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; struct ImageInfo { int batch_size; int h; @@ -1610,6 +1634,8 @@ TEST_F(AtenXlaTensorTest, TestTake) { } TEST_F(AtenXlaTensorTest, TestTakeBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::take(inputs[0], inputs[1]); }; @@ -3499,6 +3525,8 @@ TEST_F(AtenXlaTensorTest, TestPrelu) { } TEST_F(AtenXlaTensorTest, TestPreluBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::prelu(inputs[0], inputs[1]); }; @@ -3583,6 +3611,8 @@ TEST_F(AtenXlaTensorTest, TestHardSigmoidInPlace) { } TEST_F(AtenXlaTensorTest, TestHardSigmoidBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::hardsigmoid(inputs[0]); }; @@ -3625,6 +3655,8 @@ TEST_F(AtenXlaTensorTest, TestHardSwishInPlace) { } TEST_F(AtenXlaTensorTest, TestHardSwishBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::hardswish(inputs[0]); }; diff --git a/test/cpp/test_aten_xla_tensor_2.cpp b/test/cpp/test_aten_xla_tensor_2.cpp index dc3d605da34a..013abee5563d 100644 --- a/test/cpp/test_aten_xla_tensor_2.cpp +++ b/test/cpp/test_aten_xla_tensor_2.cpp @@ -1536,6 +1536,8 @@ TEST_F(AtenXlaTensorTest, TestGroupNorm) { } TEST_F(AtenXlaTensorTest, TestGroupNormBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; int num_channels = 6; torch::Tensor input = torch::rand({20, num_channels, 10, 10}, @@ -1642,6 +1644,8 @@ TEST_F(AtenXlaTensorTest, TestLayerNorm) { } TEST_F(AtenXlaTensorTest, TestLayerNormBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; torch::Tensor input = torch::rand( {2, 3, 3, 3}, torch::TensorOptions(torch::kFloat).requires_grad(true)); double eps = 1e-05; diff --git a/test/cpp/test_aten_xla_tensor_3.cpp b/test/cpp/test_aten_xla_tensor_3.cpp index 7ea9ebb959bf..1bdb16c818b6 100644 --- a/test/cpp/test_aten_xla_tensor_3.cpp +++ b/test/cpp/test_aten_xla_tensor_3.cpp @@ -664,6 +664,8 @@ TEST_F(AtenXlaTensorTest, TestReflectionPad1dRank3) { } TEST_F(AtenXlaTensorTest, TestReflectionPad1dBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; std::vector pad{2, 2}; auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::reflection_pad1d(inputs[0], pad); @@ -709,6 +711,8 @@ TEST_F(AtenXlaTensorTest, TestReflectionPad2dRank4) { } TEST_F(AtenXlaTensorTest, TestReflectionPad2dBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; std::vector pad{2, 3, 1, 2}; auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::reflection_pad2d(inputs[0], pad); @@ -754,6 +758,8 @@ TEST_F(AtenXlaTensorTest, TestReflectionPad3dRank4) { } TEST_F(AtenXlaTensorTest, TestReflectionPad3dBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; std::vector pad{1, 1, 1, 1, 1, 1}; auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::reflection_pad3d(inputs[0], pad); @@ -801,6 +807,8 @@ TEST_F(AtenXlaTensorTest, TestReplicationPad1dZeroPad) { } TEST_F(AtenXlaTensorTest, TestReplicationPad1dBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; std::vector pad{2, 3}; auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::replication_pad1d(inputs[0], pad); @@ -848,6 +856,8 @@ TEST_F(AtenXlaTensorTest, TestReplicationPad2dZeroPad) { } TEST_F(AtenXlaTensorTest, TestReplicationPad2dBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; std::vector pad{2, 3, 1, 1}; auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::replication_pad2d(inputs[0], pad); @@ -895,6 +905,8 @@ TEST_F(AtenXlaTensorTest, TestReplicationPad3dZeroPad) { } TEST_F(AtenXlaTensorTest, TestReplicationPad3dBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; std::vector pad{2, 3, 1, 1, 1, 1}; auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::replication_pad3d(inputs[0], pad); @@ -1131,6 +1143,8 @@ TEST_F(AtenXlaTensorTest, TestAsStridedMultipleDimMismatch) { } TEST_F(AtenXlaTensorTest, TestAvgPool2DBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; int kernel_size = 2; for (int stride = 1; stride <= 2; ++stride) { for (int padding = 0; padding <= 1; ++padding) { @@ -1161,6 +1175,8 @@ TEST_F(AtenXlaTensorTest, TestAvgPool2DBackward) { } TEST_F(AtenXlaTensorTest, TestAvgPool3DBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; int kernel_size = 2; for (int stride = 1; stride <= 2; ++stride) { for (int padding = 0; padding <= 1; ++padding) { @@ -1192,6 +1208,8 @@ TEST_F(AtenXlaTensorTest, TestAvgPool3DBackward) { } TEST_F(AtenXlaTensorTest, TestAvgPool2DNoBatchBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; int kernel_size = 2; for (int stride = 1; stride <= 2; ++stride) { for (int padding = 0; padding <= 1; ++padding) { @@ -1222,6 +1240,8 @@ TEST_F(AtenXlaTensorTest, TestAvgPool2DNoBatchBackward) { } TEST_F(AtenXlaTensorTest, TestAvgPool3DNoBatchBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; int kernel_size = 2; for (int stride = 1; stride <= 2; ++stride) { for (int padding = 0; padding <= 1; ++padding) { @@ -1253,6 +1273,8 @@ TEST_F(AtenXlaTensorTest, TestAvgPool3DNoBatchBackward) { } TEST_F(AtenXlaTensorTest, TestAdaptiveAvgPool3DNoBatchBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; for (int64_t output_size : {7, 4}) { auto testfn = [&](const std::vector& inputs) -> torch::Tensor { @@ -1273,6 +1295,8 @@ TEST_F(AtenXlaTensorTest, TestAdaptiveAvgPool3DNoBatchBackward) { } TEST_F(AtenXlaTensorTest, TestAdaptiveAvgPool3DBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; for (int64_t output_size : {7, 4}) { auto testfn = [&](const std::vector& inputs) -> torch::Tensor { @@ -1293,6 +1317,8 @@ TEST_F(AtenXlaTensorTest, TestAdaptiveAvgPool3DBackward) { } TEST_F(AtenXlaTensorTest, TestAdaptiveAvgPool2DBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; for (int64_t output_size : {7, 8}) { auto testfn = [&](const std::vector& inputs) -> torch::Tensor { @@ -1312,6 +1338,8 @@ TEST_F(AtenXlaTensorTest, TestAdaptiveAvgPool2DBackward) { } TEST_F(AtenXlaTensorTest, TestAdaptiveAvgPool2DNoBatchBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; for (int64_t output_size : {7, 8}) { auto testfn = [&](const std::vector& inputs) -> torch::Tensor { @@ -1329,6 +1357,8 @@ TEST_F(AtenXlaTensorTest, TestAdaptiveAvgPool2DNoBatchBackward) { } TEST_F(AtenXlaTensorTest, TestConv3DBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; int in_channels = 4; int out_channels = 8; int kernel_size = 5; diff --git a/test/cpp/test_aten_xla_tensor_4.cpp b/test/cpp/test_aten_xla_tensor_4.cpp index 5b1d99524b8c..1283cec89967 100644 --- a/test/cpp/test_aten_xla_tensor_4.cpp +++ b/test/cpp/test_aten_xla_tensor_4.cpp @@ -569,6 +569,8 @@ TEST_F(AtenXlaTensorTest, TestRsubScalar) { } TEST_F(AtenXlaTensorTest, TestConv2DBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; int in_channels = 4; int out_channels = 8; int kernel_size = 5; @@ -609,6 +611,8 @@ TEST_F(AtenXlaTensorTest, TestConv2DBackward) { } TEST_F(AtenXlaTensorTest, TestTransposedConv2DBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; int in_channels = 4; int out_channels = 8; int kernel_size = 5; @@ -746,6 +750,8 @@ TEST_F(AtenXlaTensorTest, TestL1Loss) { } TEST_F(AtenXlaTensorTest, TestL1LossBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; for (torch::Reduction::Reduction reduction : {torch::Reduction::None, torch::Reduction::Mean, torch::Reduction::Sum}) { @@ -784,6 +790,8 @@ TEST_F(AtenXlaTensorTest, TestMseLoss) { } TEST_F(AtenXlaTensorTest, TestMseLossBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; for (torch::Reduction::Reduction reduction : {torch::Reduction::None, torch::Reduction::Mean, torch::Reduction::Sum}) { diff --git a/test/cpp/test_aten_xla_tensor_5.cpp b/test/cpp/test_aten_xla_tensor_5.cpp index 07e4c2dae86b..19beae5789b2 100644 --- a/test/cpp/test_aten_xla_tensor_5.cpp +++ b/test/cpp/test_aten_xla_tensor_5.cpp @@ -1451,6 +1451,8 @@ TEST_F(AtenXlaTensorTest, TestAdaptiveMaxPool2D) { } TEST_F(AtenXlaTensorTest, TestAdaptiveMaxPool2DBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; XlaDeviceType hw_type = static_cast(bridge::GetDefaultDevice()->type()); // skip this test until the tile mismatch bug is fixed. diff --git a/test/cpp/test_aten_xla_tensor_6.cpp b/test/cpp/test_aten_xla_tensor_6.cpp index b9a669760b1b..ca2ad6498ca3 100644 --- a/test/cpp/test_aten_xla_tensor_6.cpp +++ b/test/cpp/test_aten_xla_tensor_6.cpp @@ -24,6 +24,8 @@ class AtenXlaTensorTest : public AtenXlaTensorTestBase {}; } // namespace TEST_F(AtenXlaTensorTest, TestTransposedConv3DBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; int in_channels = 4; int out_channels = 8; int kernel_size = 5; @@ -69,6 +71,8 @@ TEST_F(AtenXlaTensorTest, TestTransposedConv3DBackward) { } TEST_F(AtenXlaTensorTest, TestMaxPool2DBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; int kernel_size = 3; for (int stride = 1; stride <= 2; ++stride) { for (int padding = 0; padding <= 1; ++padding) { @@ -99,6 +103,8 @@ TEST_F(AtenXlaTensorTest, TestMaxPool2DBackward) { } TEST_F(AtenXlaTensorTest, TestMaxPool3DBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; int kernel_size = 3; for (int stride = 1; stride <= 2; ++stride) { for (int padding = 0; padding <= 1; ++padding) { @@ -130,6 +136,8 @@ TEST_F(AtenXlaTensorTest, TestMaxPool3DBackward) { } TEST_F(AtenXlaTensorTest, TestMaxPool2DNoBatchBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; int kernel_size = 3; for (int stride = 1; stride <= 2; ++stride) { for (int padding = 0; padding <= 1; ++padding) { @@ -157,6 +165,8 @@ TEST_F(AtenXlaTensorTest, TestMaxPool2DNoBatchBackward) { } TEST_F(AtenXlaTensorTest, TestMaxPool3DNoBatchBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; int kernel_size = 3; for (int stride = 1; stride <= 2; ++stride) { for (int padding = 0; padding <= 1; ++padding) { @@ -188,6 +198,8 @@ TEST_F(AtenXlaTensorTest, TestMaxPool3DNoBatchBackward) { } TEST_F(AtenXlaTensorTest, TestMaxUnpool2DBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; int kernel_size = 2; torch::Tensor input = torch::rand({2, 2, 8, 8}, torch::TensorOptions(torch::kFloat)); @@ -223,6 +235,8 @@ TEST_F(AtenXlaTensorTest, TestMaxUnpool2DBackward) { } TEST_F(AtenXlaTensorTest, TestMaxUnpool3DBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; int kernel_size = 2; torch::Tensor input = torch::rand({2, 2, 8, 8, 8}, torch::TensorOptions(torch::kFloat)); @@ -262,6 +276,8 @@ TEST_F(AtenXlaTensorTest, TestMaxUnpool3DBackward) { } TEST_F(AtenXlaTensorTest, TestTanhBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::tanh(inputs[0]); }; @@ -274,6 +290,8 @@ TEST_F(AtenXlaTensorTest, TestTanhBackward) { } TEST_F(AtenXlaTensorTest, TestSigmoidBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::sigmoid(inputs[0]); }; @@ -286,6 +304,8 @@ TEST_F(AtenXlaTensorTest, TestSigmoidBackward) { } TEST_F(AtenXlaTensorTest, TestLogSigmoidBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::log_sigmoid(inputs[0]); }; @@ -302,6 +322,8 @@ TEST_F(AtenXlaTensorTest, TestLogSigmoidBackward) { } TEST_F(AtenXlaTensorTest, TestLogSoftmaxBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; for (int dim = -4; dim < 4; ++dim) { auto testfn = [&](const std::vector& inputs) -> torch::Tensor { @@ -322,6 +344,8 @@ TEST_F(AtenXlaTensorTest, TestLogSoftmaxBackward) { } TEST_F(AtenXlaTensorTest, TestSoftmaxBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; for (int dim = -4; dim < 4; ++dim) { auto testfn = [&](const std::vector& inputs) -> torch::Tensor { @@ -339,6 +363,8 @@ TEST_F(AtenXlaTensorTest, TestSoftmaxBackward) { } TEST_F(AtenXlaTensorTest, TestSoftplusBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::softplus(inputs[0]); }; @@ -351,6 +377,8 @@ TEST_F(AtenXlaTensorTest, TestSoftplusBackward) { } TEST_F(AtenXlaTensorTest, TestReluBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::relu(inputs[0]); }; @@ -363,6 +391,8 @@ TEST_F(AtenXlaTensorTest, TestReluBackward) { } TEST_F(AtenXlaTensorTest, TestRreluBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::rrelu(inputs[0]); }; @@ -375,6 +405,8 @@ TEST_F(AtenXlaTensorTest, TestRreluBackward) { } TEST_F(AtenXlaTensorTest, TestHardshrinkBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::hardshrink(inputs[0]); }; @@ -387,6 +419,8 @@ TEST_F(AtenXlaTensorTest, TestHardshrinkBackward) { } TEST_F(AtenXlaTensorTest, TestHardshrinkBackwardWithMixedDataType) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; if (UsingTpu()) { GTEST_SKIP(); } @@ -406,6 +440,8 @@ TEST_F(AtenXlaTensorTest, TestHardshrinkBackwardWithMixedDataType) { } TEST_F(AtenXlaTensorTest, TestSoftshrinkBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::softshrink(inputs[0]); }; @@ -418,6 +454,8 @@ TEST_F(AtenXlaTensorTest, TestSoftshrinkBackward) { } TEST_F(AtenXlaTensorTest, TestSoftshrinkBackwardWithMixedDataType) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; if (UsingTpu()) { GTEST_SKIP(); } @@ -437,6 +475,8 @@ TEST_F(AtenXlaTensorTest, TestSoftshrinkBackwardWithMixedDataType) { } TEST_F(AtenXlaTensorTest, TestHardtanhBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::hardtanh(inputs[0]); }; @@ -449,6 +489,8 @@ TEST_F(AtenXlaTensorTest, TestHardtanhBackward) { } TEST_F(AtenXlaTensorTest, TestEluBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; torch::Scalar alpha = 0.5; torch::Scalar scale = 2.5; torch::Scalar input_scale = 1.5; @@ -464,6 +506,8 @@ TEST_F(AtenXlaTensorTest, TestEluBackward) { } TEST_F(AtenXlaTensorTest, TestGeluBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; for (const auto& approximate : {"none", "tanh"}) { auto testfn = [&](const std::vector& inputs) -> torch::Tensor { @@ -480,6 +524,8 @@ TEST_F(AtenXlaTensorTest, TestGeluBackward) { } TEST_F(AtenXlaTensorTest, TestLeakyReluBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; double negative_slope = 0.01; auto testfn = [=](const std::vector& inputs) -> torch::Tensor { return torch::leaky_relu(inputs[0], negative_slope); @@ -493,6 +539,8 @@ TEST_F(AtenXlaTensorTest, TestLeakyReluBackward) { } TEST_F(AtenXlaTensorTest, TestTransposeBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return torch::t(inputs[0]); }; @@ -505,6 +553,8 @@ TEST_F(AtenXlaTensorTest, TestTransposeBackward) { } TEST_F(AtenXlaTensorTest, TestAddMatMulBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; int in_channels = 32; int out_channels = 320; int labels = 50; @@ -529,6 +579,8 @@ TEST_F(AtenXlaTensorTest, TestAddMatMulBackward) { } TEST_F(AtenXlaTensorTest, TestBinaryCrossEntropyBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; if (UsingTpu()) { GTEST_SKIP(); } @@ -570,6 +622,8 @@ TEST_F(AtenXlaTensorTest, TestBinaryCrossEntropyBackward) { } TEST_F(AtenXlaTensorTest, TestNllLossBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; int batch = 6; int classes = 2; for (auto dtype : {torch::kFloat, torch::kDouble}) { @@ -611,6 +665,8 @@ TEST_F(AtenXlaTensorTest, TestNllLossBackward) { } TEST_F(AtenXlaTensorTest, TestNllLoss2dBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; int batch = 6; int classes = 2; int height = 3; @@ -656,6 +712,8 @@ TEST_F(AtenXlaTensorTest, TestNllLoss2dBackward) { } TEST_F(AtenXlaTensorTest, TestSmoothL1LossBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; torch::Tensor input = torch::randn( {2, 4}, torch::TensorOptions(torch::kFloat).requires_grad(true)); torch::Tensor target = @@ -681,6 +739,8 @@ TEST_F(AtenXlaTensorTest, TestSmoothL1LossBackward) { } TEST_F(AtenXlaTensorTest, TestViewBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; auto testfn = [&](const std::vector& inputs) -> torch::Tensor { return inputs[0].view({-1, 320}); }; @@ -693,6 +753,8 @@ TEST_F(AtenXlaTensorTest, TestViewBackward) { } TEST_F(AtenXlaTensorTest, TestBatchNorm2DBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; double momentum = 0.1; double eps = 0.5; auto testfn = [&](const std::vector& inputs) -> torch::Tensor { @@ -739,6 +801,8 @@ TEST_F(AtenXlaTensorTest, TestBatchNorm2DBackward) { } TEST_F(AtenXlaTensorTest, TestBatchNorm3DBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; double momentum = 0.1; double eps = 0.5; auto testfn = [&](const std::vector& inputs) -> torch::Tensor { @@ -785,6 +849,8 @@ TEST_F(AtenXlaTensorTest, TestBatchNorm3DBackward) { } TEST_F(AtenXlaTensorTest, TestBCEWithLogitsBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; int batch = 10; int classes = 5; torch::Tensor undef; @@ -828,6 +894,8 @@ TEST_F(AtenXlaTensorTest, TestBCEWithLogitsBackward) { } TEST_F(AtenXlaTensorTest, TestKlDivBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; torch::Tensor input = torch::rand( {4, 3}, torch::TensorOptions(torch::kFloat).requires_grad(true)); torch::Tensor target = torch::rand( @@ -847,6 +915,8 @@ TEST_F(AtenXlaTensorTest, TestKlDivBackward) { } TEST_F(AtenXlaTensorTest, TestEmbeddingBackward) { + GTEST_SKIP() << "failing due to PyTorch upstream changes. " + << "See: https://github.com/pytorch/xla/issues/9651."; int num_weights = 32; for (int padding_idx = -1; padding_idx < num_weights; ++padding_idx) { for (bool scale_grad_by_freq : {false, true}) { From a5116917da7fcb51942120edf8fcbaeea263c67f Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 1 Oct 2025 08:59:14 -0300 Subject: [PATCH 114/133] `trace`: improve error handling and error messages. (#9630) This PR refactors the `trace` operation implementation by improving its error message, and returning a status type value. **Key Changes:** - Make `tensor_methods::trace` return `StatusOr` - Improve error messages and error handling - Renamed `CheckMMInputIsMatrix` to `CheckInputIsMatrix` - Added a new parameter for specifying the operation name, so as to build a better error message --- test/test_ops_error_message.py | 13 +++++++++++++ torch_xla/csrc/aten_xla_type.cpp | 7 +++++-- torch_xla/csrc/tensor_methods.cpp | 31 ++++++++++++++++--------------- torch_xla/csrc/tensor_methods.h | 2 +- 4 files changed, 35 insertions(+), 18 deletions(-) diff --git a/test/test_ops_error_message.py b/test/test_ops_error_message.py index bb23810f2349..42858fc84f67 100644 --- a/test/test_ops_error_message.py +++ b/test/test_ops_error_message.py @@ -222,6 +222,19 @@ def test(): expect="""roll(): expected `dims` [0] (size=1) to match the size of `shifts` [2, 2] (size=2).""" ) + def test_trace_raises_error_on_non_matrix_input(self): + device = torch_xla.device() + a = torch.rand(2, 2, 2, device=device) + + def test(): + torch.trace(a) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""trace(): expected the input tensor f32[2,2,2] to be a matrix (i.e. a 2D tensor).""" + ) + if __name__ == "__main__": unittest.main() diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 91101778d7aa..b6a8484a2505 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -3875,8 +3875,11 @@ std::tuple XLANativeFunctions::topk( at::Tensor XLANativeFunctions::trace(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - return bridge::AtenFromXlaTensor(tensor_methods::trace(xla_self)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output, + tensor_methods::trace(xla_self)); + return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor XLANativeFunctions::transpose_copy(const at::Tensor& self, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 0ec204b0fff1..a52c955ae558 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -511,13 +511,16 @@ absl::Status CheckGatherSizesAreCompatible(const XLATensorPtr& input, return absl::OkStatus(); } -absl::Status CheckMMInputIsMatrix(const XLATensorPtr& mat, - const std::string_view arg) { - xla::Shape shape = mat->shape(); +absl::Status CheckInputIsMatrix(const XLATensorPtr& tensor, + const std::string_view op, + const std::string_view arg = "") { + xla::Shape shape = tensor->shape(); if (shape.dimensions().size() != 2) { - return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError( - absl::StrCat("mm(): expected the ", arg, " input tensor ", - shape.ToString(), " to be a matrix (i.e. a 2D tensor)."))); + const std::string arg_with_trailing_space = + arg.empty() ? std::string("") : absl::StrCat(arg, " "); + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + op, "(): expected the ", arg_with_trailing_space, "input tensor ", + shape.ToString(), " to be a matrix (i.e. a 2D tensor)."))); } return absl::OkStatus(); } @@ -2452,8 +2455,8 @@ XLATensorPtr mish(const XLATensorPtr& input) { absl::StatusOr mm(const XLATensorPtr& input, const XLATensorPtr& weight) { - XLA_RETURN_IF_ERROR(CheckMMInputIsMatrix(input, "first")); - XLA_RETURN_IF_ERROR(CheckMMInputIsMatrix(weight, "second")); + XLA_RETURN_IF_ERROR(CheckInputIsMatrix(input, "mm", "first")); + XLA_RETURN_IF_ERROR(CheckInputIsMatrix(weight, "mm", "second")); XLA_RETURN_IF_ERROR(CheckMMMatrixSizesAreCompatible(input, weight)); return input->CreateFrom(Dot(input->GetIrValue(), weight->GetIrValue())); } @@ -3648,13 +3651,11 @@ std::tuple topk(const XLATensorPtr& input, return std::make_tuple(t1, t2); } -XLATensorPtr trace(const XLATensorPtr& input) { - auto input_shape_ref = input->shape(); - XLA_CHECK_EQ((*input_shape_ref).dimensions_size(), 2) - << "invalid argument for trace: expected a matrix"; - torch::lazy::NodePtr eye = Identity((*input_shape_ref).dimensions(0), - (*input_shape_ref).dimensions(1), - (*input_shape_ref).element_type()); +absl::StatusOr trace(const XLATensorPtr& input) { + XLA_RETURN_IF_ERROR(CheckInputIsMatrix(input, "trace")); + xla::Shape shape = input->shape(); + torch::lazy::NodePtr eye = + Identity(shape.dimensions(0), shape.dimensions(1), shape.element_type()); return sum(input->CreateFrom(eye * input->GetIrValue()), {0, 1}, false, input->dtype()); } diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index e91e92ad96e5..fb0e39cc8617 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -973,7 +973,7 @@ std::tuple topk(const XLATensorPtr& input, bool stable); // Returns the sum of the elements of the diagonal of the input 2-D matrix. -XLATensorPtr trace(const XLATensorPtr& input); +absl::StatusOr trace(const XLATensorPtr& input); // Swap given dimensions of the input. XLATensorPtr transpose(const XLATensorPtr& input, int64_t dim0, int64_t dim1); From 3240166361abf32612c132811b9326210f23d073 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 1 Oct 2025 10:32:27 -0300 Subject: [PATCH 115/133] Fix Terraform usage of `cuda_version`. (#9655) This PR removes `cuda_version` usage inside `artifacts_build.tf` remaining from #9618. This might be the cause of errors in nightly (see [comment](https://github.com/pytorch/xla/issues/9589#issuecomment-3299701521)). --- .../tpu-pytorch-releases/artifacts_builds.tf | 30 +++++-------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/infra/tpu-pytorch-releases/artifacts_builds.tf b/infra/tpu-pytorch-releases/artifacts_builds.tf index b4e469b617be..3546b23692b5 100644 --- a/infra/tpu-pytorch-releases/artifacts_builds.tf +++ b/infra/tpu-pytorch-releases/artifacts_builds.tf @@ -94,7 +94,7 @@ locals { for b in local.nightly_builds : format("%s_%s%s%s", b.python_version, - b.accelerator == "tpu" ? "tpuvm" : format("cuda_%s", b.cuda_version), + "tpuvm", try(b.cxx11_abi == "0", false) ? "_precxx11" : "", try(b.bundle_libtpu == "1", false) ? "_libtpu" : "" ) => b @@ -106,7 +106,7 @@ locals { format("r%s_%s_%s%s%s", replace(b.package_version, "+", "_"), b.python_version, - b.accelerator == "tpu" ? "tpuvm" : format("cuda_%s", b.cuda_version), + "tpuvm", try(b.cxx11_abi == "0", false) ? "_precxx11" : "", try(b.bundle_libtpu == "1", false) ? "_libtpu" : "" ) => b @@ -136,20 +136,13 @@ module "nightly_builds" { ] description = join(" ", [ - "Builds nightly xla:nightly_${each.key}' ${ - each.value.accelerator == "tpu" - ? "TPU" - : format("CUDA %s", each.value.cuda_version) - } docker image and corresponding wheels for PyTorch/XLA.", + "Builds nightly xla:nightly_${each.key}' TPU docker image and ", + "corresponding wheels for PyTorch/XLA.", "Trigger managed by Terraform setup in", "infra/tpu-pytorch-releases/artifacts_builds.tf." ]) - wheels_dest = "${module.releases_storage_bucket.url}/wheels/${ - each.value.accelerator == "tpu" - ? "tpuvm" - : "cuda/${each.value.cuda_version}" - }" + wheels_dest = "${module.releases_storage_bucket.url}/wheels/tpuvm" wheels_srcs = ["/dist/*.whl"] build_args = { python_version = each.value.python_version @@ -183,20 +176,13 @@ module "versioned_builds" { image_tags = [each.key] description = join(" ", [ - "Builds official xla:${each.key}' ${ - each.value.accelerator == "tpu" - ? "TPU" - : format("CUDA %s", each.value.cuda_version) - } docker image and corresponding wheels for PyTorch/XLA.", + "Builds official xla:${each.key}' TPU docker image and ", + "corresponding wheels for PyTorch/XLA.", "Trigger managed by Terraform setup in", "infra/tpu-pytorch-releases/artifacts_builds.tf." ]) - wheels_dest = "${module.releases_storage_bucket.url}/wheels/${ - each.value.accelerator == "tpu" - ? "tpuvm" - : "cuda/${each.value.cuda_version}" - }" + wheels_dest = "${module.releases_storage_bucket.url}/wheels/tpuvm" wheels_srcs = ["/dist/*.whl"] # Pass docker build args to infra/ansible/Dockerfile, other than `ansible_vars`. build_args = { From 3862b8784f8b254d0230ce79565add32de5a6cb0 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 1 Oct 2025 15:29:13 -0300 Subject: [PATCH 116/133] Create PyTorch commit pin. (#9654) This PR creates a PyTorch pin. **Key Changes:** - Removed `get-torch-commit` (and, consequently, `torch-commit` parameters) from GitHub actions files - Modified the `setup.yml` action by: 1. Checking out PyTorch/XLA 2. Retrieving the contents of `.torch_commit` 3. Checking out PyTorch using the contents of the retrieved PyTorch commit 4. Moving PyTorch/XLA inside the PyTorch directory - Add `.torch_commit` file pointing to the commit just before #9651 started happening This should prevent our CI to break due to some PyTorch breaking changes as we have experienced recently (e.g. #9653, and #9651). From now on, in theory, we should only see our CI breaking because of PyTorch changes whenever we update this pin. --- .github/ci.md | 50 +++++++++++-------- .github/workflows/_build_torch_xla.yml | 6 --- .github/workflows/_test.yml | 17 ------- .github/workflows/_tpu_ci.yml | 5 -- .github/workflows/build_and_test.yml | 17 ++----- .github/workflows/setup/action.yml | 23 ++++++--- .torch_commit | 2 + CONTRIBUTING.md | 24 ++++++++- .../ansible/roles/build_srcs/tasks/main.yaml | 7 ++- scripts/build_developer.sh | 5 ++ scripts/update_deps.py | 45 ++++++++++++++++- 11 files changed, 125 insertions(+), 76 deletions(-) create mode 100644 .torch_commit diff --git a/.github/ci.md b/.github/ci.md index cc3994c884e7..5c3671339c7d 100644 --- a/.github/ci.md +++ b/.github/ci.md @@ -3,22 +3,22 @@ PyTorch and PyTorch/XLA use CI to lint, build, and test each PR that is submitted. All CI tests should succeed before the PR is merged into master. PyTorch CI pins PyTorch/XLA to a specific commit. On the other hand, PyTorch/XLA -CI pulls PyTorch from master unless a pin is manually provided. This README will -go through the reasons of these pins, how to pin a PyTorch/XLA PR to an upstream -PyTorch PR, and how to coordinate a merge for breaking PyTorch changes. +CI pulls PyTorch from `.torch_commit` unless a pin is manually provided. This +README will go through the reasons of these pins, how to pin a PyTorch/XLA PR +to an upstream PyTorch PR, and how to coordinate a merge for breaking PyTorch +changes. ## Usage -### Pinning PyTorch PR in PyTorch/XLA PR +### Temporarily Pinning PyTorch PR in PyTorch/XLA PR Sometimes a PyTorch/XLA PR needs to be pinned to a specific PyTorch PR to test -new features, fix breaking changes, etc. Since PyTorch/XLA CI pulls from PyTorch -master by default, we need to manually provide a PyTorch pin. In a PyTorch/XLA -PR, PyTorch can be manually pinned by creating a `.torch_pin` file at the root -of the repository. The `.torch_pin` should have the corresponding PyTorch PR -number prefixed by "#". Take a look at [example -here](https://github.com/pytorch/xla/pull/7313). Before the PyTorch/XLA PR gets -merged, the `.torch_pin` must be deleted. +new features, fix breaking changes, etc. In a PyTorch/XLA PR, PyTorch can be +manually pinned by creating a `.torch_pin` file at the root of the repository. +The `.torch_pin` should have the corresponding PyTorch PR number prefixed by +"#". Take a look at [example here](https://github.com/pytorch/xla/pull/7313). +Before the PyTorch/XLA PR gets merged, the `.torch_pin` must be deleted and +`.torch_commit` updated. ### Coordinating merges for breaking PyTorch PRs @@ -35,10 +35,11 @@ fail. Steps for fixing and merging such breaking PyTorch change is as following: PyTorch PR to pin the PyTorch/XLA to the commit hash created in step 1 by updating `pytorch/.github/ci_commit_pins/xla.txt`. 1. Once CI tests are green on both ends, merge PyTorch PR. -1. Remove the `.torch_pin` in PyTorch/XLA PR and merge. To be noted, `git commit - --amend` should be avoided in this step as PyTorch CI will keep using the - commit hash created in step 1 until other PRs update that manually or the - nightly buildbot updates that automatically. +1. Remove the `.torch_pin` in PyTorch/XLA PR and update the `.torch_commit` to + the hash of the merged PyTorch PR. To be noted, `git commit --amend` should + be avoided in this step as PyTorch CI will keep using the commit hash + created in step 1 until other PRs update that manually or the nightly + buildbot updates that automatically. 1. Finally, don't delete your branch until 2 days later. See step 4 for explanations. @@ -47,6 +48,18 @@ fail. Steps for fixing and merging such breaking PyTorch change is as following: The `build_and_test.yml` workflow runs tests on the TPU in addition to CPU. The set of tests run on the TPU is defined in `test/tpu/run_tests.sh`. +## Update the PyTorch Commit Pin + +In order to reduce development burden of PyTorch/XLA, starting from #9654, we +started pinning PyTorch using the `.torch_commit` file. This should reduce the +number of times a PyTorch PR breaks our most recent commits. However, this also +requires maintenance, i.e. someone has to keep updating the PyTorch commit so +as to make sure it's always supporting (almost) the latest PyTorch versions. + +Updating the PyTorch commit pin is, theoretically, simple. You just have to run +`scripts/update_deps.py --pytorch` file, and open a PR. In practice, you may +encounter a few compilation errors, or even segmentation faults. + ## CI Environment Before the CI in this repository runs, we build a base dev image. These are the @@ -152,13 +165,6 @@ good" commit to prevent accidental changes from PyTorch/XLA to break PyTorch CI without warning. PyTorch has hundreds of commits each week, and this pin ensures that PyTorch/XLA as a downstream package does not cause failures in PyTorch CI. -#### Why does PyTorch/XLA CI pull from PyTorch master? - -[PyTorch/XLA CI pulls PyTorch from master][pull-pytorch-master] unless a PyTorch -pin is manually provided. PyTorch/XLA is a downstream package to PyTorch, and -pulling from master ensures that PyTorch/XLA will stay up-to-date and works with -the latest PyTorch changes. - #### TPU CI is broken If the TPU CI won't run, try to debug using the following steps: diff --git a/.github/workflows/_build_torch_xla.yml b/.github/workflows/_build_torch_xla.yml index aaff9c2bf9dc..49ae9227372f 100644 --- a/.github/workflows/_build_torch_xla.yml +++ b/.github/workflows/_build_torch_xla.yml @@ -6,10 +6,6 @@ on: required: true type: string description: Base image for builds - torch-commit: - required: true - type: string - description: torch-commit runner: required: false type: string @@ -53,8 +49,6 @@ jobs: - name: Setup if: inputs.has_code_changes == 'true' uses: ./.actions/.github/workflows/setup - with: - torch-commit: ${{ inputs.torch-commit }} - name: Build if: inputs.has_code_changes == 'true' shell: bash diff --git a/.github/workflows/_test.yml b/.github/workflows/_test.yml index 10bc92327dee..b0d6c988fab8 100644 --- a/.github/workflows/_test.yml +++ b/.github/workflows/_test.yml @@ -23,10 +23,6 @@ on: description: | Set the maximum (in minutes) how long the workflow should take to finish timeout-minutes: - torch-commit: - required: true - type: string - description: torch-commit has_code_changes: required: false type: string @@ -89,7 +85,6 @@ jobs: if: inputs.has_code_changes == 'true' uses: ./.actions/.github/workflows/setup with: - torch-commit: ${{ inputs.torch-commit }} wheels-artifact: torch-xla-wheels - name: Fetch CPP test binaries if: inputs.has_code_changes == 'true' && matrix.run_cpp_tests @@ -112,18 +107,6 @@ jobs: pip install fsspec pip install rich pip install flax - - name: Checkout PyTorch Repo - if: inputs.has_code_changes == 'true' - uses: actions/checkout@v4 - with: - repository: pytorch/pytorch - path: pytorch - ref: ${{ inputs.torch-commit }} - - name: Checkout PyTorch/XLA Repo - if: inputs.has_code_changes == 'true' - uses: actions/checkout@v4 - with: - path: pytorch/xla - name: Extra CI deps if: inputs.has_code_changes == 'true' shell: bash diff --git a/.github/workflows/_tpu_ci.yml b/.github/workflows/_tpu_ci.yml index 656372263d28..2f48391c96ea 100644 --- a/.github/workflows/_tpu_ci.yml +++ b/.github/workflows/_tpu_ci.yml @@ -2,10 +2,6 @@ name: TPU Integration Test on: workflow_call: inputs: - torch-commit: - required: false - type: string - description: torch-commit timeout-minutes: required: false type: number @@ -42,7 +38,6 @@ jobs: if: inputs.has_code_changes == 'true' uses: ./.actions/.github/workflows/setup with: - torch-commit: ${{ inputs.torch-commit }} wheels-artifact: torch-xla-wheels - name: Install test dependencies diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 27fca6a00446..8381ed2823ad 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -26,29 +26,21 @@ jobs: base_sha: ${{ github.event_name == 'pull_request' && github.event.pull_request.base.sha || github.event.before }} head_sha: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - get-torch-commit: + report-no-code-changes: needs: [check_code_changes] runs-on: ubuntu-24.04 - outputs: - torch_commit: ${{ steps.commit.outputs.torch_commit }} steps: - - name: Get latest torch commit - id: commit - if: needs.check_code_changes.outputs.has_code_changes == 'true' - run: | - echo "torch_commit=$(git ls-remote https://github.com/pytorch/pytorch.git HEAD | awk '{print $1}')" >> "$GITHUB_OUTPUT" - name: Report no code changes - if: needs.check_code_changes.outputs.has_code_changes == 'false' run: | echo "No code changes were detected that require running the full test suite." + if: needs.check_code_changes.outputs.has_code_changes == 'false' build-torch-xla: name: "Build PyTorch/XLA" uses: ./.github/workflows/_build_torch_xla.yml - needs: [check_code_changes, get-torch-commit] + needs: [check_code_changes] with: dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.12_tpuvm - torch-commit: ${{needs.get-torch-commit.outputs.torch_commit}} timeout-minutes: 45 # Takes ~20m as of 2025/5/30. has_code_changes: ${{ needs.check_code_changes.outputs.has_code_changes }} runner: linux.24xlarge @@ -58,13 +50,12 @@ jobs: test-python-cpu: name: "CPU tests" uses: ./.github/workflows/_test.yml - needs: [build-torch-xla, check_code_changes, get-torch-commit] + needs: [build-torch-xla, check_code_changes] with: dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.12_tpuvm timeout-minutes: 45 # Takes ~26m as of 2025/5/30. collect-coverage: false runner: linux.24xlarge - torch-commit: ${{needs.get-torch-commit.outputs.torch_commit}} has_code_changes: ${{ needs.check_code_changes.outputs.has_code_changes }} secrets: gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }} diff --git a/.github/workflows/setup/action.yml b/.github/workflows/setup/action.yml index c5c9a9ad1e66..fe8c953281f6 100644 --- a/.github/workflows/setup/action.yml +++ b/.github/workflows/setup/action.yml @@ -1,8 +1,5 @@ name: Set up PyTorch/XLA inputs: - torch-commit: - type: string - description: PyTorch commit to check out, if provided wheels-artifact: type: string description: | @@ -16,6 +13,7 @@ runs: run: | ls -la rm -rvf ${GITHUB_WORKSPACE}/* + - name: Setup gcloud shell: bash run: | @@ -25,24 +23,37 @@ runs: # reason composite actions don't support secrets. # https://docs.github.com/en/actions/using-workflows/avoiding-duplication if: ${{ env.GCLOUD_SERVICE_KEY }} + - name: Checkout PyTorch Repo uses: actions/checkout@v4 with: repository: pytorch/pytorch path: pytorch - ref: ${{ inputs.torch-commit }} - submodules: recursive - if: ${{ inputs.torch-commit }} + - name: Checkout PyTorch/XLA Repo uses: actions/checkout@v4 with: path: pytorch/xla + + # Fetch and checkout to the pinned PyTorch commit. + - name: Checkout to PyTorch Commit Pin + working-directory: pytorch + shell: bash + env: + TORCH_COMMIT_FILE: ".torch_commit" + run: | + COMMIT=$(tail -1 "xla/$TORCH_COMMIT_FILE") + git fetch --no-recurse-submodules origin $COMMIT + git checkout --no-recurse-submodules FETCH_HEAD + git submodule update --init --recursive + - name: Fetch PyTorch/XLA packages uses: actions/download-artifact@v5 with: name: ${{ inputs.wheels-artifact }} path: /tmp/wheels/ if: ${{ inputs.wheels-artifact }} + - name: Install wheels shell: bash run: | diff --git a/.torch_commit b/.torch_commit new file mode 100644 index 000000000000..715a8bee47e6 --- /dev/null +++ b/.torch_commit @@ -0,0 +1,2 @@ +# 2025-09-17 +928ac57c2ab03f9f79376f9995553eea2e6f4ca8 \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b8d233c87002..348a91da3300 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -37,7 +37,7 @@ working with: Next, we need to clone the forked repos locally so that we can make changes. -On your Linuc machine, decide a directory as your workspace. Make sure that +On your Linux machine, decide a directory as your workspace. Make sure that this directory and all of its ancestors are publically readable. Then run the following commands on this machine: @@ -58,6 +58,28 @@ git clone --recursive git@github.com:/vision.git git clone --recursive git@github.com:/pytorch-xla.git pytorch/xla ``` +### Pinned PyTorch Version + +Since PR #9654, PyTorch/XLA started pinnning a PyTorch version. The pinned +commit can be found in `.torch_commit` file at the root directory. Note that +the pinned PyTorch version guarantees all PyTorch/XLA tests are passing +whenever the underlying PyTorch is compiled at that specific commit. Therefore, +specially for development, it's recommended that PyTorch is compiled at that +specific commit. Otherwise you might end up with all kinds of errors: from +build errors, to segmentation faults. So, make sure to check out that version: + +```bash +# Go to PyTorch directory. +cd $WORKSPACE_DIR/pytorch + +# Retrieve the PyTorch commit pin inside PyTorch/XLA directory. +# Note: it's located in the last line of `.torch_commit`. +COMMIT=$(tail -1 "xla/.torch_commit") + +# Create a branch (optional) and jump at that commit. +git checkout -b pin "$COMMIT" +``` + ### Setting up Remote Tracking From time to time, we'll need to bring our forked repos up to date with the diff --git a/infra/ansible/roles/build_srcs/tasks/main.yaml b/infra/ansible/roles/build_srcs/tasks/main.yaml index 3d6489b6162f..213370f665a0 100644 --- a/infra/ansible/roles/build_srcs/tasks/main.yaml +++ b/infra/ansible/roles/build_srcs/tasks/main.yaml @@ -1,8 +1,8 @@ - name: Read PyTorch pin - ansible.builtin.command: cat {{ (src_root, 'pytorch/xla/.torch_pin') | path_join }} + ansible.builtin.shell: | + cat {{ (src_root, 'pytorch/xla/.torch_pin') | path_join }} 2> /dev/null || + tail -1 {{ (src_root, 'pytorch/xla/.torch_commit') | path_join }} register: torch_pin - # Pin may not exist - ignore_errors: true - name: Checkout PyTorch pin # ansible.builtin.git wants to fetch the entire history, so check out the pin manually @@ -21,7 +21,6 @@ chdir: "{{ (src_root, 'pytorch') | path_join }}" args: executable: /bin/bash - when: torch_pin is succeeded - name: Build PyTorch ansible.builtin.command: diff --git a/scripts/build_developer.sh b/scripts/build_developer.sh index 792d19fadb06..beba52916e2b 100755 --- a/scripts/build_developer.sh +++ b/scripts/build_developer.sh @@ -56,6 +56,11 @@ if [ "$_BUILD_BASE" == "pytorch" ]; then # Change to the pytorch directory. cd $_SCRIPT_DIR/../.. + TORCH_COMMIT="xla/.torch_commit" + if [ -e "$TORCH_COMMIT" ]; then + git checkout $(tail -1 "$TORCH_COMMIT") + fi + # Remove any leftover old wheels and old installation. pip uninstall torch -y python3 setup.py clean diff --git a/scripts/update_deps.py b/scripts/update_deps.py index 6c369406d699..8e7897928dde 100755 --- a/scripts/update_deps.py +++ b/scripts/update_deps.py @@ -43,6 +43,9 @@ _JAX_PROJECT_URL = _JAX_INDEX_URL + 'jax/' _JAXLIB_PROJECT_URL = _JAX_INDEX_URL + 'jaxlib/' +_TORCH_COMMIT_FORMAT = "# %cs%n%H" +_TORCH_COMMIT_FILE = os.path.join(_PTXLA_DIR, ".torch_commit") + class PEP503Parser(HTMLParser): """Parser for PEP 503 simple repository API pages. @@ -478,6 +481,38 @@ def update_jax(use_latest: bool) -> bool: return success +def update_pytorch(use_latest: bool) -> bool: + clean_tmp_dir() + + torch_temp_dir = os.path.join(_TMP_DIR, "pytorch") + branch = "main" if use_latest else "viable/strict" + + cmd_clone = [ + "git", + "clone", + "--branch", + branch, + "--depth=1", + "https://github.com/pytorch/pytorch", + torch_temp_dir, + ] + os.system(" ".join(cmd_clone)) + + cmd_commit_show = [ + "git", + f"--git-dir={torch_temp_dir}/.git", + "show", + "--no-patch", + f"--pretty=format:\"{_TORCH_COMMIT_FORMAT}\"", + ] + commit = os.popen(" ".join(cmd_commit_show)).read().strip() + + with open(_TORCH_COMMIT_FILE, "w") as f: + f.write(commit) + + return True + + def main() -> None: logging.basicConfig(level=logging.INFO) @@ -496,7 +531,9 @@ def main() -> None: openxla_updated = update_openxla() libtpu_updated = update_libtpu() jax_updated = update_jax(use_latest=True) - if not (openxla_updated and libtpu_updated and jax_updated): + pytorch_updated = update_pytorch(use_latest=True) + if not (openxla_updated and libtpu_updated and jax_updated and + pytorch_updated): sys.exit(1) else: logger.info('Updating to latest stable versions...') @@ -513,7 +550,11 @@ def main() -> None: libtpu_updated = update_libtpu( target_date=jax_release_date.replace('-', '')) jax_updated = update_jax(use_latest=False) - if not (openxla_updated and libtpu_updated and jax_updated): + + pytorch_updated = update_pytorch(use_latest=False) + + if not (openxla_updated and libtpu_updated and jax_updated and + pytorch_updated): sys.exit(1) From 6ac4a7cdea634b2a9705c4809318b006ed9270bf Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 1 Oct 2025 15:39:40 -0300 Subject: [PATCH 117/133] Accept conda channels' ToS when building the upstream docker image. (#9661) This PR should fix [the error](https://github.com/pytorch/xla/actions/runs/18161391171/job/51692786522) we've been getting when trying to build the upstream image. It simply adds to the `install_conda.sh` script what `conda` suggests us to do. --- .github/upstream/install_conda.sh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/upstream/install_conda.sh b/.github/upstream/install_conda.sh index 2908334f0923..8f9bd3675316 100644 --- a/.github/upstream/install_conda.sh +++ b/.github/upstream/install_conda.sh @@ -27,6 +27,10 @@ function install_and_setup_conda() { fi export CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" + # Accept Conda channel ToS. + conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main + conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r + conda update -y -n base conda conda install -y python=$PYTHON_VERSION From cc300f792e95e51db614a69c6e3c69eb2d3a8a48 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 1 Oct 2025 16:14:07 -0300 Subject: [PATCH 118/133] Revert "Fix Terraform usage of `cuda_version`. (#9655)" (#9664) This reverts commit 3240166361abf32612c132811b9326210f23d073. PR #9655 actually introduced errors in Terraform. --- .../tpu-pytorch-releases/artifacts_builds.tf | 30 ++++++++++++++----- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/infra/tpu-pytorch-releases/artifacts_builds.tf b/infra/tpu-pytorch-releases/artifacts_builds.tf index 3546b23692b5..b4e469b617be 100644 --- a/infra/tpu-pytorch-releases/artifacts_builds.tf +++ b/infra/tpu-pytorch-releases/artifacts_builds.tf @@ -94,7 +94,7 @@ locals { for b in local.nightly_builds : format("%s_%s%s%s", b.python_version, - "tpuvm", + b.accelerator == "tpu" ? "tpuvm" : format("cuda_%s", b.cuda_version), try(b.cxx11_abi == "0", false) ? "_precxx11" : "", try(b.bundle_libtpu == "1", false) ? "_libtpu" : "" ) => b @@ -106,7 +106,7 @@ locals { format("r%s_%s_%s%s%s", replace(b.package_version, "+", "_"), b.python_version, - "tpuvm", + b.accelerator == "tpu" ? "tpuvm" : format("cuda_%s", b.cuda_version), try(b.cxx11_abi == "0", false) ? "_precxx11" : "", try(b.bundle_libtpu == "1", false) ? "_libtpu" : "" ) => b @@ -136,13 +136,20 @@ module "nightly_builds" { ] description = join(" ", [ - "Builds nightly xla:nightly_${each.key}' TPU docker image and ", - "corresponding wheels for PyTorch/XLA.", + "Builds nightly xla:nightly_${each.key}' ${ + each.value.accelerator == "tpu" + ? "TPU" + : format("CUDA %s", each.value.cuda_version) + } docker image and corresponding wheels for PyTorch/XLA.", "Trigger managed by Terraform setup in", "infra/tpu-pytorch-releases/artifacts_builds.tf." ]) - wheels_dest = "${module.releases_storage_bucket.url}/wheels/tpuvm" + wheels_dest = "${module.releases_storage_bucket.url}/wheels/${ + each.value.accelerator == "tpu" + ? "tpuvm" + : "cuda/${each.value.cuda_version}" + }" wheels_srcs = ["/dist/*.whl"] build_args = { python_version = each.value.python_version @@ -176,13 +183,20 @@ module "versioned_builds" { image_tags = [each.key] description = join(" ", [ - "Builds official xla:${each.key}' TPU docker image and ", - "corresponding wheels for PyTorch/XLA.", + "Builds official xla:${each.key}' ${ + each.value.accelerator == "tpu" + ? "TPU" + : format("CUDA %s", each.value.cuda_version) + } docker image and corresponding wheels for PyTorch/XLA.", "Trigger managed by Terraform setup in", "infra/tpu-pytorch-releases/artifacts_builds.tf." ]) - wheels_dest = "${module.releases_storage_bucket.url}/wheels/tpuvm" + wheels_dest = "${module.releases_storage_bucket.url}/wheels/${ + each.value.accelerator == "tpu" + ? "tpuvm" + : "cuda/${each.value.cuda_version}" + }" wheels_srcs = ["/dist/*.whl"] # Pass docker build args to infra/ansible/Dockerfile, other than `ansible_vars`. build_args = { From 420adaa4feac9f9231ee75a92ba5d704f45014fe Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 1 Oct 2025 16:17:11 -0300 Subject: [PATCH 119/133] Bump Python version of `ci-tpu-test-trigger` to 3.12. (#9665) `ci-tpu-test-trigger` was failing with: ``` Downloading https://us-python.pkg.dev/ml-oss-artifacts-published/jax/meson-python/meson_python-0.18.0-py3-none-any.whl ... meson-python: error: The package requires Python version >=3.11, running on 3.10.18 ``` --- infra/tpu-pytorch/test_triggers.tf | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/infra/tpu-pytorch/test_triggers.tf b/infra/tpu-pytorch/test_triggers.tf index 0ca1f96f5f04..9f146b7dbfb8 100644 --- a/infra/tpu-pytorch/test_triggers.tf +++ b/infra/tpu-pytorch/test_triggers.tf @@ -23,7 +23,7 @@ module "tpu_e2e_tests" { ]) build_args = { - python_version = "3.10" + python_version = "3.12" } ansible_vars = { From 13485454e88c52cdaa99acd54df0f3dec6afb0ad Mon Sep 17 00:00:00 2001 From: Hoomaaan <33916130+Hoomaaan@users.noreply.github.com> Date: Wed, 1 Oct 2025 14:24:52 -0700 Subject: [PATCH 120/133] fix(xla): convert group-local to global ranks in broadcast (#9657) Related AWS Neuron ticket: https://t.corp.amazon.com/V1941917988/overview broadcast was passing group-local ranks directly to xm.collective_broadcast() which expects global ranks, causing data curroption in single-member process groups TEST: ``` import os import torch import torch.distributed as dist import torch_xla as xla import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp import torch_xla.runtime as xr def main(): dist.init_process_group(backend="xla") rank = dist.get_rank() world_size = dist.get_world_size() tp = dist.new_group(ranks=[rank]) tp_rank = dist.get_rank(group=tp) tp_size = dist.get_world_size(group=tp) print( f">>>> pid={os.getpid()}, rank={rank}\n" f">>> world_size={world_size}\n" f">>> tp_rank={tp_rank}, tp_size={tp_size}, tp_members={dist.get_process_group_ranks(tp)}" ) do_train, do_valid, do_test = 0.1, 0.2, 0.3 # breakpoint() flags = torch.tensor([do_train, do_valid, do_test], dtype=torch.float32, device='xla') # breakpoint() dist.broadcast(flags, rank, group=tp) print(f">>>> pid={os.getpid()}, rank={rank}\n" f">>> do_train={flags[0].item()}, do_valid={flags[1].item()}, do_test={flags[2].item()}\n" f">>> global_ordinal={xr.global_ordinal()}") if __name__ == "__main__": main() ``` Results after this fix: ``` torchrun --nproc-per-node=2 --nnodes=1 ./bug.py W0926 18:50:41.903000 1081605 torch/distributed/run.py:766] W0926 18:50:41.903000 1081605 torch/distributed/run.py:766] ***************************************** W0926 18:50:41.903000 1081605 torch/distributed/run.py:766] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W0926 18:50:41.903000 1081605 torch/distributed/run.py:766] ***************************************** >>>> pid=1081679, rank=0 >>> world_size=2 >>> tp_rank=0, tp_size=1, tp_members=[0] >>>> pid=1081680, rank=1 >>> world_size=2 >>> tp_rank=0, tp_size=1, tp_members=[1] . . . 2.19.8089.0+8ab9f450/MODULE_10344927339446294134+e30acd3a/model.neff >>>> pid=1081680, rank=1 >>> do_train=0.10000000149011612, do_valid=0.20000000298023224, do_test=0.30000001192092896 >>> global_ordinal=1 >>>> pid=1081679, rank=0 >>> do_train=0.10000000149011612, do_valid=0.20000000298023224, do_test=0.30000001192092896 ``` Now both ranks have the correct values. Previously Rank1 was all zeros. --- test/test_torch_distributed_xla_backend.py | 87 ++++++++++++++++++++++ torch_xla/distributed/xla_backend.py | 7 +- 2 files changed, 93 insertions(+), 1 deletion(-) diff --git a/test/test_torch_distributed_xla_backend.py b/test/test_torch_distributed_xla_backend.py index bb0dfd3efd7f..39576e9b011b 100644 --- a/test/test_torch_distributed_xla_backend.py +++ b/test/test_torch_distributed_xla_backend.py @@ -44,6 +44,18 @@ def patch_world(rank, size): yield +@contextlib.contextmanager +def patch_world_with_xla_runtime(rank, size): + assert isinstance(dist.group.WORLD, + torch_xla.distributed.xla_backend.ProcessGroupXla) + + with mock.patch.object(dist.group.WORLD, 'rank', return_value=rank), \ + mock.patch.object(dist.group.WORLD, 'size', return_value=size), \ + mock.patch.object(xr, 'global_ordinal', return_value=rank), \ + mock.patch.object(xr, 'world_size', return_value=size): + yield + + class XlaBackendTest(parameterized.TestCase): @classmethod @@ -328,6 +340,81 @@ def test_unimplemented_op(self, op): with self.assertRaises(NotImplementedError): getattr(pg_xla, op)(tensor) + @patch_world_with_xla_runtime(rank=0, size=2) + def test_broadcast_single_rank_group_rank0(self): + """Test broadcast in single-member process group for rank 0""" + device = torch_xla.device() + + with new_group_barrier_disabled(): + tp = dist.new_group(ranks=[0]) + + # Create flags tensor with initial values (simulating rank 0's values) + flags = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32, device=device) + + # Broadcast within the single-member group (should be a no-op but shouldn't crash) + dist.broadcast(flags, src=0, group=tp) + + # Values should remain unchanged since it's a single-member group + self.assertAlmostEqual(flags[0].item(), 0.1, places=5) + self.assertAlmostEqual(flags[1].item(), 0.2, places=5) + self.assertAlmostEqual(flags[2].item(), 0.3, places=5) + + # Verify the process group properties + self.assertEqual(dist.get_rank(group=tp), 0) + self.assertEqual(dist.get_world_size(group=tp), 1) + + @patch_world_with_xla_runtime(rank=1, size=2) + def test_broadcast_single_rank_group_rank1(self): + """Test broadcast in single-member process group for rank 1""" + device = torch_xla.device() + + with new_group_barrier_disabled(): + tp = dist.new_group(ranks=[1]) + + # Create flags tensor with initial values (simulating rank 1's values) + flags = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32, device=device) + + # Broadcast within the single-member group (should be a no-op but shouldn't crash) + dist.broadcast(flags, src=1, group=tp) + + # Values should remain unchanged since it's a single-member group + self.assertAlmostEqual(flags[0].item(), 0.1, places=5) + self.assertAlmostEqual(flags[1].item(), 0.2, places=5) + self.assertAlmostEqual(flags[2].item(), 0.3, places=5) + + # Verify the process group properties + self.assertEqual(dist.get_rank(group=tp), + 0) # Local rank in single-member group is 0 + self.assertEqual(dist.get_world_size(group=tp), 1) + + @patch_world_with_xla_runtime(rank=0, size=2) + def test_broadcast_global_rank_conversion_single_member(self): + """Test that global rank conversion works correctly for single-member groups""" + device = torch_xla.device() + + # Create single-member group for rank 0 + with new_group_barrier_disabled(): + tp = dist.new_group(ranks=[0]) + + flags = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32, device=device) + + # Get the ProcessGroupXla instance to test directly + self.assertIsInstance(tp, torch_xla.distributed.xla_backend.ProcessGroupXla) + + # Test broadcast options - local rank 0 should map to global rank 0 + opts = dist.BroadcastOptions() + opts.rootRank = 0 + opts.rootTensor = 0 + + # This should work without variable name errors + work = tp.broadcast([flags], opts) + self.assertIsNotNone(work) + + # Values should be preserved + self.assertAlmostEqual(flags[0].item(), 0.1, places=5) + self.assertAlmostEqual(flags[1].item(), 0.2, places=5) + self.assertAlmostEqual(flags[2].item(), 0.3, places=5) + if __name__ == '__main__': if xr.device_type() != 'CPU': diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index 3c7848d6e327..ae3868ac3d7e 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -131,9 +131,14 @@ def allgather_coalesced(self, output_tensors_list, input_tensors, opts=None): # Call site: # https://github.com/pytorch/pytorch/blob/release/1.10/torch/distributed/distributed_c10d.py#L1129 def broadcast(self, tensors, opts): + import torch.distributed as dist + root_tensor = tensors[opts.rootTensor] + # Convert group local rank to global rank for xla collectives + group_source = opts.rootRank + global_src = dist.get_global_rank(self, group_source) xm.collective_broadcast([root_tensor], - opts.rootRank, + global_src, groups=self._mesh, pin_layout=False) From 1ab678799f2a9a448d202df0262a9cca9ea97bb1 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 1 Oct 2025 20:44:05 -0300 Subject: [PATCH 121/133] Accept conda channels' ToS with environment variable. (#9666) Follow-up: #9661 Apparently, we need to accept the ToS of those same channels, again ([link](https://github.com/pytorch/xla/actions/runs/18172128014/job/51729124142)). Instead, I'm using the `CONDA_PLUGINS_AUTO_ACCEPT_TOS` environment variable, documented [here](https://github.com/pytorch/pytorch/issues/158438#issuecomment-3084935777). --- .github/upstream/install_conda.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/upstream/install_conda.sh b/.github/upstream/install_conda.sh index 8f9bd3675316..e1085d65e697 100644 --- a/.github/upstream/install_conda.sh +++ b/.github/upstream/install_conda.sh @@ -27,9 +27,9 @@ function install_and_setup_conda() { fi export CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" - # Accept Conda channel ToS. - conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main - conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r + # Accept Conda channels' ToS automatically. + # Ref: https://github.com/pytorch/pytorch/issues/158438#issuecomment-3084935777 + export CONDA_PLUGINS_AUTO_ACCEPT_TOS="yes" conda update -y -n base conda conda install -y python=$PYTHON_VERSION From 2a9138a26ee257fef05310ad3fecf7c55fe80d73 Mon Sep 17 00:00:00 2001 From: Sungjoon Shon Date: Fri, 3 Oct 2025 14:06:16 -0400 Subject: [PATCH 122/133] mul: remove opmath cast sequence (#9663) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove the explicit opmath-driven cast chain (bf16→f32→bf16, etc.) from `mul`. The op now executes in the dtype chosen by standard dtype promotion, without inserting unconditional upcast/downcast steps. But leave its functionality for future usage. --- test/test_operations_hlo.py | 16 ++++++++++++++++ torch_xla/csrc/aten_xla_type.cpp | 1 - 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/test/test_operations_hlo.py b/test/test_operations_hlo.py index 9b254d1464b2..a7f34d6efb65 100644 --- a/test/test_operations_hlo.py +++ b/test/test_operations_hlo.py @@ -67,6 +67,22 @@ def test_dropout_by_u8_mask(self): hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([b]) assert 'u8' in hlo_text + def test_bfloat16_mul_not_upcast(self): + a = torch.rand(5, 5, dtype=torch.bfloat16).to('xla') + b = torch.rand(5, 5, dtype=torch.bfloat16).to('xla') + c = a * b + hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([c]) + # Check that the output is not upcasted to float32 + assert 'f32' not in hlo_text + + def test_bfloat16_float32_mul_upcast(self): + a = torch.rand(5, 5, dtype=torch.bfloat16).to('xla') + b = torch.rand(5, 5, dtype=torch.float32).to('xla') + c = a * b + hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([c]) + # Check that the output is upcasted to float32 + assert 'f32' in hlo_text + if __name__ == '__main__': torch.set_default_dtype(torch.float32) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index b6a8484a2505..d48251056c86 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2535,7 +2535,6 @@ at::Tensor XLANativeFunctions::mul(const at::Tensor& self, .add_input(self) .add_input(other) .cast_inputs_to_common_dtype() - .use_opmathtype_for_compute() .run(); } From d36ded218a492cecbfcf963fd24911a7b639daa8 Mon Sep 17 00:00:00 2001 From: Het Shah Date: Fri, 18 Jul 2025 15:59:45 -0400 Subject: [PATCH 123/133] [Experimental] Add initial implementation of GSPMD->Shardy pass within PyTorch/XLA (#1) Adds an environment variable CONVERT_SHLO_TO_SHARDY that does 2 things: - Uses V2 sharding annotations when generating the GSPMD SHLO module (i.e., in V1 a mesh annotation string like: devices=[2,1,4]0,1,2,3,4,5,6,7 becomes this in V2: devices=[2,1,4]<=[8]). - Converts the new GSPMD module with the V2 annotations into a Shardy module. --- torch_xla/csrc/init_python_bindings.cpp | 7 ++ torch_xla/csrc/runtime/BUILD | 1 + .../csrc/runtime/pjrt_computation_client.cpp | 4 ++ torch_xla/csrc/runtime/stablehlo_helper.cpp | 10 +++ torch_xla/csrc/runtime/stablehlo_helper.h | 2 + torch_xla/csrc/xla_sharding_util.cpp | 17 +++++ torch_xla/csrc/xla_sharding_util.h | 5 ++ torch_xla/distributed/spmd/xla_sharding.py | 67 ++++++++++++++++++- 8 files changed, 112 insertions(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index a52ecc8124e7..37150718c027 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1571,12 +1571,19 @@ void InitXlaModuleBindings(py::module m) { // Define the _XLAC.OpSharding class. PythonScope>(m, "OpSharding") + // Constructor for V1 shardings .def_init([](const py::list& tile_assignment, const py::list& group_assignment, const py::list& replication_groups, int sharding_type) { return ShardingUtil::CreateOpSharding( tile_assignment, group_assignment, replication_groups, ShardingUtil::ShardingType(sharding_type)); + }) + // Constructor for V2 shardings. + .def_init([](const py::list& dims, const py::list& reshape_dims, + const py::list& transpose_perm) { + return ShardingUtil::CreateIotaOpSharding(dims, reshape_dims, + transpose_perm); }); // Define the _XLAC.PjRtPlugin class. diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 4f0f3bf384ed..5ca04a2965da 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -358,6 +358,7 @@ cc_library( "@xla//xla/mlir_hlo:all_passes", "@xla//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", "@xla//xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo", + "@xla//xla/service/spmd/shardy/stablehlo_round_trip:stablehlo_import", ], ) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cpp b/torch_xla/csrc/runtime/pjrt_computation_client.cpp index 280b50964d82..d0b552613d13 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cpp +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cpp @@ -14,6 +14,7 @@ #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/pjrt_registry.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" +#include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/tensor_source.h" #include "torch_xla/csrc/runtime/tf_logging.h" #include "torch_xla/csrc/runtime/util.h" @@ -638,6 +639,9 @@ std::vector PjRtComputationClient::Compile( mlir::ModuleOp mlir_module = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); ConvertHloToStableHlo(instance.computation.mutable_proto(), &mlir_module); + if (runtime::sys_util::GetEnvBool("CONVERT_SHLO_TO_SHARDY", false)) { + ConvertStableHloToSdy(&mlir_module); + } executable = util::RaisePythonValueErrorOnFailure([&] { return fake_xla_compile_ ? fake_xla_compile_() diff --git a/torch_xla/csrc/runtime/stablehlo_helper.cpp b/torch_xla/csrc/runtime/stablehlo_helper.cpp index 08856778fd88..857ec5809175 100644 --- a/torch_xla/csrc/runtime/stablehlo_helper.cpp +++ b/torch_xla/csrc/runtime/stablehlo_helper.cpp @@ -18,6 +18,7 @@ #include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "xla/service/spmd/shardy/stablehlo_round_trip/stablehlo_import.h" namespace torch_xla { @@ -89,6 +90,7 @@ static absl::Status mhloToStablehloHelper(mlir::ModuleOp* mlir_module, torch_xla::runtime::CreateRemoveXlaMarkTensorOpsPass()); pm.addNestedPass(mlir::createCanonicalizerPass()); pm.addNestedPass(mlir::createCSEPass()); + if (!mlir::succeeded(pm.run(*mlir_module))) { return absl::Status( absl::StatusCode::kInternal, @@ -111,6 +113,14 @@ void ConvertHloToStableHlo(const xla::HloModuleProto* proto, << getHloModuleStr(proto); } +void ConvertStableHloToSdy(mlir::ModuleOp* mlir_module) { + mlir::PassManager pm(mlir_module->getContext()); + xla::sdy::addStablehloImportPipeline(pm, false, false); + if (!mlir::succeeded(pm.run(*mlir_module))) { + XLA_ERROR() << "StableHLO -> SDY conversion failed.\n"; + } +} + std::string hloToStablehlo(const xla::HloModuleProto* proto, bool emit_bytecode) { mlir::MLIRContext context; diff --git a/torch_xla/csrc/runtime/stablehlo_helper.h b/torch_xla/csrc/runtime/stablehlo_helper.h index bdef7b975400..2298ecfb2d18 100644 --- a/torch_xla/csrc/runtime/stablehlo_helper.h +++ b/torch_xla/csrc/runtime/stablehlo_helper.h @@ -13,6 +13,8 @@ namespace torch_xla { std::string hloToStablehlo(const xla::HloModuleProto* proto, bool emit_bytecode); +void ConvertStableHloToSdy(mlir::ModuleOp* mlir_module); + void ConvertHloToStableHlo(const xla::HloModuleProto* proto, mlir::ModuleOp* mlir_module); diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index b3f6346020d8..4a9b372899e6 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -218,6 +218,23 @@ bool ShardingUtil::EqualOpShardings(const xla::OpSharding& a, return xla::protobuf_util::HaveSameSerialization(a, b); } +xla::OpSharding ShardingUtil::CreateIotaOpSharding( + const py::list& dims, const py::list& reshape_dims, + const py::list& transpose_perm) { + auto dims_vec = dims.cast>(); + auto reshape_dims_vec = reshape_dims.cast>(); + auto transpose_perm_vec = transpose_perm.cast>(); + std::vector subgroup_types; + if (dims_vec.size() > transpose_perm.size()) { + subgroup_types.push_back(xla::OpSharding::REPLICATED); + } + return xla::HloSharding::Subgroup( + xla::TileAssignment(dims_vec, reshape_dims_vec, + transpose_perm_vec), + subgroup_types) + .ToProto(); +} + xla::OpSharding ShardingUtil::CreateOpSharding( const py::list& tile_assignment, const py::list& group_assignment, const py::list& replication_groups, ShardingType sharding_type) { diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index 8b8b98653b2f..0bcceb905611 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -51,6 +51,11 @@ class ShardingUtil { const py::list& group_assignment, const py::list& replication_groups, ShardingType sharding_type); + // Creates an xla::OpSharding for TILED and PARTIAL types using the + // HloShardingV2 system. + static xla::OpSharding CreateIotaOpSharding(const py::list& dims, + const py::list& reshape_dims, + const py::list& transpose_perm); // Returns the shape of the resulting shards of `tensor` after applying // `sharding`. This assumes the shards will be padded to ensure they all diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index be6daca582ed..acc54b6ec08a 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -1,6 +1,7 @@ import collections from collections.abc import Generator, MutableMapping import math +import os from collections import OrderedDict, defaultdict from dataclasses import dataclass, field import torch @@ -118,9 +119,18 @@ def get_axis_name_idx(self, name: str) -> int: return None return self.axis_names.index(name) + def _validate_translated_partition_spec(self, partition_spec: tuple): + flat_specs = np.hstack([d for d in partition_spec]) + specs = [d for d in flat_specs if d is not None] + assert all(d >= 0 and d < len(self.mesh_shape) for d in specs), \ + f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape." + assert len(specs) == len(np.unique(specs)), \ + f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}." + @functools.lru_cache(maxsize=None) def _get_op_sharding_args(self, partition_spec: PartitionSpec): partition_spec = _translate_named_partition_spec(self, partition_spec) + self._validate_translated_partition_spec(partition_spec) flat_specs = np.hstack([d for d in partition_spec]) specs = [d for d in flat_specs if d is not None] assert all(d >= 0 and d < len(self.mesh_shape) for d in specs), \ @@ -142,6 +152,57 @@ def _get_op_sharding_args(self, partition_spec: PartitionSpec): sharding_type = int(sharding_type) return tile_assignment, group_assignment, replication_groups, sharding_type + @functools.lru_cache(maxsize=None) + def _get_op_sharding_args_v2(self, partition_spec: PartitionSpec): + """ + Returns the appropriate dims, reshape_dims, and transpose_perm for the given partition spec. + """ + partition_spec = _translate_named_partition_spec(self, partition_spec) + self._validate_translated_partition_spec(partition_spec) + + dims = [] + used_axes = OrderedDict() + for axis in partition_spec: + if isinstance(axis, tuple): + dim_size = 1 + for i in axis: + assert i is not None, "None not allowed within tuple" + dim_size *= self.mesh_shape[i] + used_axes[i] = True + dims.append(dim_size) + elif axis is not None: + assert isinstance(axis, int), "Axis must be an int or a tuple of ints" + dims.append(self.mesh_shape[axis]) + used_axes[axis] = True + else: + # Replicated mesh axis + dims.append(1) + + transpose_perm = [k for k in used_axes.keys()] + for i in range(len(self.mesh_shape)): + if i not in used_axes: + dims.append(self.mesh_shape[i]) + transpose_perm.append(i) + reshape_dims = list(self.mesh_shape) + + return dims, reshape_dims, transpose_perm + + @functools.lru_cache(maxsize=None) + def get_op_sharding_v2( + self, partition_spec: PartitionSpec) -> torch_xla._XLAC.OpSharding: + """ + Return the OpSharding for the given partition spec using V2 annotations. + """ + if len(partition_spec) == 0: + return torch_xla._XLAC.OpSharding([], [], [], ShardingType.REPLICATED) + sharding_type = _get_sharding_type(partition_spec, self.size()) + if sharding_type not in (ShardingType.TILED, ShardingType.PARTIAL): + return torch_xla._XLAC.OpSharding([], [], [0], sharding_type) + + dims, reshape_dims, transpose_perm = self._get_op_sharding_args_v2( + partition_spec) + return torch_xla._XLAC.OpSharding(dims, reshape_dims, transpose_perm) + @functools.lru_cache(maxsize=None) def get_op_sharding( self, partition_spec: PartitionSpec) -> torch_xla._XLAC.OpSharding: @@ -157,6 +218,7 @@ def get_op_sharding( tile_assignment, group_assignment, replication_groups, sharding_type = self._get_op_sharding_args( partition_spec) + return torch_xla._XLAC.OpSharding(tile_assignment, group_assignment, replication_groups, sharding_type) @@ -654,7 +716,10 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, t.shard_(NamedSharding(jmesh, P(*partition_spec))) return t - op_sharding = mesh.get_op_sharding(partition_spec) + if os.environ.get('CONVERT_SHLO_TO_SHARDY', False): + op_sharding = mesh.get_op_sharding_v2(partition_spec) + else: + op_sharding = mesh.get_op_sharding(partition_spec) annotate_func = torch_xla._XLAC._xla_mark_sharding annotate_func(unwrap_sharded_tensor(t), op_sharding) # Pass mesh and partition spec information for DTensor compatibility From 036321ae164af9e88353876179b4d7a9faeb1211 Mon Sep 17 00:00:00 2001 From: Jonathan Azpur Date: Fri, 25 Jul 2025 20:40:32 +0000 Subject: [PATCH 124/133] Create job to build torch-xla wheel and publish to tt-pypi --- .github/workflows/_build_torch_xla_3.10.yml | 76 +++++++++++++++++++++ .github/workflows/_publish_torch_xla.yml | 63 +++++++++++++++++ .github/workflows/build_and_publish.yml | 33 +++++++++ 3 files changed, 172 insertions(+) create mode 100644 .github/workflows/_build_torch_xla_3.10.yml create mode 100644 .github/workflows/_publish_torch_xla.yml create mode 100644 .github/workflows/build_and_publish.yml diff --git a/.github/workflows/_build_torch_xla_3.10.yml b/.github/workflows/_build_torch_xla_3.10.yml new file mode 100644 index 000000000000..8499647fc8f7 --- /dev/null +++ b/.github/workflows/_build_torch_xla_3.10.yml @@ -0,0 +1,76 @@ +name: build-torch-xla +on: + workflow_call: + inputs: + torch_version: + description: 'Torch version to build (default: 2.7.0)' + required: false + type: string + default: '2.7.0' + outputs: + artifact_name: + description: 'Name of uploaded wheels artifact' + value: ${{ jobs.build-wheels.outputs.artifact_name }} + workflow_dispatch: +jobs: + build-wheels: + runs-on: ubuntu-latest + env: + ARTIFACT_NAME: install-artifact-torch-xla-release + GIT_VERSIONED_XLA_BUILD: 1 + container: + image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:tpu + options: --user root + outputs: + artifact_name: ${{ steps.set_upload_name.outputs.artifact_name }} + steps: + - name: "Build Torch/XLA wheel" + id: build_wheels + run: | + cmake --version + apt-get update && apt-get install -y curl git build-essential + + # Clean up any existing pyenv installation + rm -rf $HOME/.pyenv + + curl https://pyenv.run | bash + export PATH="$HOME/.pyenv/bin:$PATH" + eval "$(pyenv init -)" + pyenv install 3.10.12 + pyenv global 3.10.12 + ln -sf $HOME/.pyenv/versions/3.10.12/bin/python3.10 /usr/local/bin/python3.10 + + # Install essential packages for Python 3.10 + python3.10 -m pip install --upgrade pip + python3.10 -m pip install pyyaml setuptools wheel numpy typing_extensions requests + + cd /tmp + git clone --recursive --branch v${{ inputs.torch_version || '2.7.0' }} https://github.com/pytorch/pytorch.git + cd pytorch/ + git clone --recursive https://github.com/tenstorrent/pytorch-xla.git xla + + # copy pre-built wheels from cache + python3.10 setup.py bdist_wheel + python3.10 setup.py develop + + # Build PyTorch/XLA + cd xla/ + python3.10 setup.py bdist_wheel + + # Collect wheels + mkdir -p /dist + cp dist/*.whl /dist/ + + # Clean up any existing pyenv installation + rm -rf $HOME/.pyenv + + - name: "Upload Wheels Artifact" + id: upload + uses: actions/upload-artifact@v4 + with: + name: ${{ env.ARTIFACT_NAME }} + path: /dist/*.whl + + - name: Set artifact name output + id: set_upload_name + run: echo "artifact_name=${{ env.ARTIFACT_NAME }}" >> $GITHUB_OUTPUT diff --git a/.github/workflows/_publish_torch_xla.yml b/.github/workflows/_publish_torch_xla.yml new file mode 100644 index 000000000000..07f0785bed2f --- /dev/null +++ b/.github/workflows/_publish_torch_xla.yml @@ -0,0 +1,63 @@ +name: publish-wheel +on: + workflow_call: + inputs: + artifact_name: + required: true + type: string + description: 'Name of the artifact containing the wheel' + +jobs: + + publish-wheels: + name: "Publish wheels to internal PyPI" + runs-on: ubuntu-latest + permissions: + id-token: write + steps: + - name: Validate inputs + run: | + if [ -z "${{ inputs.artifact_name }}" ]; then + echo "ERROR: artifact_name input is empty or not provided!" + exit 1 + fi + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.PYPI_ROLE }} + aws-region: ${{ secrets.PYPI_REGION }} + + - name: Install s3pypi + run: | + pip install s3pypi + + - name: Download wheel artifacts + uses: actions/download-artifact@v4 + with: + name: ${{ inputs.artifact_name }} + path: ./dist + + - name: Publish wheels to internal PyPI + run: | + wheel_count=$(find ./dist -type f -name "torch_xla*.whl" | wc -l) + if [ "$wheel_count" -ne 1 ]; then + echo "ERROR: Expected exactly 1 wheel file, but found $wheel_count!" + exit 1 + fi + + wheel_file=$(find ./dist -type f -name "torch_xla*.whl" -exec realpath {} \;) + wheel_basename=$(basename "$wheel_file") + echo "Wheel file found, publishing $wheel_basename to PyPi server" + + s3pypi upload "$wheel_file" --bucket ${{ secrets.PYPI_BUCKET }} --put-root-index --force + if [ $? -ne 0 ]; then + echo "ERROR: Failed to upload $wheel_basename to S3 PyPI" + exit 1 + fi + echo "Successfully uploaded $wheel_basename" diff --git a/.github/workflows/build_and_publish.yml b/.github/workflows/build_and_publish.yml new file mode 100644 index 000000000000..80394eae9cf4 --- /dev/null +++ b/.github/workflows/build_and_publish.yml @@ -0,0 +1,33 @@ +name: Build and Publish PyTorch/XLA +on: + push: + branches: + - master + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true + +jobs: + check_code_changes: + name: Check Code Changes + uses: ./.github/workflows/_check_code_changes.yml + with: + event_name: ${{ github.event_name }} + base_sha: ${{ github.event.before }} + head_sha: ${{ github.sha }} + + build-torch-xla: + name: "Build PyTorch/XLA for Python 3.10" + if: needs.check_code_changes.outputs.has_code_changes == 'true' + uses: ./.github/workflows/_build_torch_xla_3.10.yml + needs: check_code_changes + + publish-torch-xla: + name: "Publish PyTorch/XLA" + uses: ./.github/workflows/_publish_torch_xla.yml + needs: build-torch-xla + secrets: inherit + with: + artifact_name: ${{ needs.build-torch-xla.outputs.artifact_name }} From 58da15c84f155cd0c77ed3d28d4f674c67877156 Mon Sep 17 00:00:00 2001 From: Jonathan Azpur Date: Tue, 29 Jul 2025 16:20:52 -0400 Subject: [PATCH 125/133] Add permision from caller workflow to enable job (#4) --- .github/workflows/build_and_publish.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/build_and_publish.yml b/.github/workflows/build_and_publish.yml index 80394eae9cf4..a3e9fe57c93e 100644 --- a/.github/workflows/build_and_publish.yml +++ b/.github/workflows/build_and_publish.yml @@ -29,5 +29,7 @@ jobs: uses: ./.github/workflows/_publish_torch_xla.yml needs: build-torch-xla secrets: inherit + permissions: + id-token: write with: artifact_name: ${{ needs.build-torch-xla.outputs.artifact_name }} From 24bb34c8621f2c70d702fb08b28c81f8ed103796 Mon Sep 17 00:00:00 2001 From: Sungjoon Shon Date: Sat, 2 Aug 2025 13:02:22 -0400 Subject: [PATCH 126/133] Add V2 sharding support and improve partition spec handling for multichip training (#2) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add V2 sharding support and improve partition spec handling for multi-chip training These changes are required to support multi-chip training for real models on the torch-xla side. - Added V2 OpSharding support in XlaShardingSpec, which is used internally by MpLoader for parallel input loading. The original implementation only supported V1 shardings. - Fixed environment variable parsing for CONVERT_SHLO_TO_SHARDY - previous logic treated values like "0" or "false" as truthy. - Added logic to compute dims, reshape_dims, and transpose_perm for V2 sharding based on mesh_shape and partition_spec. The new logic now correctly handles cases that were previously unsupported: case 1: mesh_shape=(2,1,1,1), partition_spec=(0,None,None,None) -> dims=[2,1,1,1], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3] case 2: mesh_shape=(2,1,1,1), partition_spec=(0,) Ã-> dims=[2], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3] case 3: mesh_shape=(2,4), partition_spec=(0,None) -> dims=[2,1,4], reshape_dims=[2,4], transpose_perm=[0,1] * Fix formatting according to Torch-XLA style guide --------- Co-authored-by: Het Shah --- torch_xla/csrc/init_python_bindings.cpp | 25 +++++++--------- torch_xla/csrc/xla_sharding_util.cpp | 19 ++++++++++++ torch_xla/csrc/xla_sharding_util.h | 3 ++ torch_xla/distributed/spmd/xla_sharding.py | 35 +++++++++++++++++----- 4 files changed, 61 insertions(+), 21 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 37150718c027..45da5ee57614 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1534,25 +1534,22 @@ void InitXlaModuleBindings(py::module m) { const py::list& replication_groups, int sharding_type, bool minibatch) { xla::Shape global_shape = - CreateComputationShapeFromTensor(tensor, nullptr); - if (minibatch) { - XLA_ASSIGN_OR_THROW( - runtime::ComputationClient * absl_nonnull const client, - runtime::GetComputationClient()); - int num_local_devices = client->GetLocalDevices().size(); - int num_global_devices = client->GetAllDevices().size(); - XLA_CHECK(tile_assignment.size() == num_global_devices) - << "Minibatch sharding only supports sharding along the batch " - "dimension"; - int batch_dim_shape = - tensor.sizes()[0] * num_global_devices / num_local_devices; - global_shape.set_dimensions(0, batch_dim_shape); - } + ShardingUtil::GetAdjustedGlobalShape(tensor, minibatch); return std::make_shared( ShardingUtil::CreateOpSharding( tile_assignment, group_assignment, replication_groups, ShardingUtil::ShardingType(sharding_type)), global_shape, minibatch); + }) + .def_init([](at::Tensor tensor, const py::list& dims, + const py::list& reshape_dims, const py::list& transpose_perm, + bool minibatch) { + xla::Shape global_shape = + ShardingUtil::GetAdjustedGlobalShape(tensor, minibatch); + return std::make_shared( + ShardingUtil::CreateIotaOpSharding(dims, reshape_dims, + transpose_perm), + global_shape, minibatch); }); // Define the _XLAC.IrValue class. diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 4a9b372899e6..e83318cf5675 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -889,4 +889,23 @@ bool ShardingUtil::GetAutoSharding() { } return use_auto_sharding; } + +xla::Shape ShardingUtil::GetAdjustedGlobalShape(const at::Tensor& tensor, + bool minibatch) { + xla::Shape global_shape = CreateComputationShapeFromTensor(tensor, nullptr); + if (minibatch) { + int num_local_devices = + runtime::GetComputationClientOrDie()->GetLocalDevices().size(); + int num_global_devices = + runtime::GetComputationClientOrDie()->GetAllDevices().size(); + XLA_CHECK(tile_assignment.size() == num_global_devices) + << "Minibatch sharding only supports sharding along the batch " + "dimension"; + int batch_dim_shape = + tensor.sizes()[0] * num_global_devices / num_local_devices; + global_shape.set_dimensions(0, batch_dim_shape); + } + return global_shape; +} + } // namespace torch_xla diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index 0bcceb905611..2cae399e2931 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -155,6 +155,9 @@ class ShardingUtil { static void SetAutoSharding(); static bool GetAutoSharding(); + + static xla::Shape GetAdjustedGlobalShape(const at::Tensor& tensor, + bool minibatch); }; } // namespace torch_xla diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index acc54b6ec08a..fe82ed47fcaa 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -154,12 +154,10 @@ def _get_op_sharding_args(self, partition_spec: PartitionSpec): @functools.lru_cache(maxsize=None) def _get_op_sharding_args_v2(self, partition_spec: PartitionSpec): - """ - Returns the appropriate dims, reshape_dims, and transpose_perm for the given partition spec. - """ partition_spec = _translate_named_partition_spec(self, partition_spec) self._validate_translated_partition_spec(partition_spec) + # 1. Calculate the initial part of dims based on the partition_spec. dims = [] used_axes = OrderedDict() for axis in partition_spec: @@ -175,14 +173,22 @@ def _get_op_sharding_args_v2(self, partition_spec: PartitionSpec): dims.append(self.mesh_shape[axis]) used_axes[axis] = True else: - # Replicated mesh axis dims.append(1) - transpose_perm = [k for k in used_axes.keys()] + # 2. If the product of dims is less than the total number of devices, + # append the sizes of the unused mesh axes. + if math.prod(dims) < math.prod(self.mesh_shape): + for i in range(len(self.mesh_shape)): + if i not in used_axes: + dims.append(self.mesh_shape[i]) + + # 3. Calculate transpose_perm (sharded axes first, then unused axes). + transpose_perm = list(used_axes.keys()) for i in range(len(self.mesh_shape)): if i not in used_axes: - dims.append(self.mesh_shape[i]) transpose_perm.append(i) + + # 4. reshape_dims is always the physical mesh shape. reshape_dims = list(self.mesh_shape) return dims, reshape_dims, transpose_perm @@ -592,6 +598,11 @@ def _mark_manual_sharding( return wrap_as_sharded_tensor(t) +def _use_shlo_to_shardy() -> bool: + return os.environ.get("CONVERT_SHLO_TO_SHARDY", + "").lower() in ("1", "true", "yes") + + def enable_manual_sharding(t: Union[torch.Tensor, XLAShardedTensor], partition_spec: PartitionSpec, *, @@ -716,7 +727,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, t.shard_(NamedSharding(jmesh, P(*partition_spec))) return t - if os.environ.get('CONVERT_SHLO_TO_SHARDY', False): + if _use_shlo_to_shardy(): op_sharding = mesh.get_op_sharding_v2(partition_spec) else: op_sharding = mesh.get_op_sharding(partition_spec) @@ -898,6 +909,9 @@ def __post_init__(self): self._group_assignment, self._replication_groups = _get_group_assignment( self._sharding_type, tile_assignment, len(partition_spec), replicate_dims) + if _use_shlo_to_shardy(): + self.dims, self.reshape_dims, self.transpose_dims = mesh._get_op_sharding_args_v2( + partition_spec) def xla_spec(self, t: torch.Tensor) -> Union['XlaShardingSpec', None]: """ @@ -906,6 +920,13 @@ def xla_spec(self, t: torch.Tensor) -> Union['XlaShardingSpec', None]: """ if not self.can_apply(t): return None + + if _use_shlo_to_shardy(): + # Convert to Shardy spec if the environment variable is set. + return torch_xla._XLAC.XlaShardingSpec(t, self.dims, self.reshape_dims, + self.transpose_dims, + self.minibatch) + return torch_xla._XLAC.XlaShardingSpec(t, self._tile_assignment, self._group_assignment, self._replication_groups, From 686cb76105512f3f3b263daccf29fc01a28aa50e Mon Sep 17 00:00:00 2001 From: Sungjoon Shon Date: Mon, 11 Aug 2025 21:06:44 +0000 Subject: [PATCH 127/133] feat: add support for custom compile options in torch_xla.compile and PJRT backend MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This change introduces the ability to pass custom compile options from Python down to the PJRT backend, allowing users to fine-tune XLA compilation behavior without modifying core code. Key changes: * Python API * Added custom_compile_options parameter to torch_xla.compile for passing compile-time options as a dict (supports bool, float, int, and str values). * Added torch_xla.set_custom_compile_options() utility for setting compile options globally. * Added internal binding _XLAC._set_custom_compile_options(). * C++ Runtime * Added SetCustomCompileOptions() virtual method to ComputationClient and implemented it in PjRtComputationClient. * PjRtComputationClient now stores custom_compile_options_ and injects them into xla::CompileOptions.env_option_overrides during compilation. * Options are stringified before being passed to XLA for compatibility. Motivation:
This enables advanced users to pass through backend-specific tuning flags (e.g., enabling experimental optimizations, toggling partitioning strategies) without hardcoding them, improving flexibility for research and debugging workflows. --- torch_xla/csrc/init_python_bindings.cpp | 10 ++++++++++ torch_xla/csrc/runtime/computation_client.h | 3 +++ .../csrc/runtime/ifrt_computation_client.h | 5 +++++ .../csrc/runtime/pjrt_computation_client.cpp | 14 +++++++++++++ .../csrc/runtime/pjrt_computation_client.h | 4 ++++ torch_xla/torch_xla.py | 20 ++++++++++++++++++- 6 files changed, 55 insertions(+), 1 deletion(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 45da5ee57614..58836598897a 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -3310,6 +3310,16 @@ void InitXlaModuleBindings(py::module m) { XLA_ERROR() << "Could not get the buffer pointer for XLATensor " "without a data handle or an IR."; }) + .def("_set_custom_compile_options", + [](const py::dict& compile_options) { + std::unordered_map options; + for (const auto& item : compile_options) { + std::string key = item.first.cast(); + options[key] = py::str(item.second).cast(); + } + runtime::GetComputationClientOrDie()->SetCustomCompileOptions( + options); + }) .def( // from an XLA tensor to a PyCapsule. // When consuming the PyCapsule, we should synchronize diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 79ff199eb2ff..beee8f4b90dc 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -446,6 +446,9 @@ class ComputationClient { // after the last ':' character of the device string. static int64_t GetDeviceOrdinal(const std::string& device); + virtual void SetCustomCompileOptions( + const std::unordered_map& options) = 0; + protected: static constexpr auto spmd_device_str = "SPMD:0"; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index 8b45922c397f..6e7875457105 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -172,6 +172,11 @@ class IfrtComputationClient : public ComputationClient { XLA_ERROR() << __FUNCTION__ << " not implemented"; } + void SetCustomCompileOptions( + const std::unordered_map& options) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + // Creates a new instance of IfrtComputationClient and initializes it. static absl::StatusOr> Create(); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cpp b/torch_xla/csrc/runtime/pjrt_computation_client.cpp index d0b552613d13..b015b8555f66 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cpp +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cpp @@ -555,6 +555,10 @@ std::vector PjRtComputationClient::Compile( for (auto& instance : instances) { xla::CompileOptions compile_options; + for (auto& option : custom_compile_options_) { + compile_options.env_option_overrides.push_back( + {option.first, option.second}); + } if (enable_cm_in_mp) { compile_options.executable_build_options.set_use_spmd_partitioning(true); compile_options.env_option_overrides.push_back( @@ -562,6 +566,7 @@ std::vector PjRtComputationClient::Compile( compile_options.env_option_overrides.push_back( {"xla_tpu_decompose_einsum_reduce_scatter", true}); } + if (instance.is_sharded) { // TODO(yeounoh) multi-host, multi-slice configurations compile_options.executable_build_options.set_use_spmd_partitioning(true); @@ -1056,5 +1061,14 @@ void PjRtComputationClient::OnReadyCallback( [callback](absl::Status unused) { callback(); }); } +void PjRtComputationClient::SetCustomCompileOptions( + const std::unordered_map& options) { + // Stringfy values + custom_compile_options_.clear(); + for (const auto& [key, value] : options) { + custom_compile_options_[key] = value; + } +} + } // namespace runtime } // namespace torch_xla diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index d550f1cce0cb..db9798bd8214 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -165,6 +165,9 @@ class PjRtComputationClient : public ComputationClient { void OnReadyCallback(DataPtr data, const std::function& callback) override; + void SetCustomCompileOptions( + const std::unordered_map& options) override; + // Creates a new instance of PjRtComputationClient and initializes it. static absl::StatusOr> Create(); @@ -197,6 +200,7 @@ class PjRtComputationClient : public ComputationClient { // If not nullptr, invoke this instead of the actual XLA compilation. Used // only for testing. std::function fake_xla_compile_ = nullptr; + std::unordered_map custom_compile_options_; xla::PjRtDevice* StringToPjRtDevice(const std::string& device); diff --git a/torch_xla/torch_xla.py b/torch_xla/torch_xla.py index 9062d6a9ef21..f0ecb6dcf687 100644 --- a/torch_xla/torch_xla.py +++ b/torch_xla/torch_xla.py @@ -116,6 +116,7 @@ def compile( full_graph: Optional[bool] = False, name: Optional[str] = None, max_different_graphs: Optional[int] = None, + custom_compile_options: Optional[dict] = None, ): """ Optimizes given model/function using torch_xla's LazyTensor tracing mode. @@ -136,6 +137,8 @@ def compile( max_different_graphs (Optional[int]): number of different traced graphs of the given model/function that we are allowed to have. An error will be raised in case this limit is exceeded. + custom_compile_options (Optional[dict]): A dictionary of custom compile options to be set. + The keys are strings and the values can be of type bool, float, int, or str. Example:: @@ -214,7 +217,8 @@ def _compile(): sync() torch_xla._XLAC._set_use_eager_mode(saved_eager_mode_status) torch_xla._XLAC._set_current_graph_name(saved_current_graph_name) - + if custom_compile_options is not None and len(custom_compile_options) > 0: + torch_xla._XLAC._set_custom_compile_options(custom_compile_options) return _compile() if f is None else _compile()(f) @@ -264,3 +268,17 @@ def launch( fn(xu.getenv_as(xenv.LOCAL_RANK, int), *args) else: xmp.spawn(fn, args=args, nprocs=nprocs, start_method=start_method) + +def set_custom_compile_options( + options: Optional[dict] = None, +): + """Sets custom compile options for the XLA compilation. + + Args: + options: A dictionary of custom compile options to be set. + The keys are strings and the values can be of type bool, float, int, or str. + """ + if options is None: + options = {} + torch_xla._XLAC._set_custom_compile_options(options) + \ No newline at end of file From 5dfbb4dbd79a5104811d7f97966ae1106e4907c1 Mon Sep 17 00:00:00 2001 From: Het Shah Date: Wed, 3 Sep 2025 12:23:04 -0400 Subject: [PATCH 128/133] Change V2 sharding spec algorithm + Fix tensor sharding spec visualization (#7) This PR adds support for all previously unsupported partition specs and fixes the visualize_tensor_sharding() function to support V2 sharding specs. See https://github.com/pytorch/xla/pull/9541 for the upstream PR discussion and additional context. * Add some tests and reviewer suggestions. Will update V2 op sharding logic in a later commit soon. * New implementation (WIP) * Fix new implementation * Fix visualize_tensor_sharding function for V2 shardings --- test/spmd/test_spmd_debugging.py | 72 ++++++++++++++++ test/spmd/test_xla_sharding.py | 99 ++++++++++++++-------- torch_xla/csrc/init_python_bindings.cpp | 43 ++++++++-- torch_xla/csrc/xla_sharding_util.cpp | 13 +-- torch_xla/csrc/xla_sharding_util.h | 3 +- torch_xla/distributed/spmd/debugging.py | 25 +++++- torch_xla/distributed/spmd/xla_sharding.py | 96 +++++++++++---------- torch_xla/torch_xla.py | 7 +- 8 files changed, 262 insertions(+), 96 deletions(-) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index 2f126f00955e..4d5acc008b20 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -17,6 +17,7 @@ import torch_xla.distributed.spmd as xs from torch_xla.distributed.spmd import XLAShardedTensor from torch_xla.distributed.spmd import Mesh +from torch_xla.distributed.spmd.debugging import construct_v1_sharding_str import test_xla_sharding_base @@ -822,6 +823,77 @@ def test_multi_host_replicated_cpu(self): fake_output = fake_capture.get() assert output == fake_output + +class ConvertV2ShardingToV1Test(test_xla_sharding_base.XlaShardingTest): + + @classmethod + def setUpClass(cls): + super().setUpClass() + os.environ["CONVERT_SHLO_TO_SHARDY"] = "1" + + def run_test(self): + mesh = self._get_mesh(self.device_mesh_shape) + t = torch.randn(self.tensor_shape).to(torch_xla.device()) + xs.mark_sharding(t, mesh, self.partition_spec) + actual_str = construct_v1_sharding_str(t) + self.assertEqual(self.expected_str, actual_str) + + def test_tiled_sharding(self): + self.device_mesh_shape = (1, self.n_devices) + self.tensor_shape = (1, 128) + self.partition_spec = (0, 1) + self.expected_str = '{devices=[1,%d]%s}' % (self.n_devices, ','.join( + [str(i) for i in range(self.n_devices)])) + self.run_test() + + @unittest.skipIf(xr.global_runtime_device_count() < 2, + f"Requires at least 2 devices.") + def test_tupled_tiled_sharding(self): + self.device_mesh_shape = (2, self.n_devices // 2) + self.tensor_shape = (16,) + self.partition_spec = ((0, 1),) + self.expected_str = "{devices=[%d]%s}" % (self.n_devices, ','.join( + str(x) for x in range(self.n_devices))) + self.run_test() + + def test_replicated_sharding(self): + self.device_mesh_shape = (1, self.n_devices) + self.tensor_shape = (4, 4) + self.partition_spec = (None, None) + self.expected_str = '{replicated}' + self.run_test() + + @unittest.skipIf(xr.global_runtime_device_count() < 4, + f"Requires at least 4 devices.") + def test_partial_replication_sharding(self): + self.device_mesh_shape = (2, self.n_devices // 2) + self.tensor_shape = (4, 4) + self.partition_spec = (0, None) + self.expected_str = '{devices=[2,1,%d]%s last_tile_dim_replicate}' % ( + self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))) + self.run_test() + + @unittest.skipIf(xr.global_runtime_device_count() < 4, + f"Requires at least 4 devices.") + def test_tupled_partial_replication_sharding(self): + self.device_mesh_shape = (1, 2, self.n_devices // 2) + self.tensor_shape = (16, 16) + self.partition_spec = ((0, 1), None) + self.expected_str = "{devices=[2,1,%d]%s last_tile_dim_replicate}" % ( + self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))) + self.run_test() + + def test_tupled_partial_replication_sharding_with_transpose(self): + self.device_mesh_shape = (1, 2, self.n_devices // 2) + self.tensor_shape = (16, 16) + self.partition_spec = (None, (2, 1)) + device_order = self.device_ids.reshape(self.device_mesh_shape).transpose( + (2, 1, 0)).flatten() + self.expected_str = "{devices=[1,%d]%s}" % (self.n_devices, ','.join( + str(x) for x in device_order)) + self.run_test() + + if __name__ == '__main__': test = unittest.main() sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 48b760f6e3f0..b74709d66f60 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -31,6 +31,7 @@ class BasicXlaShardingTest(test_xla_sharding_base.XlaShardingTest): @classmethod def setUpClass(cls): super().setUpClass() + cls.convert_to_shardy = xu.check_env_flag("CONVERT_SHLO_TO_SHARDY") def test_xla_sharded_tensor(self): partition_spec = (0, 1) @@ -238,6 +239,8 @@ def test_custom_tile_assignment(self): if self.n_devices > 1: annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join( [str(i) for i in reversed(range(self.n_devices))])) + if self.convert_to_shardy: + annotation = '{devices=[1,%d]<=[%d]}' % (self.n_devices, self.n_devices) self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt)) def test_mark_sharding_2d(self): @@ -252,6 +255,8 @@ def test_mark_sharding_2d(self): if self.n_devices > 1: annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join( [str(i) for i in range(self.n_devices)])) + if self.convert_to_shardy: + annotation = '{devices=[1,%d]<=[%d]}' % (self.n_devices, self.n_devices) self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt1)) actual = (xt1 + xt2).cpu() @@ -271,6 +276,9 @@ def test_mark_sharding_4d(self): annotation = '{devices=[1,1,%d,%d]%s}' % ( z_dim, self.n_devices // z_dim, ','.join( [str(i) for i in range(self.n_devices)])) + if self.convert_to_shardy: + annotation = '{devices=[1,1,%d,%d]<=[%d]}' % (z_dim, self.n_devices // + z_dim, self.n_devices) self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt)) actual = (xt + xt).cpu() @@ -403,9 +411,11 @@ def test_tupled_partition_spec(self): mesh = self._get_mesh((2, self.n_devices // 2)) t = torch.randn(16).to('xla') xs.mark_sharding(t, mesh, ((0, 1),)) - self.assertEqual( - torch_xla._XLAC._get_xla_sharding_spec(t), "{devices=[%d]%s}" % - (self.n_devices, ','.join(str(x) for x in range(self.n_devices)))) + annotation = "{devices=[%d]%s}" % (self.n_devices, ','.join( + str(x) for x in range(self.n_devices))) + if self.convert_to_shardy: + annotation = "{devices=[%d]<=[%d]}" % (self.n_devices, self.n_devices) + self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation) @unittest.skipUnless(xr.global_runtime_device_count() >= 4, "Multiple devices required for tupled partition spec") @@ -415,34 +425,43 @@ def test_named_partial_tupled_partition_spec(self): # Shard the first dimension on `r` and `b`, replicate the second dimension t = torch.randn(16, 16).to('xla') xs.mark_sharding(t, mesh, (('r', 'b'), None)) - self.assertEqual( - torch_xla._XLAC._get_xla_sharding_spec(t), - "{devices=[2,1,%d]%s last_tile_dim_replicate}" % - (self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))) + annotation = "{devices=[2,1,%d]%s last_tile_dim_replicate}" % ( + self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))) + if self.convert_to_shardy: + annotation = "{devices=[2,1,%d]<=[%d] last_tile_dim_replicate}" % ( + self.n_devices // 2, self.n_devices) + self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation) # Replicate the first dimension, shard the second on `b` and `m` u = torch.randn(16, 16).to('xla') xs.mark_sharding(u, mesh, (None, ('b', 'm'))) - self.assertEqual( - torch_xla._XLAC._get_xla_sharding_spec(u), "{devices=[1,%d]%s}" % - (self.n_devices, ','.join(str(x) for x in range(self.n_devices)))) + annotation = "{devices=[1,%d]%s}" % (self.n_devices, ','.join( + str(x) for x in range(self.n_devices))) + if self.convert_to_shardy: + annotation = "{devices=[1,%d]<=[%d]}" % (self.n_devices, self.n_devices) + self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(u), annotation) # Replicate the first dimension, shard the second on `r` and `m` v = torch.randn(16, 16).to('xla') xs.mark_sharding(v, mesh, (None, ('r', 'm'))) device_order = mesh.get_logical_mesh().transpose((0, 2, 1)).flatten() - self.assertEqual( - torch_xla._XLAC._get_xla_sharding_spec(v), - "{devices=[1,%d,2]%s last_tile_dim_replicate}" % - (self.n_devices // 2, ','.join(str(x) for x in device_order))) + annotation = "{devices=[1,%d,2]%s last_tile_dim_replicate}" % ( + self.n_devices // 2, ','.join(str(x) for x in device_order)) + if self.convert_to_shardy: + annotation = "{devices=[1,%d,2]<=[2,%d]T(1,0) last_tile_dim_replicate}" % ( + self.n_devices // 2, self.n_devices // 2) + self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(v), annotation) # Replicate the first dimension, shard the second on `m` and `b` v = torch.randn(16, 16).to('xla') xs.mark_sharding(v, mesh, (None, ('m', 'b'))) device_order = mesh.get_logical_mesh().transpose((2, 1, 0)).flatten() - self.assertEqual( - torch_xla._XLAC._get_xla_sharding_spec(v), "{devices=[1,%d]%s}" % - (self.n_devices, ','.join(str(x) for x in device_order))) + annotation = "{devices=[1,%d]%s}" % (self.n_devices, ','.join( + str(x) for x in device_order)) + if self.convert_to_shardy: + annotation = "{devices=[1,%d]<=[2,%d]T(1,0)}" % (self.n_devices, + self.n_devices // 2) + self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(v), annotation) @unittest.skipUnless(xr.global_runtime_device_count() > 1, 'Multiple devices required for tupled partition spec') @@ -452,9 +471,12 @@ def test_multiple_tuples_in_spec(self): ('a', 'b', 'c', 'd')) t = torch.randn(2, 2).to('xla') xs.mark_sharding(t, mesh, (('a', 'b'), ('c', 'd'))) - self.assertEqual( - torch_xla._XLAC._get_xla_sharding_spec(t), "{devices=[2,%d]%s}" % - (self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))) + annotation = "{devices=[2,%d]%s}" % (self.n_devices // 2, ','.join( + str(x) for x in range(self.n_devices))) + if self.convert_to_shardy: + annotation = "{devices=[2,%d]<=[%d]}" % (self.n_devices // 2, + self.n_devices) + self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation) @unittest.skipUnless(xr.global_runtime_device_count() > 1, 'At least 2 devices needed for 2D mesh') @@ -462,9 +484,12 @@ def test_3d_tensor_2d_mesh(self): mesh = self._get_mesh((2, self.n_devices // 2)) t = torch.randn(16, 16, 16).to('xla') xs.mark_sharding(t, mesh, (None, 0, 1)) - self.assertEqual( - torch_xla._XLAC._get_xla_sharding_spec(t), '{devices=[1,2,%d]%s}' % - (self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))) + annotation = '{devices=[1,2,%d]%s}' % (self.n_devices // 2, ','.join( + str(x) for x in range(self.n_devices))) + if self.convert_to_shardy: + annotation = '{devices=[1,2,%d]<=[%d]}' % (self.n_devices // 2, + self.n_devices) + self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation) def test_partial_replication_addmm(self): device = torch_xla.device() @@ -983,18 +1008,20 @@ def test_op_sharding_cache(self): t = torch.randn(1, self.n_devices).to('xla') xs.mark_sharding(t, mesh, (0, 1)) - self.assertIn("CreateOpSharding", met.counter_names()) - self.assertEqual(met.counter_value("CreateOpSharding"), 1) + counter_name = "CreateIotaOpSharding" if self.convert_to_shardy else "CreateOpSharding" + self.assertIn(counter_name, met.counter_names()) + self.assertEqual(met.counter_value(counter_name), 1) # Sharding with the same partition spec should not result in another call u = torch.randn(1, self.n_devices).to('xla') xs.mark_sharding(u, mesh, (0, 1)) - self.assertEqual(met.counter_value("CreateOpSharding"), 1) + self.assertEqual(met.counter_value(counter_name), 1) - # Changing the partition spec will result in another CreateOpSharding + # Changing the partition spec will result in another + # CreateOpSharding or CreatingIotaOpSharding call v = torch.randn(1, self.n_devices).to('xla') xs.mark_sharding(v, mesh, (0, None)) - self.assertEqual(met.counter_value("CreateOpSharding"), 2) + self.assertEqual(met.counter_value(counter_name), 2) def test_from_cpu_shards_replicated(self): from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards @@ -1397,10 +1424,10 @@ def test_data_loader_with_sharding(self): input_sharding=xs.ShardingSpec(mesh, ('data', None, None, None))) data, _ = iter(train_device_loader).__next__() self.assertEqual(data.size(), torch.Size([8, 3, 64, 64])) - self.assertEqual( - torch_xla._XLAC._get_xla_sharding_spec(data), - f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}" - ) + annotation = f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}" + if self.convert_to_shardy: + annotation = f"{{devices=[{mesh.size()},1,1,1]<=[{mesh.size()}]}}" + self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(data), annotation) @unittest.skipUnless( xr.global_runtime_device_count() > 1, @@ -1420,10 +1447,10 @@ def test_data_loader_with_non_batch_size(self): input_sharding=xs.ShardingSpec(mesh, ('data', None, None, None))) data, _ = iter(train_device_loader).__next__() self.assertEqual(data.size(), torch.Size([mesh.size() - 1, 3, 64, 64])) - self.assertEqual( - torch_xla._XLAC._get_xla_sharding_spec(data), - f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}" - ) + annotation = f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}" + if self.convert_to_shardy: + annotation = f"{{devices=[{mesh.size()},1,1,1]<=[{mesh.size()}]}}" + self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(data), annotation) @unittest.skipUnless( xr.global_runtime_device_count() > 1, diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 58836598897a..755606f8a8c0 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -760,6 +760,16 @@ std::string GetTensorsHloGraph(const std::vector& tensors, return XLAGraphExecutor::Get()->DumpHloComputation(xtensors, mode); } +std::optional GetXLAOpSharding(const at::Tensor& input) { + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + XLATensor::ShardingSpecPtr sharding_spec = + xtensor ? xtensor->sharding_spec() : nullptr; + if (sharding_spec != nullptr) { + return sharding_spec->sharding; + } + return std::nullopt; +} + std::string GetXLAShardingSpec(const XLATensorPtr xtensor) { auto sharding_spec = xtensor->sharding_spec(); if (sharding_spec != nullptr) { @@ -1526,6 +1536,10 @@ at::Tensor tensor_fromDLPack(PyObject* data) { void InitXlaModuleBindings(py::module m) { PythonScope module(m); + using TileAssignmentDims = std::vector; + using ReshapeDims = std::vector; + using TransposePerm = std::vector; + // Define the _XLAC.XlaShardingSpec class. PythonScope>( m, "XlaShardingSpec") @@ -1543,12 +1557,12 @@ void InitXlaModuleBindings(py::module m) { }) .def_init([](at::Tensor tensor, const py::list& dims, const py::list& reshape_dims, const py::list& transpose_perm, - bool minibatch) { + const py::list& types, bool minibatch) { xla::Shape global_shape = ShardingUtil::GetAdjustedGlobalShape(tensor, minibatch); return std::make_shared( ShardingUtil::CreateIotaOpSharding(dims, reshape_dims, - transpose_perm), + transpose_perm, types), global_shape, minibatch); }); @@ -1578,9 +1592,9 @@ void InitXlaModuleBindings(py::module m) { }) // Constructor for V2 shardings. .def_init([](const py::list& dims, const py::list& reshape_dims, - const py::list& transpose_perm) { + const py::list& transpose_perm, const py::list& types) { return ShardingUtil::CreateIotaOpSharding(dims, reshape_dims, - transpose_perm); + transpose_perm, types); }); // Define the _XLAC.PjRtPlugin class. @@ -2703,7 +2717,26 @@ void InitXlaModuleBindings(py::module m) { if (sharding_spec != nullptr) { return sharding_spec->sharding; } - return std::nullopt; + return GetXLAOpSharding(input); + }) + .def("_get_xla_op_sharding_v2_params", + [](const at::Tensor& input) -> std::optional> { + std::optional maybe_sharding = + GetXLAOpSharding(input); + if (!maybe_sharding) { + return std::nullopt; + } + const xla::OpSharding& sharding = maybe_sharding.value(); + TileAssignmentDims tile_assignment_dims( + sharding.tile_assignment_dimensions().begin(), + sharding.tile_assignment_dimensions().end()); + ReshapeDims reshape_dims(sharding.iota_reshape_dims().begin(), + sharding.iota_reshape_dims().end()); + TransposePerm transpose_perm(sharding.iota_transpose_perm().begin(), + sharding.iota_transpose_perm().end()); + return std::make_tuple(tile_assignment_dims, reshape_dims, + transpose_perm, + sharding.replicate_on_last_tile_dim()); }) .def("_get_xla_sharding_specs", [](const std::vector& tensors) diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index e83318cf5675..16f498299123 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -220,18 +220,21 @@ bool ShardingUtil::EqualOpShardings(const xla::OpSharding& a, xla::OpSharding ShardingUtil::CreateIotaOpSharding( const py::list& dims, const py::list& reshape_dims, - const py::list& transpose_perm) { + const py::list& transpose_perm, const py::list& types) { + TORCH_LAZY_COUNTER("CreateIotaOpSharding", 1); auto dims_vec = dims.cast>(); auto reshape_dims_vec = reshape_dims.cast>(); auto transpose_perm_vec = transpose_perm.cast>(); - std::vector subgroup_types; - if (dims_vec.size() > transpose_perm.size()) { - subgroup_types.push_back(xla::OpSharding::REPLICATED); + std::vector subgroup_types_vec; + for (auto type : types) { + subgroup_types_vec.push_back( + static_cast(type.cast())); } + CHECK_EQ(reshape_dims_vec.size(), transpose_perm_vec.size()); return xla::HloSharding::Subgroup( xla::TileAssignment(dims_vec, reshape_dims_vec, transpose_perm_vec), - subgroup_types) + subgroup_types_vec) .ToProto(); } diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index 2cae399e2931..a925c470748a 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -55,7 +55,8 @@ class ShardingUtil { // HloShardingV2 system. static xla::OpSharding CreateIotaOpSharding(const py::list& dims, const py::list& reshape_dims, - const py::list& transpose_perm); + const py::list& transpose_perm, + const py::list& types); // Returns the shape of the resulting shards of `tensor` after applying // `sharding`. This assumes the shards will be padded to ensure they all diff --git a/torch_xla/distributed/spmd/debugging.py b/torch_xla/distributed/spmd/debugging.py index e5f53d04aea1..2cb9368aff08 100644 --- a/torch_xla/distributed/spmd/debugging.py +++ b/torch_xla/distributed/spmd/debugging.py @@ -157,6 +157,27 @@ def visualize_sharding(sharding: str, return table +def construct_v1_sharding_str(t: torch.Tensor) -> str: + """ + Returns the corresponding HLO V1 sharding string from the tensor + """ + sharding = torch_xla._XLAC._get_xla_sharding_spec(t) + if "<=" not in sharding: + # This is already in the V1 format + return sharding + sharding_params = torch_xla._XLAC._get_xla_op_sharding_v2_params(t) + assert sharding_params is not None + tile_assignment_dims, reshape_dims, transpose_perm, replicate_on_last_dim = sharding_params + num_devices = np.prod(reshape_dims) + device_list = np.arange(num_devices).reshape(reshape_dims).transpose( + transpose_perm).reshape(num_devices) + + tile_assignment_str = ",".join(str(dim) for dim in tile_assignment_dims) + device_list_str = ",".join(str(i) for i in device_list) + replicate_str = " last_tile_dim_replicate" if replicate_on_last_dim else "" + return f"{{devices=[{tile_assignment_str}]{device_list_str}{replicate_str}}}" + + def visualize_tensor_sharding(t, **kwargs): """Visualizes an array's sharding.""" @@ -164,5 +185,7 @@ def visualize_tensor_sharding(t, **kwargs): def maybe_unwrap(t: torch.Tensor) -> torch.Tensor: return t.global_tensor if isinstance(t, XLAShardedTensor) else t - sharding = torch_xla._XLAC._get_xla_sharding_spec(maybe_unwrap(t)) + t = maybe_unwrap(t) + sharding = construct_v1_sharding_str(t) + return visualize_sharding(sharding, **kwargs) diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index fe82ed47fcaa..f081e6111429 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -131,12 +131,6 @@ def _validate_translated_partition_spec(self, partition_spec: tuple): def _get_op_sharding_args(self, partition_spec: PartitionSpec): partition_spec = _translate_named_partition_spec(self, partition_spec) self._validate_translated_partition_spec(partition_spec) - flat_specs = np.hstack([d for d in partition_spec]) - specs = [d for d in flat_specs if d is not None] - assert all(d >= 0 and d < len(self.mesh_shape) for d in specs), \ - f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape." - assert len(specs) == len(np.unique(specs)), \ - f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}." tile_assignment = _get_tile_assignment(self, partition_spec) if len(tile_assignment.shape) > len(partition_spec): @@ -154,44 +148,58 @@ def _get_op_sharding_args(self, partition_spec: PartitionSpec): @functools.lru_cache(maxsize=None) def _get_op_sharding_args_v2(self, partition_spec: PartitionSpec): + """ + This function returns all the sharding parameters needed for TILED or PARTIAL sharding. + (All other sharding types are handled separately by the V1 OpSharding function) + """ partition_spec = _translate_named_partition_spec(self, partition_spec) self._validate_translated_partition_spec(partition_spec) - # 1. Calculate the initial part of dims based on the partition_spec. - dims = [] - used_axes = OrderedDict() - for axis in partition_spec: - if isinstance(axis, tuple): - dim_size = 1 - for i in axis: - assert i is not None, "None not allowed within tuple" - dim_size *= self.mesh_shape[i] - used_axes[i] = True - dims.append(dim_size) - elif axis is not None: - assert isinstance(axis, int), "Axis must be an int or a tuple of ints" - dims.append(self.mesh_shape[axis]) - used_axes[axis] = True - else: - dims.append(1) - - # 2. If the product of dims is less than the total number of devices, - # append the sizes of the unused mesh axes. - if math.prod(dims) < math.prod(self.mesh_shape): - for i in range(len(self.mesh_shape)): - if i not in used_axes: - dims.append(self.mesh_shape[i]) + # This algorithm is adapted from + # https://github.com/openxla/xla/blob/256b633e0adaee80588a8c3a5e4b2eaa005b5414/xla/service/spmd/shardy/stablehlo_round_trip/export_shardings.cc#L288 + tile_assignment_dims = [1] * len(partition_spec) + axisRefToShardedPos = {} + subgroup_types = [] + shardedPos = 0 - # 3. Calculate transpose_perm (sharded axes first, then unused axes). - transpose_perm = list(used_axes.keys()) - for i in range(len(self.mesh_shape)): - if i not in used_axes: - transpose_perm.append(i) + for idx, axes in enumerate(partition_spec): + if axes is None: + # Tensor dim is being replicated + continue + elif isinstance(axes, tuple): + # Tensor dim is being sharded over multiple axes + for axis in axes: + tile_assignment_dims[idx] *= self.mesh_shape[axis] + axisRefToShardedPos[axis] = shardedPos + shardedPos += 1 + else: + # Tensor dim is being sharded over just 1 axis + tile_assignment_dims[idx] *= self.mesh_shape[axes] + axisRefToShardedPos[axes] = shardedPos + shardedPos += 1 + + all_axes_ordered = [i for i in range(len(self.mesh_shape))] + reshape_dims = [0] * len(all_axes_ordered) + transpose_perm = [0] * len(all_axes_ordered) + + totalReplicatedSize = 1 + replicatedPos = shardedPos + for idx, axis in enumerate(all_axes_ordered): + reshape_dims[idx] = self.mesh_shape[axis] + if axis in axisRefToShardedPos: + # Axis is sharded + transpose_perm[axisRefToShardedPos[axis]] = idx + else: + # Axis is replicated + transpose_perm[replicatedPos] = idx + replicatedPos += 1 + totalReplicatedSize *= self.mesh_shape[axis] - # 4. reshape_dims is always the physical mesh shape. - reshape_dims = list(self.mesh_shape) + if totalReplicatedSize > 1: + tile_assignment_dims.append(totalReplicatedSize) + subgroup_types.append(ShardingType.REPLICATED) - return dims, reshape_dims, transpose_perm + return tile_assignment_dims, reshape_dims, transpose_perm, subgroup_types @functools.lru_cache(maxsize=None) def get_op_sharding_v2( @@ -203,11 +211,11 @@ def get_op_sharding_v2( return torch_xla._XLAC.OpSharding([], [], [], ShardingType.REPLICATED) sharding_type = _get_sharding_type(partition_spec, self.size()) if sharding_type not in (ShardingType.TILED, ShardingType.PARTIAL): - return torch_xla._XLAC.OpSharding([], [], [0], sharding_type) + return torch_xla._XLAC.OpSharding([], [], [], sharding_type) - dims, reshape_dims, transpose_perm = self._get_op_sharding_args_v2( + dims, reshape_dims, transpose_perm, types = self._get_op_sharding_args_v2( partition_spec) - return torch_xla._XLAC.OpSharding(dims, reshape_dims, transpose_perm) + return torch_xla._XLAC.OpSharding(dims, reshape_dims, transpose_perm, types) @functools.lru_cache(maxsize=None) def get_op_sharding( @@ -910,7 +918,7 @@ def __post_init__(self): self._sharding_type, tile_assignment, len(partition_spec), replicate_dims) if _use_shlo_to_shardy(): - self.dims, self.reshape_dims, self.transpose_dims = mesh._get_op_sharding_args_v2( + self.dims, self.reshape_dims, self.transpose_perm, self.subgroup_types = mesh._get_op_sharding_args_v2( partition_spec) def xla_spec(self, t: torch.Tensor) -> Union['XlaShardingSpec', None]: @@ -922,9 +930,9 @@ def xla_spec(self, t: torch.Tensor) -> Union['XlaShardingSpec', None]: return None if _use_shlo_to_shardy(): - # Convert to Shardy spec if the environment variable is set. return torch_xla._XLAC.XlaShardingSpec(t, self.dims, self.reshape_dims, - self.transpose_dims, + self.transpose_perm, + self.subgroup_types, self.minibatch) return torch_xla._XLAC.XlaShardingSpec(t, self._tile_assignment, diff --git a/torch_xla/torch_xla.py b/torch_xla/torch_xla.py index f0ecb6dcf687..1ab241dcf12f 100644 --- a/torch_xla/torch_xla.py +++ b/torch_xla/torch_xla.py @@ -217,6 +217,7 @@ def _compile(): sync() torch_xla._XLAC._set_use_eager_mode(saved_eager_mode_status) torch_xla._XLAC._set_current_graph_name(saved_current_graph_name) + if custom_compile_options is not None and len(custom_compile_options) > 0: torch_xla._XLAC._set_custom_compile_options(custom_compile_options) return _compile() if f is None else _compile()(f) @@ -269,9 +270,8 @@ def launch( else: xmp.spawn(fn, args=args, nprocs=nprocs, start_method=start_method) -def set_custom_compile_options( - options: Optional[dict] = None, -): + +def set_custom_compile_options(options: Optional[dict] = None,): """Sets custom compile options for the XLA compilation. Args: @@ -281,4 +281,3 @@ def set_custom_compile_options( if options is None: options = {} torch_xla._XLAC._set_custom_compile_options(options) - \ No newline at end of file From 7bc474accaef4035447e2cccbea6093b171c95df Mon Sep 17 00:00:00 2001 From: ddilbazTT Date: Tue, 2 Sep 2025 09:58:08 -0400 Subject: [PATCH 129/133] Uplift wheel python 3.10 to 3.11 --- .github/workflows/_build_torch_xla_3.11.yml | 76 +++++++++++++++++++++ .github/workflows/_publish_torch_xla.yml | 2 +- .github/workflows/build_and_publish.yml | 4 +- 3 files changed, 79 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/_build_torch_xla_3.11.yml diff --git a/.github/workflows/_build_torch_xla_3.11.yml b/.github/workflows/_build_torch_xla_3.11.yml new file mode 100644 index 000000000000..5306455b0ada --- /dev/null +++ b/.github/workflows/_build_torch_xla_3.11.yml @@ -0,0 +1,76 @@ +name: build-torch-xla +on: + workflow_call: + inputs: + torch_version: + description: 'Torch version to build (default: 2.7.0)' + required: false + type: string + default: '2.7.0' + outputs: + artifact_name: + description: 'Name of uploaded wheels artifact' + value: ${{ jobs.build-wheels.outputs.artifact_name }} + workflow_dispatch: +jobs: + build-wheels: + runs-on: ubuntu-latest + env: + ARTIFACT_NAME: install-artifact-torch-xla-release + GIT_VERSIONED_XLA_BUILD: 1 + container: + image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:tpu + options: --user root + outputs: + artifact_name: ${{ steps.set_upload_name.outputs.artifact_name }} + steps: + - name: "Build Torch/XLA wheel" + id: build_wheels + run: | + cmake --version + apt-get update && apt-get install -y curl git build-essential + + # Clean up any existing pyenv installation + rm -rf $HOME/.pyenv + + curl https://pyenv.run | bash + export PATH="$HOME/.pyenv/bin:$PATH" + eval "$(pyenv init -)" + pyenv install 3.11 + pyenv global 3.11 + ln -sf $HOME/.pyenv/versions/3.11/bin/python3.11 /usr/local/bin/python3.11 + + # Install essential packages for Python 3.11 + python3.11 -m pip install --upgrade pip + python3.11 -m pip install pyyaml setuptools wheel numpy typing_extensions requests + + cd /tmp + git clone --recursive --branch v${{ inputs.torch_version || '2.7.0' }} https://github.com/pytorch/pytorch.git + cd pytorch/ + git clone --recursive https://github.com/tenstorrent/pytorch-xla.git xla + + # copy pre-built wheels from cache + python3.11 setup.py bdist_wheel + python3.11 setup.py develop + + # Build PyTorch/XLA + cd xla/ + python3.11 setup.py bdist_wheel + + # Collect wheels + mkdir -p /dist + cp dist/*.whl /dist/ + + # Clean up any existing pyenv installation + rm -rf $HOME/.pyenv + + - name: "Upload Wheels Artifact" + id: upload + uses: actions/upload-artifact@v4 + with: + name: ${{ env.ARTIFACT_NAME }} + path: /dist/*.whl + + - name: Set artifact name output + id: set_upload_name + run: echo "artifact_name=${{ env.ARTIFACT_NAME }}" >> $GITHUB_OUTPUT diff --git a/.github/workflows/_publish_torch_xla.yml b/.github/workflows/_publish_torch_xla.yml index 07f0785bed2f..cb77a07c9c5d 100644 --- a/.github/workflows/_publish_torch_xla.yml +++ b/.github/workflows/_publish_torch_xla.yml @@ -25,7 +25,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: '3.10' + python-version: '3.11' - name: Configure AWS Credentials uses: aws-actions/configure-aws-credentials@v4 diff --git a/.github/workflows/build_and_publish.yml b/.github/workflows/build_and_publish.yml index a3e9fe57c93e..35d8f9510ebb 100644 --- a/.github/workflows/build_and_publish.yml +++ b/.github/workflows/build_and_publish.yml @@ -19,9 +19,9 @@ jobs: head_sha: ${{ github.sha }} build-torch-xla: - name: "Build PyTorch/XLA for Python 3.10" + name: "Build PyTorch/XLA for Python 3.11" if: needs.check_code_changes.outputs.has_code_changes == 'true' - uses: ./.github/workflows/_build_torch_xla_3.10.yml + uses: ./.github/workflows/_build_torch_xla_3.11.yml needs: check_code_changes publish-torch-xla: From a2514dd3ad2c4ad9192f2f11d76e141071173e12 Mon Sep 17 00:00:00 2001 From: Jonathan Azpur Date: Fri, 5 Sep 2025 19:16:42 -0400 Subject: [PATCH 130/133] Update jax dependency to 0.7.1 to align with tt front ends (#8) --- .github/workflows/_build_torch_xla_3.11.yml | 2 +- setup.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/_build_torch_xla_3.11.yml b/.github/workflows/_build_torch_xla_3.11.yml index 5306455b0ada..57f4d2bed6ea 100644 --- a/.github/workflows/_build_torch_xla_3.11.yml +++ b/.github/workflows/_build_torch_xla_3.11.yml @@ -38,7 +38,7 @@ jobs: eval "$(pyenv init -)" pyenv install 3.11 pyenv global 3.11 - ln -sf $HOME/.pyenv/versions/3.11/bin/python3.11 /usr/local/bin/python3.11 + ln -sf $(pyenv which python3.11) /usr/local/bin/python3.11 # Install essential packages for Python 3.11 python3.11 -m pip install --upgrade pip diff --git a/setup.py b/setup.py index 33642f9a3f6e..c6a7e2465f8b 100644 --- a/setup.py +++ b/setup.py @@ -110,14 +110,14 @@ # 4. After the local build succeeds, create a PR and wait for the CI result. Fix # CI errors as needed until all required checks pass. -USE_NIGHTLY = True # Whether to use nightly or stable libtpu and JAX. +USE_NIGHTLY = False # Whether to use nightly or stable libtpu and JAX. _libtpu_version = '0.0.21' _libtpu_date = '20250813' _jax_version = '0.7.1' _jaxlib_version = '0.7.1' -_jax_date = '20250813' # Date for jax and jaxlib. +_jax_date = '20250617' # Date for jax and jaxlib. if USE_NIGHTLY: _libtpu_version += f".dev{_libtpu_date}+nightly" From 86bac8b47f7454e73c174bb8b0b4623feaac8cfd Mon Sep 17 00:00:00 2001 From: Sungjoon Shon Date: Mon, 6 Oct 2025 18:53:38 +0000 Subject: [PATCH 131/133] Fix for API match fix for api match --- torch_xla/csrc/init_python_bindings.cpp | 13 +++++++++---- torch_xla/csrc/runtime/pjrt_computation_client.cpp | 1 - torch_xla/csrc/xla_sharding_util.cpp | 11 ++++------- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 755606f8a8c0..a80f05a4cd59 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -761,9 +761,12 @@ std::string GetTensorsHloGraph(const std::vector& tensors, } std::optional GetXLAOpSharding(const at::Tensor& input) { - XLATensorPtr xtensor = bridge::GetXlaTensor(input); + auto xtensor = bridge::GetXlaTensor(input); + if (!xtensor.ok()) { + return std::nullopt; + } XLATensor::ShardingSpecPtr sharding_spec = - xtensor ? xtensor->sharding_spec() : nullptr; + xtensor.value() ? xtensor.value()->sharding_spec() : nullptr; if (sharding_spec != nullptr) { return sharding_spec->sharding; } @@ -3350,8 +3353,10 @@ void InitXlaModuleBindings(py::module m) { std::string key = item.first.cast(); options[key] = py::str(item.second).cast(); } - runtime::GetComputationClientOrDie()->SetCustomCompileOptions( - options); + XLA_ASSIGN_OR_THROW( + runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + client->SetCustomCompileOptions(options); }) .def( // from an XLA tensor to a PyCapsule. diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cpp b/torch_xla/csrc/runtime/pjrt_computation_client.cpp index b015b8555f66..e1307ba04e6d 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cpp +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cpp @@ -140,7 +140,6 @@ absl::Status PjRtComputationClient::Initialize() { auto tracked_devices = GetLocalDevices(); tracked_devices.emplace_back(spmd_device_str); operation_manager_ = std::move(OperationManager(std::move(tracked_devices))); - return absl::OkStatus(); } diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 16f498299123..2c6de2e8b33e 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -897,13 +897,10 @@ xla::Shape ShardingUtil::GetAdjustedGlobalShape(const at::Tensor& tensor, bool minibatch) { xla::Shape global_shape = CreateComputationShapeFromTensor(tensor, nullptr); if (minibatch) { - int num_local_devices = - runtime::GetComputationClientOrDie()->GetLocalDevices().size(); - int num_global_devices = - runtime::GetComputationClientOrDie()->GetAllDevices().size(); - XLA_CHECK(tile_assignment.size() == num_global_devices) - << "Minibatch sharding only supports sharding along the batch " - "dimension"; + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + int num_local_devices = client->GetLocalDevices().size(); + int num_global_devices = client->GetAllDevices().size(); int batch_dim_shape = tensor.sizes()[0] * num_global_devices / num_local_devices; global_shape.set_dimensions(0, batch_dim_shape); From 27f7792a180b103980c0769527e2663e3d9bd759 Mon Sep 17 00:00:00 2001 From: Sungjoon Shon Date: Tue, 7 Oct 2025 16:00:02 +0000 Subject: [PATCH 132/133] Torch build option change Torch build option change to avoid build warning and error. --- .github/workflows/_build_torch_xla_3.11.yml | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/.github/workflows/_build_torch_xla_3.11.yml b/.github/workflows/_build_torch_xla_3.11.yml index 57f4d2bed6ea..29dedeb95de7 100644 --- a/.github/workflows/_build_torch_xla_3.11.yml +++ b/.github/workflows/_build_torch_xla_3.11.yml @@ -25,6 +25,7 @@ jobs: artifact_name: ${{ steps.set_upload_name.outputs.artifact_name }} steps: - name: "Build Torch/XLA wheel" + shell: bash id: build_wheels run: | cmake --version @@ -49,9 +50,15 @@ jobs: cd pytorch/ git clone --recursive https://github.com/tenstorrent/pytorch-xla.git xla - # copy pre-built wheels from cache - python3.11 setup.py bdist_wheel - python3.11 setup.py develop + ( + # Build PyTorch + # From https://docs.pytorch.org/FBGEMM/fbgemm/development/BuildInstructions.html, section "Build Issues with GCC 12+" + export CFLAGS+=" -Wno-error=maybe-uninitialized -Wno-error=uninitialized -Wno-error=restrict" + export CXXFLAGS+=" -Wno-error=maybe-uninitialized -Wno-error=uninitialized -Wno-error=restrict" + # copy pre-built wheels from cache + python3.11 setup.py bdist_wheel + python3.11 setup.py develop + ) # Build PyTorch/XLA cd xla/ From 7a014f815748e6a33722afb25187dc1d64114947 Mon Sep 17 00:00:00 2001 From: Sungjoon Shon Date: Thu, 9 Oct 2025 20:06:31 +0000 Subject: [PATCH 133/133] Temporary adding checkout branch checkout branch --- .github/workflows/_build_torch_xla_3.11.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/_build_torch_xla_3.11.yml b/.github/workflows/_build_torch_xla_3.11.yml index 29dedeb95de7..6995a93fcc48 100644 --- a/.github/workflows/_build_torch_xla_3.11.yml +++ b/.github/workflows/_build_torch_xla_3.11.yml @@ -50,9 +50,12 @@ jobs: cd pytorch/ git clone --recursive https://github.com/tenstorrent/pytorch-xla.git xla + ( # Build PyTorch # From https://docs.pytorch.org/FBGEMM/fbgemm/development/BuildInstructions.html, section "Build Issues with GCC 12+" + export CFLAGS="${CFLAGS:+$CFLAGS }-w" + export CXXFLAGS="${CXXFLAGS:+$CXXFLAGS }-w" export CFLAGS+=" -Wno-error=maybe-uninitialized -Wno-error=uninitialized -Wno-error=restrict" export CXXFLAGS+=" -Wno-error=maybe-uninitialized -Wno-error=uninitialized -Wno-error=restrict" # copy pre-built wheels from cache @@ -62,6 +65,7 @@ jobs: # Build PyTorch/XLA cd xla/ + git checkout sshon/rebase-to-upstream python3.11 setup.py bdist_wheel # Collect wheels