diff --git a/.github/workflows/lintercheck.yml b/.github/workflows/lintercheck.yml index edd012e9c008..9b5b63a5afa0 100644 --- a/.github/workflows/lintercheck.yml +++ b/.github/workflows/lintercheck.yml @@ -1,113 +1,113 @@ name: Linter check on: - pull_request: - push: - branches: - - master - tags: - - r[0-9]+.[0-9]+ + pull_request: + push: + branches: + - master + tags: + - r[0-9]+.[0-9]+ jobs: - check_code_changes: - name: Check Code Changes - uses: ./.github/workflows/_check_code_changes.yml - with: - event_name: ${{ github.event_name }} - # For pull_request, use PR's base and head. For push, use event's before and sha. - base_sha: ${{ github.event_name == 'pull_request' && github.event.pull_request.base.sha || github.event.before }} - head_sha: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - linter_check: - runs-on: ubuntu-24.04 - needs: [check_code_changes] - steps: - - name: Checkout repo - if: needs.check_code_changes.outputs.has_code_changes == 'true' - uses: actions/checkout@v3 - - name: Setup Python - if: needs.check_code_changes.outputs.has_code_changes == 'true' - uses: actions/setup-python@v4 + check_code_changes: + name: Check Code Changes + uses: ./.github/workflows/_check_code_changes.yml with: - python-version: '3.10' - cache: 'pip' - - run: pip install yapf==0.40.2 # N.B.: keep in sync with `torchax/dev-requirements.txt`, `infra/ansible/config/pip.yaml` + event_name: ${{ github.event_name }} + # For pull_request, use PR's base and head. For push, use event's before and sha. + base_sha: ${{ github.event_name == 'pull_request' && github.event.pull_request.base.sha || github.event.before }} + head_sha: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + linter_check: + runs-on: ubuntu-24.04 + needs: [check_code_changes] + steps: + - name: Checkout repo + if: needs.check_code_changes.outputs.has_code_changes == 'true' + uses: actions/checkout@v3 + - name: Setup Python + if: needs.check_code_changes.outputs.has_code_changes == 'true' + uses: actions/setup-python@v4 + with: + python-version: "3.10" + cache: "pip" + - run: pip install yapf==0.40.2 # N.B.: keep in sync with `infra/ansible/config/pip.yaml` - - name: Check no TORCH_PIN - if: > - (github.event_name == 'push' && github.event.ref == 'refs/heads/master') && - needs.check_code_changes.outputs.has_code_changes == 'true' - shell: bash - run: | - TORCH_PIN=./.torch_pin - if [[ -f "${TORCH_PIN}" ]]; then - echo "Please remove ${TORCH_PIN} before landing." - exit 1 - else - echo "No ${TORCH_PIN} found, safe to land..." - fi - - name: Check .cc file extension - shell: bash - run: | - # Find *.cc files recursively in the current directory, limiting to files only. - found_files=$(find . -type f -name "*.cc") + - name: Check no TORCH_PIN + if: > + (github.event_name == 'push' && github.event.ref == 'refs/heads/master') && + needs.check_code_changes.outputs.has_code_changes == 'true' + shell: bash + run: | + TORCH_PIN=./.torch_pin + if [[ -f "${TORCH_PIN}" ]]; then + echo "Please remove ${TORCH_PIN} before landing." + exit 1 + else + echo "No ${TORCH_PIN} found, safe to land..." + fi + - name: Check .cc file extension + shell: bash + run: | + # Find *.cc files recursively in the current directory, limiting to files only. + found_files=$(find . -type f -name "*.cc") - # Check if any files were found. - if [ -n "$found_files" ]; then - echo "Found *.cc files:" - echo "$found_files" - echo "Please rename them to *.cpp for consistency." - exit 1 - else - echo "PASSED *.cc file extension check" - fi - - name: Run clang-format - if: needs.check_code_changes.outputs.has_code_changes == 'true' - shell: bash - env: - CLANG_FORMAT: clang-format-16 - run: | - sudo apt-get update - sudo apt install -y "${CLANG_FORMAT}" - git_status=$(git status --porcelain) - if [[ $git_status ]]; then - echo "Checkout code is not clean" - echo "${git_status}" - exit 1 - fi + # Check if any files were found. + if [ -n "$found_files" ]; then + echo "Found *.cc files:" + echo "$found_files" + echo "Please rename them to *.cpp for consistency." + exit 1 + else + echo "PASSED *.cc file extension check" + fi + - name: Run clang-format + if: needs.check_code_changes.outputs.has_code_changes == 'true' + shell: bash + env: + CLANG_FORMAT: clang-format-16 + run: | + sudo apt-get update + sudo apt install -y "${CLANG_FORMAT}" + git_status=$(git status --porcelain) + if [[ $git_status ]]; then + echo "Checkout code is not clean" + echo "${git_status}" + exit 1 + fi - find . -name '*.cpp' -o -name '*.h' -o -name '*.cc' | xargs "${CLANG_FORMAT}" -i -style=file - git_status=$(git status --porcelain) - if [[ $git_status ]]; then - git diff - echo "${CLANG_FORMAT} recommends the changes above, please manually apply them OR automatically apply the changes " - echo "by running \"${CLANG_FORMAT} -i -style=file /PATH/TO/foo.cpp\" to the following files" - echo "${git_status}" - exit 1 - else - echo "PASSED C++ format" - fi - - name: Run yapf - if: needs.check_code_changes.outputs.has_code_changes == 'true' - shell: bash - run: | - git_status=$(git status --porcelain) - if [[ $git_status ]]; then - echo "Checkout code is not clean" - echo "${git_status}" - exit 1 - fi + find . -name '*.cpp' -o -name '*.h' -o -name '*.cc' | xargs "${CLANG_FORMAT}" -i -style=file + git_status=$(git status --porcelain) + if [[ $git_status ]]; then + git diff + echo "${CLANG_FORMAT} recommends the changes above, please manually apply them OR automatically apply the changes " + echo "by running \"${CLANG_FORMAT} -i -style=file /PATH/TO/foo.cpp\" to the following files" + echo "${git_status}" + exit 1 + else + echo "PASSED C++ format" + fi + - name: Run yapf + if: needs.check_code_changes.outputs.has_code_changes == 'true' + shell: bash + run: | + git_status=$(git status --porcelain) + if [[ $git_status ]]; then + echo "Checkout code is not clean" + echo "${git_status}" + exit 1 + fi - yapf -i -r *.py test/ scripts/ torch_xla/ benchmarks/ torchax/ - git_status=$(git status --porcelain) - if [[ $git_status ]]; then - git diff - echo "yapf recommends the changes above, please manually apply them OR automatically apply the changes " - echo "by running `yapf -i /PATH/TO/foo.py` to the following files" - echo "${git_status}" - exit 1 - else - echo "PASSED Python format" - fi - - name: Report no code changes - if: needs.check_code_changes.outputs.has_code_changes == 'false' - run: | - echo "No code changes were detected that require running the full test suite." + yapf -i -r *.py test/ scripts/ torch_xla/ benchmarks/ + git_status=$(git status --porcelain) + if [[ $git_status ]]; then + git diff + echo "yapf recommends the changes above, please manually apply them OR automatically apply the changes " + echo "by running `yapf -i /PATH/TO/foo.py` to the following files" + echo "${git_status}" + exit 1 + else + echo "PASSED Python format" + fi + - name: Report no code changes + if: needs.check_code_changes.outputs.has_code_changes == 'false' + run: | + echo "No code changes were detected that require running the full test suite." diff --git a/.github/workflows/torchax.yml b/.github/workflows/torchax.yml deleted file mode 100644 index 2f1e930f48b5..000000000000 --- a/.github/workflows/torchax.yml +++ /dev/null @@ -1,73 +0,0 @@ -name: torchax -on: - pull_request: - branches: - - master - - r[0-9]+.[0-9]+ - push: - branches: - - master - - r[0-9]+.[0-9]+ - workflow_dispatch: - -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} - cancel-in-progress: true - -jobs: - check_code_changes: - name: Check Code Changes - uses: ./.github/workflows/_check_code_changes.yml - with: - event_name: ${{ github.event_name }} - # For pull_request, use PR's base and head. For push, use event's before and sha. - base_sha: ${{ github.event_name == 'pull_request' && github.event.pull_request.base.sha || github.event.before }} - head_sha: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - torchax-cpu: - runs-on: ubuntu-24.04 - needs: [check_code_changes] - strategy: - matrix: - python-version: ['3.10', '3.11', '3.12'] - steps: - - name: Checkout repo - if: needs.check_code_changes.outputs.has_code_changes == 'true' - uses: actions/checkout@v4 - with: - sparse-checkout: | - torchax - - name: Setup Python - if: needs.check_code_changes.outputs.has_code_changes == 'true' - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - name: Install - if: needs.check_code_changes.outputs.has_code_changes == 'true' - shell: bash - working-directory: torchax - run: | - pip install -r test-requirements.txt - pip install -e .[cpu] - - name: Run tests - if: needs.check_code_changes.outputs.has_code_changes == 'true' - working-directory: torchax - shell: bash - run: | - export JAX_PLATFORMS=cpu - # Find all Python test files recursively - find ./test -name "test_*.py" -type f | while IFS= read -r test_file; do - # Skip tests with known issues - if [[ "$test_file" == *"test_tf_integration.py"* ]]; then - echo "Skipping ${test_file}. TODO(https://github.com/pytorch/xla/issues/8770): Investigate" - continue - fi - echo "Running tests for $test_file" - pytest "$test_file" - done - # Run distributed tests. - XLA_FLAGS=--xla_force_host_platform_device_count=4 pytest -n 0 test_dist/ - echo "Tests completed." - - name: Report no code changes - if: needs.check_code_changes.outputs.has_code_changes == 'false' - run: | - echo "No code changes were detected that require running the full test suite." diff --git a/setup.py b/setup.py index 33642f9a3f6e..2a9e86625e02 100644 --- a/setup.py +++ b/setup.py @@ -119,6 +119,8 @@ _jaxlib_version = '0.7.1' _jax_date = '20250813' # Date for jax and jaxlib. +_torchax_version = '0.0.7' # likely stay the same + if USE_NIGHTLY: _libtpu_version += f".dev{_libtpu_date}+nightly" _jax_version += f'.dev{_jax_date}' @@ -335,19 +337,6 @@ def build_extension(self, ext: Extension) -> None: # 1. Find `torch_xla` and its subpackages automatically from the root. packages_to_include = find_packages(include=['torch_xla', 'torch_xla.*']) -# 2. Explicitly find the contents of the nested `torchax` package. -# Find all sub-packages within the torchax directory (e.g., 'ops'). -torchax_source_dir = 'torchax/torchax' -torchax_subpackages = find_packages(where=torchax_source_dir) -# Construct the full list of packages, starting with the top-level -# 'torchax' and adding all the discovered sub-packages. -packages_to_include.extend(['torchax'] + - ['torchax.' + pkg for pkg in torchax_subpackages]) - -# 3. The package_dir mapping explicitly tells setuptools where the 'torchax' -# package's source code begins. `torch_xla` source code is inferred. -package_dir_mapping = {'torchax': torchax_source_dir} - class Develop(develop.develop): """ @@ -372,7 +361,7 @@ def link_packages(self): and `.pth` files. setuptools uses `.egg-link` by default. However, `.egg-link` only supports linking a single directory containg one editable package. This function removes the `.egg-link` file and generates a `.pth` file that can - be used to link multiple packages, in particular, `torch_xla` and `torchax`. + be used to link multiple packages. Note that this function is only relevant in the editable package development path (`python setup.py develop`). Nightly and release wheel builds work out of the box @@ -409,18 +398,13 @@ def link_packages(self): pth_filename = os.path.join(target_dir, f"{dist_name}.pth") project_root = os.path.dirname(os.path.abspath(__file__)) - paths_to_add = { - project_root, # For `torch_xla` - os.path.abspath(os.path.join(project_root, 'torchax')), # For `torchax` - } - with open(pth_filename, "w", encoding='utf-8') as f: - for path in sorted(paths_to_add): - f.write(path + "\n") + f.write(project_root + "\n") def _get_jax_install_requirements(): return [ + f'torchax=={_torchax_version}', f'jaxlib=={_jaxlib_version}', f'jax=={_jax_version}', ] @@ -452,7 +436,6 @@ def _get_jax_install_requirements(): ], python_requires=">=3.10.0", packages=packages_to_include, - package_dir=package_dir_mapping, ext_modules=[ BazelExtension('//:_XLAC.so'), ], diff --git a/torchax/LICENSE b/torchax/LICENSE deleted file mode 100644 index 1d064b89dc7c..000000000000 --- a/torchax/LICENSE +++ /dev/null @@ -1,28 +0,0 @@ -BSD 3-Clause License - -Copyright (c) 2023, pytorch-tpu - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -1. Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -3. Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/torchax/README.md b/torchax/README.md index 57a212b4838d..941ffc29ed00 100644 --- a/torchax/README.md +++ b/torchax/README.md @@ -1,222 +1,17 @@ -# torchax: Running PyTorch on TPU via JAX +# torchax: a torch frontend for JAX, and JAX - torch interoperability layer. -**torchax** is a backend for PyTorch, allowing users to run -PyTorch on Google Cloud TPUs. **torchax** is also a library for providing -graph-level interoperability between PyTorch and JAX. +**torchax** is a frontend for JAX, allowing users to write JAX programs +using PyTorch syntax. +**torchax** is also a library for providing +graph-level interoperability between PyTorch and JAX; meaning +we can reuse PyTorch models in a JAX program. -This means, with **torchax** you can: -* Run PyTorch code on TPUs with as little as 2 lines of code change. -* Call a JAX function from a PyTorch function, passing in `jax.Array`s. -* Call a PyTorch function from a JAX function, passing in a `torch.Tensor`s. -* Use JAX features such as `jax.grad`, `optax`, and `GSPMD` to train a PyTorch - model. -* Use a PyTorch model as feature extractor and use it with a JAX model. -etc etc. +## New Location: -## Install +As of 2025-10-06, **torchax** has been permantly moved to https://github.com/google/torchax. -First install torch CPU: +This file only serves as a reference. Thanks. -```bash -# On Linux. -pip install torch --index-url https://download.pytorch.org/whl/cpu - -# Or on Mac. -pip install torch -``` - -Then install JAX for the accelerator you want to use: - -```bash -# On Google Cloud TPU. -pip install -U jax[tpu] - -# Or, on GPU machines. -pip install -U jax[cuda12] - -# Or, on Linux CPU machines or Macs (see the note below). -pip install -U jax -``` - -NOTE: if you like metal support for Apple devices then install the -metal version of JAX: https://developer.apple.com/metal/jax/ - -Finally install torchax: - -```bash -# Install pre-built torchax. -pip install torchax - -# Or, install torchax from source. -pip install git+https://github.com/pytorch/xla.git#subdirectory=torchax -``` - -## Run a model - -Now let's execute a model under torchax. We'll start with a simple 2-layer model. -In theory, we can use any instance of `torch.nn.Module`. - -```python -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class MyModel(nn.Module): - def __init__(self): - super().__init__() - self.fc1 = nn.Linear(28 * 28, 120) - self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 10) - - def forward(self, x): - x = x.view(-1, 28 * 28) - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - x = self.fc3(x) - return x - -m = MyModel() - -# Execute this model using torch. -inputs = torch.randn(3, 3, 28, 28) -print(m(inputs)) -``` - -To execute this model with `torchax`, we need to enable torchax to capture PyTorch ops: - -```python -import torchax -torchax.enable_globally() -``` - -Then, we can use a `jax` device: - -```python -inputs = torch.randn(3, 3, 28, 28, device='jax') -m = MyModel().to('jax') -res = m(inputs) -print(type(res)) # outputs torchax.tensor.Tensor -print(res.jax()) # print the underlying Jax Array -``` - -`torchax.tensor.Tensor` is a `torch.Tensor` subclass that holds -a `jax.Array`. You can inspect that JAX array with `res.jax()`. - -In other words, despite that the code above looks like PyTorch, it is actually running JAX! - -## What is happening behind the scene - -We took the approach detailed in the -[new device](https://github.com/albanD/subclass_zoo/blob/main/new_device.py) -recipe by Alban (@albanD), using `jax.Array` for `raw_data`. - -In other words, when a torch op is executed inside an `env` context manager, -which is enabled by `torchax.enable_globally()`, we will swap out the -implementation of that op with JAX. - -When a model's constructor runs, it will call some tensor constructor, such as -`torch.rand`, `torch.ones`, or `torch.zeros` to create its weights. When torchax -is enabled, these constructors will create a `torchax.tensor.Tensor`, which -contains a `jax.Array`. - -Then, each subsequent op will extract the `jax.Array`, call the op's JAX -implementation, and wrap the result back into a `torchax.tensor.Tensor`, - -See more at [how it works](docs/how_it_works.md) and\ -[ops registry](docs/ops_registry.md). - -### Executing with jax.jit - -The above script will execute the model using eager mode JAX as the backend. This -does allow executing torch models on TPUs, but is often slower than what we can -achieve with `jax.jit`. - -`jax.jit` is a function that takes a JAX function (i.e. a function that takes JAX arrays -and returns JAX arrays) into a compiled (thus faster) version of the same function. - -We have made a `jax_jit` decorator that would accomplish the same with functions -that takes and returns `torch.Tensor`s. To use this, the first step is to create -a functional version of this model: this means the parameters should be passed in -as input instead of being attributes of the class: - -```python -def model_func(param, inputs): - return torch.func.functional_call(m, param, inputs) -``` - -Here we use [torch.func.functional_call](https://pytorch.org/docs/stable/generated/torch.func.functional_call.html) -from PyTorch to replace the model weights with `param` and then call the -model. This is roughly equivalent to: - -```python -def model_func(param, inputs): - m.load_state_dict(param) - return m(*inputs) -``` - -Now, we can apply `jax_jit` on `module_func`: - -```python -from torchax.interop import jax_jit - -model_func_jitted = jax_jit(model_func) -print(model_func_jitted(new_state_dict, inputs)) -``` - -See more examples at [eager_mode.py](examples/eager_mode.py) and the -[examples folder](examples/). - -To ease the idiom of creating functional model and calling it with parameters, -we also created the `JittableModule` helper class. It lets us rewrite the -above as: - -```python -from torchax.interop import JittableModule - -m_jitted = JittableModule(m) -res = m_jitted(...) -``` - -The first time `m_jitted` is called, it will trigger `jax.jit` to compile the -compile for the given input shapes. Subsequent calls with the same input shapes -will be fast as the compilation is cached. - -## Saving and Loading Checkpoints - -You can use `torchax.save_checkpoint` and `torchax.load_checkpoint` to save and load your training state. The state can be a dictionary containing the model's weights, optimizer state, and any other information you want to save. - -```python -import torchax -import torch -import optax - -# Assume model, optimizer, and other states are defined -model = MyModel() -optimizer = optax.adam(1e-3) -opt_state = optimizer.init(model.parameters()) -weights = model.parameters() -buffers = model.buffers() -epoch = 10 - -state = { - 'weights': weights, - 'buffers': buffers, - 'opt_state': opt_state, - 'epoch': epoch, -} - -# Save checkpoint -torchax.save_checkpoint(state, '/path/to/checkpoint.pt') - -# Load checkpoint -loaded_state = torchax.load_checkpoint('/path/to/checkpoint.pt') - -# Restore state -model.load_state_dict(loaded_state['weights']) -opt_state = loaded_state['opt_state'] -epoch = loaded_state['epoch'] -``` ## Citation diff --git a/torchax/build_nightly.sh b/torchax/build_nightly.sh deleted file mode 100755 index 885a90c6d44d..000000000000 --- a/torchax/build_nightly.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env bash -set -ex - -NIGHTLY_VERSION=$(date '+%Y%m%d%H%M') - -# Update the version to .devYYYYMMDDHHMM in __init__.py -VERSION_UPDATE_PATTERN="s/^__version__\s*=\s*\"([^\"]+)\"/__version__ = \"\1.dev$NIGHTLY_VERSION\"/g;" -sed -r "$VERSION_UPDATE_PATTERN" torchax/__init__.py --in-place - -hatch build -t wheel diff --git a/torchax/dev-requirements.txt b/torchax/dev-requirements.txt deleted file mode 100644 index 2da02ae8599b..000000000000 --- a/torchax/dev-requirements.txt +++ /dev/null @@ -1,5 +0,0 @@ --f https://download.pytorch.org/whl/torch -torch==2.8.0 ; sys_platform == 'darwin' # macOS -torch==2.8.0+cpu; sys_platform != 'darwin' # Non-macOS (CPU-only), like on TPU -yapf==0.40.2 # N.B.: keep in sync with `infra/ansible/config/pip.yaml`, `.github/workflows/lintercheck.yml` -flax==0.10.6 diff --git a/torchax/docs/dispatch.png b/torchax/docs/dispatch.png deleted file mode 100644 index fcdd5e9e58a3..000000000000 Binary files a/torchax/docs/dispatch.png and /dev/null differ diff --git a/torchax/docs/fixing_op_info_test.md b/torchax/docs/fixing_op_info_test.md deleted file mode 100644 index 186b70f79a48..000000000000 --- a/torchax/docs/fixing_op_info_test.md +++ /dev/null @@ -1,254 +0,0 @@ -# How to fix an op info test. - -## What is OpInfo test - -PyTorch created a list of python objects (OpInfo) to keep -track how to test each op. This is useful to us because it -ensures that the ops we implement produces the same results -pytorch would produce. - -Context: -* https://dev-discuss.pytorch.org/t/opinfos-in-pytorch-1-10/253 -* https://github.com/pytorch/pytorch/issues/54261 - - -## How to fix one - -### Remove one op from skiplist - -Open [test/test_ops.py](../test/test_ops.py) with your -favorite text editor. -Remove one line from the `skiplist` set. - -i.e. - -```bash -(base) hanq-macbookpro:torchax hanq$ git diff -diff --git a/experimental/torchax/test/test_ops.py b/experimental/torchax/test/test_ops.py -index 72a39ae85..2a156cbce 100644 ---- a/experimental/torchax/test/test_ops.py -+++ b/experimental/torchax/test/test_ops.py -@@ -15,7 +15,6 @@ skiplist = { - "_native_batch_norm_legit", - "_segment_reduce", - "_upsample_bilinear2d_aa", -- "addbmm", - "addmm", - "addmv", - "addr", -``` - -### Run test to see what failure -For errors you might get after running test, there are two kind: -- Target op failure - - error shows related to target op, such as `No lowering found for 'aten::addbmm'`, please follow instruction like [Fix Target op failure](https://github.com/pytorch/xla/blob/ManfeiBai-patch-99/experimental/torchax/docs/fixing_op_info_test.md#fix-target-op-failure) -- Decomposed op failure - - no implementation found for target ops, but error is not `no lowering`, error shows target op has been implemented somewhere; for sitution like this, please follow instruction like [Fix Decomposed op failure](https://github.com/pytorch/xla/blob/ManfeiBai-patch-99/experimental/torchax/docs/fixing_op_info_test.md#fix-other-op-failure) - -#### Fix Target op failure -Error gotten: - -``` -(base) hanq-macbookpro:torchax hanq$ python test/test_ops.py -... -E RuntimeError: ('No lowering found for\n\nTo execute this test, run the following from the base repo dir:\n python test/test_ops.py -k test_reference_eager_addbmm_cpu_int64\n\nThis message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0', 'aten::addbmm') -``` - -From here we have 2 strategies for fixing this test: - -1. Add an implementation to `aten::addbmm` operator using Jax ops. Or, -2. Add an implementation `aten::addbmm` operator using torch ops (this commonly known as "decompositions"). - -Either way works for torchax. For ops that are not "Core Aten" sometimes we implement in torch ops with the goal of -upstreaming this decomposition to [pytorch decompositon](https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py) -so other projects can benefit from it. - -For illustration purposes, let's implement this op in Jax. - -(NOTE: this doesn't stop us from upstreaming a decomposition later if we want) - -#### Fix Decomposed op failure -For situation that no target op(`trapezoid`) implemention found in `experimental/torchax/torchax/ops/jaten.py`, but error shows target op(`trapezoid`) has been implemented somewhere: -``` -====================================================================== -FAIL: test_reference_eager_trapezoid_cpu_int64 (__main__.TestOpInfoCPU) [torchax_diff:0.001] ----------------------------------------------------------------------- -... -AssertionError: The values for attribute 'dtype' do not match: torch.float64 != torch.float32. -``` -Please try to fix it by following these steps: - 1. confirm your target op `trapezoid` is decomposed by running this code to print each sub ops: - ``` - import torch - import torchax - - env = torchax.default_env() - env.config.debug_print_each_op = True - env.config.debug_accuracy_for_each_op = True - - with env: - y = torch.tensor([1, 5, 10]) - print(torch.trapezoid(y)) - ``` - 2. (optional) Debug by modify [debug_accuracy()](https://github.com/pytorch/xla/blob/c26b19ebdefccd3a4300763e1085724d3d4cd3d0/experimental/torchax/torchax/tensor.py#L171C1-L194C14) to check `res`(from jax) and `expected_res`(from torch)'s value and dtype/type. - 3. you might need to debug/modify/add implementation of sub ops(found in step1) to support `trapezoid` by using step 2, like: - ``` - @op(torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar) - def _aten_mul(x, y): - new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) - res = x * y - if isinstance(x, float) or isinstance(y, float): - res = res.astype(new_dtype) - return res - ``` - -### First Impl - -To implement this op using jax ops, we first find what -is the exact semantics in this page: -https://pytorch.org/docs/stable/generated/torch.addbmm.html - -From it's math formula: we can implement it as follows. - -``` -+@op(torch.ops.aten.addbmm.default) -+def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1): -+ -+ mm = jnp.einsum('bxy, byz -> xz', batch1, batch2) -+ return beta * input + alpha * mm -``` - -Now running test again: - -``` -python test/test_ops.py -k test_reference_eager_addbmm_cpu_int64 -``` - -(NOTE: the exact test command is printed out when we run -`pytest test/test_ops.py` so we can only run the failed test instead of running all tests.) - -We now see this error: - -``` -FAIL: test_reference_eager_addbmm_cpu_int64 (__main__.TestOpInfoCPU) [torchax_diff:0.001] ----------------------------------------------------------------------- -Traceback (most recent call last): - File "/Users/hanq/git/qihqi/torch_xla/experimental/torchax/test/test_ops.py", line 654, in run_export_and_compare - diff_output( - File "/Users/hanq/git/qihqi/torch_xla/experimental/torchax/test/test_ops.py", line 617, in diff_output - testcase.assertTrue( -AssertionError: False is not true -``` - -This is telling me that our implementation did not produce -the same result as the ops in PyTorch. - -To debug this, let's figure out what exact input caused this. -We can achieve this by setting a break point [here](https://github.com/pytorch/xla/blob/master/experimental/torchax/test/test_ops.py#L644), right before the diff. Here we can -inspect values of `res` and `res2`, as well as the `sample_input`. - -The sample input we get is -``` -SampleInput(input=tensor([[-3, -3, 9, 8, -8, -3, -4, 2, 2, 2], - [-5, 1, -9, 9, 1, -5, 6, 1, -4, -5], - [-2, -1, 5, -2, -3, 0, 5, -4, 9, -6], - [-1, -7, 6, 3, 8, 3, 8, 9, -5, 7], - [-3, -4, -9, 9, 7, -3, -8, 2, 5, -3]]), args=(tensor([[[-2, 4, -2, 5, 8], - [-6, -2, 5, 7, 7], - [-8, -3, 2, 5, -3], - [-4, 7, 0, -9, 8], - [ 3, 9, -9, -2, 0]], - - [[-7, 1, -3, 7, -4], - [ 3, 5, 4, 6, 5], - [-2, 8, 3, 5, 7], - [ 8, -2, -8, 2, 0], - [ 6, 1, -8, 8, 0]], - - [[ 2, -1, -5, -8, -9], - [ 5, 0, -4, -1, -6], - [-6, 2, -5, -2, -5], - [-5, -3, -5, -4, 9], - [-3, 4, -9, -9, 7]], - - [[ 2, 5, -7, -3, 8], - [-5, -7, -8, -4, 4], - [-4, -6, -3, 0, 6], - [ 8, 0, -3, -8, 2], - [-4, 3, -9, -6, 7]], - - [[ 2, 1, -6, 2, 8], - [ 2, 6, 4, 1, 8], - [-9, 9, -5, 8, 3], - [-5, 0, -2, 4, 0], - [ 5, 8, -4, 9, 7]]]), tensor([[[-1, -8, 3, 5, -8, 2, -5, 0, -9, -5], - [-4, -7, 2, 2, 1, -9, 2, 7, -1, -1], - [ 1, 8, -6, -4, -6, -8, -7, -9, 7, 4], - [-4, 1, -9, 3, 4, 6, 0, -2, -2, -7], - [ 5, 5, 0, 8, -3, 7, -7, 8, 3, 5]], - - [[ 8, -4, -9, 9, 5, 0, 5, 0, -5, 5], - [-5, -3, -2, 8, 1, -2, 4, -7, 5, 3], - [-4, 4, 1, -4, -8, 2, -5, 2, 9, -7], - [ 9, 6, -8, -3, 3, 1, 4, 6, -5, -4], - [-2, 1, 5, 5, 2, 6, 7, -3, -7, 3]], - - [[ 9, -8, 5, -3, -1, 2, -9, -5, -1, -3], - [-3, 3, -9, -7, -9, -8, 1, -3, 7, -2], - [ 8, -1, 8, -8, -7, 4, 8, 8, 5, -7], - [-1, 6, -8, 7, -1, -5, -8, 6, -2, 8], - [-5, -5, 8, 6, 0, 1, 3, -2, -3, -9]], - - [[ 7, -2, 6, -8, -5, 3, 2, -1, -5, 8], - [-6, -4, 3, 9, -9, -8, -7, 3, 9, 0], - [ 1, 3, 4, 4, -5, -2, -4, -2, 3, -7], - [-6, 9, 5, -1, 7, 7, 8, -3, -8, 0], - [-1, -6, -3, 3, 3, -8, -4, 9, -5, 7]], - - [[-5, -3, -9, 6, -1, -7, 9, -8, 1, -8], - [-8, -8, -2, -5, -7, -8, 1, 0, 0, -6], - [ 7, -5, 2, 2, 0, -9, -5, -7, 1, 8], - [-4, 0, 9, 6, -1, -6, 6, -6, -2, -1], - [ 7, 3, 0, 1, 1, -9, 5, -8, -1, -7]]])), kwargs={'beta': 0.6, 'alpha': 0.2}, broadcasts_input=False, name='') -``` - -And the `res` from torch is - -``` -tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) -``` - -So few observation is: -1. Input tensor are of type int64 -2. alpha and beta are both floats. - -So one can suspect that it has to do with rounding. -Reading the doc more carefully, we can find this sentence - - For inputs of type FloatTensor or DoubleTensor, arguments beta and alpha must be real numbers, otherwise they should be integers. - -So likely torch first casted the float alpha and beta to integer, which yields 0, then used them in math to get a matrix with all zeros. - -### Second Impl - -```python -+@op(torch.ops.aten.addbmm.default) -+def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1): -+ alpha = jnp.array(alpha).astype(batch1.dtype) -+ beta = jnp.array(beta).astype(batch1.dtype) -+ mm = jnp.einsum('bxy, byz -> xz', batch1, batch2) -+ return jax.lax.cond(beta == 0, -+ lambda: alpha * mm, -+ lambda: beta*input + alpha*mm) -+ -``` - -Adding type casts makes the tests passes. - -### Submit -Now, let's remove the pdb and prints we added, and submit the fix as a PR: https://github.com/pytorch/xla/pull/6993 - diff --git a/torchax/docs/how_it_works.md b/torchax/docs/how_it_works.md deleted file mode 100644 index f352773b6da4..000000000000 --- a/torchax/docs/how_it_works.md +++ /dev/null @@ -1,134 +0,0 @@ -How it works -============ - - -## Tensor subclass and eager mode - -The class `Tensor` is a `torch.Tensor` subclass -that overrides `__torch_dispatch__`. - -It roughly looks like this (with some details removed): - -The complete class impl is at [tensor.py](../torchax/tensor.py). - -```python -class Tensor(torch.Tensor): - - @staticmethod - def __new__(cls, elem): - return torch.Tensor._make_wrapper_subclass( - cls, - shape, - dtype=dtype, - device='meta', - requires_grad=False, - ) - - def __init__(self, elem: jax.Array): - super().__init__() - self._elem = elem - - __torch_function__ = torch._C._disabled_torch_function_impl - - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - # here assumes ALL tensors in args / kwargs are - # instances of Tensor - args, kwargs = unwrap((args, kwargs)) - jax_func = some_registry[func] - res = jax_func(*args, **kwargs) - return wrap(res) - -def wrap(tree): - # wrap jax.Array with Tensor - return pytree.tree_map_only( - jax.Array, Tensor, tree) - -def unwrap(tree): - # get jax.Array out ofTensor - return pytree.tree_map_only( - Tensor, lambda x: x._elem, tree) -``` - -In other words, assuming that we have a function -that takes `jax.Array` as input and returns `jax.Array` -but otherwise implement the same semantics -as a `ATen` op; then, using this tensor we would -be able to route the call to this jax function. - -[_ops.py](../torchax/_ops.py) files defines some of those ops. - -Let's take `aten::add` as example: - -```python -@op(torch.ops.aten.add) -def _aten_add(x, y, *, alpha=1): - """if isinstance(x, jnp.ndarray) and isinstance(y, jnp.ndarray): - - assert x.dtype == y.dtype, (x.dtype, y.dtype) - """ - return x + y * alpha -``` - -The `@op` decorator just puts this function into `some_registry` dictionary. - -`_aten_add` has same signature as `torch.ops.aten.add` but takes `jax.Array` as -input. - -![](dispatch.png) - - -## fx Interpreter and dynamo mode - -Now, assuming we have this `some_registry` dict with key core Aten ops, -and value the equivalent python Jax functions. We can also build a `fx.Interpreter` -subclass that executes the jax function given a `fx.GraphModule`. - - -```python -class JaxInterpreter(torch.fx.Interpreter): - - def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: - if not isinstance(target, - (torch._ops.OpOverloadPacket, torch._ops.OpOverload)): - return super().call_function(target, args, kwargs) - - op = some_registry[target] - return op.func(*args, **kwargs) -``` - -There is no wrapping and unwrapping needed because `args` and `kwargs` are -already `jax.Array`'s. - -Using this interpreter we can build a dynamo backend: - -```python -def backend(fxgraph): - - def tojit(*args, *kwargs): - return JaxInterpreter(fxgraph).run(*args, **kwargs) - jitted = jax.jit(to_jit) - - def f(*torchtensor): - jaxarrays = unwrap(torchtensors) - res = jitted(jax_array) - return wrap(res) - - return f -``` - -The inner function `tojit` is a function that takes and returns -`jax.Array`'s. So it's suitable to be jitted with `jax.jit`. - -`f` is returned callable that takes `Tensor`; so can interop with -other torch codes. - -## nn.Modules and state management - -See [README.md](../README.md) for using `torch.func.functional_call` to -make `nn.Module`s interact well with `jax.jit`. - -See [Examples](../examples/README.md) for training using torch's optimizers or jax's -optimizers. - -[def]: dispatch.png diff --git a/torchax/docs/ops_registry.md b/torchax/docs/ops_registry.md deleted file mode 100644 index 242208ab0517..000000000000 --- a/torchax/docs/ops_registry.md +++ /dev/null @@ -1,41 +0,0 @@ -# Ops Registry - -## Background - -In the [How it works](how_it_works.md) doc, we mentioned 2 important pieces: - -1. A mechanism to route `ATen` ops to implementation written in - Jax or in PyTorch, and - -2. The ops themselves. - - -Ops Registry is there to help us to organize the ops themselves. - -An op implementation can written in terms of Jax, or in other PyTorch ops. -The latter is also known as "decompositions". For decompositions, -one need to be careful of not introducing circular dependencies. - -Here we simply store the operator implementations in a dictionary, -which key the torch / Aten callable that we wish to override, and -value an instance of `Operator` class. - -`Operator` class has this schema: - -```python -@dataclasses.dataclass -class Operator: - torch_op: TorchCallable - func: Union[TorchCallable, JaxCallable] - is_jax_function: bool - is_user_defined: bool - needs_env: bool - is_view_op: bool -``` - -The `torch_op` is the corresponding torch callable, and `func` the implementation. `is_jax_function` is True if `func` is implemented using Jax, False if `func` is implemented using other torch ops. We can use this information to decide how to call it. - -If `needs_env` is true, `func` will recieve an extra kwarg with name `env`. -This will be the "Environment" in which this op operate on. In particular, -the environment will contain the Jax random number generator key, that might be useful for ops like `aten::rand`. - diff --git a/torchax/docs/support_a_new_model.md b/torchax/docs/support_a_new_model.md deleted file mode 100644 index 09d70c144c01..000000000000 --- a/torchax/docs/support_a_new_model.md +++ /dev/null @@ -1,137 +0,0 @@ -# Run a model under torchax - -Supporting a new model in torchax means -having this model run using torchax and succeeds. - -A model usually consists of executing a list of torch ops -on a set of tensors (i.e. the parameters and inputs) and -produce a new tensor(s). These ops should just work. - -However, there are cases that the model doesn't run on -torchax, because: - -1. Some op it needs is not implemented. -2. Some op it needs is implemented incorrectly -3. There are some non-torch-op code that interacts with torchax in a non-friendly matter. - -Here we present few steps to attempt to fix the related issues. Using dlrm model as -example. - -This assumes that you already installed torchax with `pip install -e .` locally. -Following the instructions in [README](../README.md) - - -### Get torchbench scripts - -Following the instructions in https://github.com/pytorch-tpu/run_torchbench - - -### Run script from run_torchbench: - -```bash -(xla2) hanq-macbookpro:run_torchbench hanq$ python models/dlrm.py -Traceback (most recent call last): - File "/Users/hanq/git/qihqi/run_torchbench/models/dlrm.py", line 16, in - module = importlib.import_module(model_name) - File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/importlib/__init__.py", line 126, in import_module - return _bootstrap._gcd_import(name[level:], package, level) - File "", line 1050, in _gcd_import - File "", line 1027, in _find_and_load - File "", line 1006, in _find_and_load_unlocked - File "", line 688, in _load_unlocked - File "", line 883, in exec_module - File "", line 241, in _call_with_frames_removed - File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torchbench-0.1-py3.10.egg/torchbenchmark/models/dlrm/__init__.py", line 15, in - from .tricks.qr_embedding_bag import QREmbeddingBag -ModuleNotFoundError: No module named 'torchbenchmark.models.dlrm.tricks' -``` - -Turns out I forgot to run `python install.py dlrm` in the benchmarks folder (cloned from pytorch/benchmark) - - -### Fixing missing ops: - -Rerunning: -```bash -(xla2) hanq-macbookpro:run_torchbench hanq$ python models/dlrm.py -Traceback (most recent call last): - File "/Users/hanq/git/qihqi/run_torchbench/models/dlrm.py", line 28, in - print(model(*example)) - File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl - return forward_call(*args, **kwargs) - File "/Users/hanq/git/qihqi/run_torchbench/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py", line 355, in forward - return self.sequential_forward(dense_x, lS_o, lS_i) - File "/Users/hanq/git/qihqi/run_torchbench/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py", line 367, in sequential_forward - ly = self.apply_emb(lS_o, lS_i, self.emb_l) - File "/Users/hanq/git/qihqi/run_torchbench/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py", line 308, in apply_emb - V = E(sparse_index_group_batch, sparse_offset_group_batch) - File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl - return forward_call(*args, **kwargs) - File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 390, in forward - return F.embedding_bag(input, self.weight, offsets, - File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/functional.py", line 2360, in embedding_bag - return handle_torch_function( - File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/overrides.py", line 1619, in handle_torch_function - result = mode.__torch_function__(public_api, types, args, kwargs) - File "/Users/hanq/git/qihqi/torch_xla/experimental/torchax/torchax/tensor.py", line 215, in __torch_function__ - return func(*args, **(kwargs or {})) - File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/functional.py", line 2451, in embedding_bag - ret, _, _, _ = torch.embedding_bag( - File "/Users/hanq/git/qihqi/torch_xla/experimental/torchax/torchax/tensor.py", line 230, in __torch_dispatch__ - return self.env.dispatch(func, types, args, kwargs) - File "/Users/hanq/git/qihqi/torch_xla/experimental/torchax/torchax/tensor.py", line 310, in dispatch - raise OperatorNotFound( -torchax.tensor.OperatorNotFound: Operator with name aten::_embedding_bag has no lowering -``` - -Now let's implement this op. - -Few tricks while implementing the ops: - -1. Feel free to edit the script `models/dlrm.py` while debugging. -2. Useful options to set `env.config.debug_print_each_op = True` will print out each - op that goes through the dispatcher. -3. Set `env.config.debug_accuracy_for_each_op = True` will in addition of running Jax - op, it also runs it again in Torch CPU. Then it diffs the result. If the diff is too - large, then it drops you into pdb for inspection. -4. After inspecting input / output / shapes of the op, maybe it's enough hint for - you to fix this op. Or, if it's not, then it's adviced to save the inputs / outputs - and write a unit test for it and iterate on that. Usually a unit test is faster - to iterate than running a whole model. - -After finishing `embedding_bag` badly, I reached the next op - -```bash - File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 390, in forward - return F.embedding_bag(input, self.weight, offsets, - File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/functional.py", line 2451, in embedding_bag - ret, _, _, _ = torch.embedding_bag( - File "/Users/hanq/git/qihqi/torch_xla/experimental/torchax/torchax/tensor.py", line 124, in __torch_dispatch__ - return func(*args, **(kwargs or {})) - File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_ops.py", line 594, in __call__ - return self_._op(*args, **kwargs) - File "/Users/hanq/git/qihqi/torch_xla/experimental/torchax/torchax/tensor.py", line 212, in __torch_function__ - return func(*args, **(kwargs or {})) - File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_ops.py", line 594, in __call__ - return self_._op(*args, **kwargs) - File "/Users/hanq/git/qihqi/torch_xla/experimental/torchax/torchax/tensor.py", line 227, in __torch_dispatch__ - return self.env.dispatch(func, types, args, kwargs) - File "/Users/hanq/git/qihqi/torch_xla/experimental/torchax/torchax/tensor.py", line 308, in dispatch - raise OperatorNotFound( -torchax.tensor.OperatorNotFound: Operator with name aten::_embedding_bag_forward_only has no lowering -``` - -Turns out, that is the same operator. so adding the @op(torch.ops.aten._embedding_bag_forward_only) -on top of the same op works. - -Now the resulting PR is: https://github.com/pytorch/xla/pull/7583 - -After this `python models/dlrm.py` runs. - -NOTE: -The _embedding_bag implementation is actually very crude, just sufficient to make -the model pass. diff --git a/torchax/docs/torch_dispatch/README.md b/torchax/docs/torch_dispatch/README.md deleted file mode 100644 index 3310c4d2f997..000000000000 --- a/torchax/docs/torch_dispatch/README.md +++ /dev/null @@ -1,39 +0,0 @@ -# How torch dispatch works - -References: -* [__torch_dispatch__](https://dev-discuss.pytorch.org/t/what-and-why-is-torch-dispatch/557) -* [Dispatcher](http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/) Note: old but not outdated. - -## torch ops vs. Aten ops - -torch ops - regular python functions / methods -> `__torch_functions__` -aten ops - pybind11 registered C++ functions + `OpOverload` wrapper -> `__torch_dispatch__` - - -## Issues with torch functions: - -* More torch functions than *core* Aten ops -* Some torch functions are NOT overridable. - -## Issues with torch dispatch: - -* Undesired Decompositions -* overloads - -## Ways to override - -* Subclass -* Decorator - -## extension poster -* https://docs.google.com/presentation/d/1piuv9nBzyoqdH49D1SoE5OZUPSMpOOFqfSKOhr-ab2c/edit#slide=id.p1 - -## How does it works - -**TODO:** replace with github links - -https://source.corp.google.com/piper///depot/google3/third_party/py/torch/_tensor.py;l=439?q=class%20Tensor&ss=piper%2FGoogle%2FPiper:google3%2Fthird_party%2Fpy%2Ftorch%2F - -https://source.corp.google.com/piper///depot/google3/third_party/py/torch/torch/csrc/utils/python_arg_parser.cpp;l=394?q=%22%5C%22__torch_dispatch__%5C%22%22&ss=piper%2FGoogle%2FPiper:google3%2Fthird_party%2Fpy%2Ftorch%2F - -https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml diff --git a/torchax/docs/torch_dispatch/example.py b/torchax/docs/torch_dispatch/example.py deleted file mode 100644 index 12780fc8cf83..000000000000 --- a/torchax/docs/torch_dispatch/example.py +++ /dev/null @@ -1,67 +0,0 @@ -import torch -from torch import nn -from torch.nn import functional as F -from torch.utils import _pytree as pytree - - -class Subclass(torch.Tensor): - - def __new__(cls, raw_data, requires_grad=False): - return torch.Tensor._make_subclass( - cls, - raw_data, - require_grad=requires_grad, - ) - - def __init__(self, raw_data=None, requires_grad=False): - # Store any provided user raw_data - self.raw_data = raw_data - - @classmethod - #def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - def __torch_function__(cls, func, types, args=(), kwargs=None): - kwargs = kwargs or {} - print(f'func is {func}') - - def unpack(x): - if isinstance(x, Subclass): - return x.raw_data - return x - - (args, kwargs) = pytree.tree_map(unpack, (args, kwargs)) - res = func(*args, **kwargs) - return pytree.tree_map_only(torch.Tensor, Subclass, res) - - def __str__(self): - return f'Subclass of shape {self.shape}' - - def add(self, a): - print('HERE: add') - return super().add(a) - - __repr__ = __str__ - - -class MyModel(nn.Module): - - def __init__(self): - super().__init__() - self.fc1 = nn.Linear(28 * 28, 120) - self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 10) - - def forward(self, x): - x = x.view(-1, 28 * 28) - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - x = self.fc3(x) - return x - - -model = MyModel() - -x = torch.randn(10, 28 * 28) -x2 = Subclass(x) -print(model(x2)) - -x2.add(2) diff --git a/torchax/docs/torch_dispatch/run_env.py b/torchax/docs/torch_dispatch/run_env.py deleted file mode 100644 index ab66c58986b6..000000000000 --- a/torchax/docs/torch_dispatch/run_env.py +++ /dev/null @@ -1,11 +0,0 @@ -import torch -import torchax - -env = torchax.default_env() -env.config.debug_print_each_op = True -env.config.debug_accuracy_for_each_op = True - -with env: - y = torch.tensor([1, 5, 10]) - print(torch.trapezoid(y)) - print(torch.trapz(y, y)) diff --git a/torchax/docs/torch_xla2_dynamo.md b/torchax/docs/torch_xla2_dynamo.md deleted file mode 100644 index d375feddfcc7..000000000000 --- a/torchax/docs/torch_xla2_dynamo.md +++ /dev/null @@ -1,194 +0,0 @@ -# Dynamo backend for torchxla2 - -## Goal - -Have a dynamo backend backend by torchax. - -The users should be able to do the following: - -```python -m = model ... -m_compiled = torch.compile(m, backend='torchax_compile') # backend name TBD -result = m_compiled(*inputs) -``` - -The above should run on TPU will low overhead. - -## Challenge - -Usually the challenge of a dynamo backend is the compiler that -transforms a fx graph with torch (or Aten) ops to the compiled executable. -However, in our case, that piece is solved. - -For every `call_function` node; we lookup the corresponding implementation of -said ATen op in a dictionary for it's corresponding implementation in Jax, -and we just call it. - -This is illustrated here: https://github.com/pytorch/xla/blob/master/experimental/torchax/torchax/export.py#L23 - -Now, the challenge is for dynamo to be able to 1. produce the graph; and 2. n -not incur any data copies in this process. - - -Consider this following pseudocode: - -```python -class Tensor: - _data: jax.Array - def __torch_dispatch__(...): - # do stuff with _data, get new data - return Tensor(new_data) - -def dynamo_backend(fx, sample): - compiled = compile fx into graph that manipulate jax.Array. - def returned_callable(inputs): - datas = [i._data for i in inputs] - res = compiled(*datas) - return TensorSubclass(res) - return returned_callable - -model = torch.compile(model, backend = dynamo_backend) -inputs = a list of TensorSubclass or a list of torch.Tensor? -model(*inputs) -``` - -What would be the type of inputs? -If inputs are of type `TensorSubclass`, then dynamo -will attempt to trace through the `__torch_dispatch__` method, -and throws error because it doesn't know what is `_data` and the -operations on it. - -If `inputs` is of type `torch.Tensor`, then it works: dynamo -calls the backend, the backend can produce correct result. -But, `inputs` need to be converted to `TensorSubclass` first inside of -the backend; which usually means a data copy. This happens everytime -the compiled backend is executed, therefore not desirable. - -## The Desired behavior - -When *tracing* dynamo treats TensorSubclass as if it is a regular tensor -without dispatch override; and when executing the compiled callable, -TensorSubclass is passed in as-is. We know that dynamo can do this with -some tensor subclass, namely `FakeTensor`. - - -Let's list out the possible ways we could accomplish this behavior. - - -# Option 1. Have the jax.Array object hold in C++ - -Roughly we would have a `Tensor` subclass in C++, this is very -similar to the `LazyTensor` subclass that is the current `XLATensor`. -This tensor can hold it's own states in C++. In our case, that would -be a `PyObject*` that happens to point to either `jnp.ndarray` or -jax's `Traced` during jax.jit. We might further result the -`XLA` dispatch key to route the operators to the jax implementation, -emulating what `__torch_dispatch__` does. - -This way, eager mode will continue to work, and dynamo would work -because the Python class is still `torch.Tensor` (not a subclass), and -there are no Python logic in dispatching so dynamo cannot trace through. - -## Pros: -* Very clear that this will work. -* Recommended by ezyang - -## Cons: -Now need to deal with C++ builds. In particular, `torch` becomes a source -dependency instead of a pip dependency; meaning, again we need to start -building torch first then build torchax. This might be mitigated if -that subclass can be upstreamed. - - -# Option 2. Modify dynamo to do the desired behavior - -We have one instance where a `torch.Tensor` dispatch subclass -just works with dynamo, without dynamo make a fuss when it traces -`__torch_dispatch__`. This is `FakeTensor`. (https://github.com/pytorch/pytorch/pull/100017/files) - -The idea is to make dynamo trace as-if the inputs are `FakeTensor` and -not `XLATensor`. and only after the creation of fx graph and backend, dynamo -calls the compiled callable with `XLATensor`. - -Pros: -* Likely pure python changes. - -Cons: -* We also need to design a mechanism to represent tensor subclasses that - is desirable for dynamo to trace through, and those is not. -* Likely significant amount of work. - - -# Option 3. Register All the ops as custom_ops - -So currently dynamo traces `__torch_dispatch__`, and we don't like that -because it will find the operations on Jax arrays, and doesn't understand those. - -What if we make dynamo **able** to understand what is inside? -The [Black box python functions](https://docs.google.com/document/d/1ZuCVyMfibExwvtzhd9cfMWk5zXT3Dhy1b3kuvAIkBoU/edit#heading=h.56tggsazyrkh) doc -points the possibility of registering things that we don't want dynamo -to go into as a custom op. So we could, theoretically do the following: - -1. Register the jax impl of an Aten op as a custom op. - i.e. register `jaten.add` for `aten.add`. -2. For meta kernels, just call the meta kernel of `aten.add`. -3. In `__torch_dispatch__`, we forward the call from `aten.add` to `jaten.add`. - -When dynamo attempts to go inside of `__torch_dispatch__`, it will find -`jaten.add`. Then it will record that in the `fx.Graph`. - -Our backend will see the same ops but in a different namespace (`jaten`). -That is fine as long as we know how to look up its implementation. - -Note: we probably also need to hook up gradients of custom ops via. `autograph.Function`. - - -Pros / Cons: -Haven't tried, don't know if it gonna work or not. - - - - - - -# Appendix, Failed attempts: - -## Attempt 1: move dispatch to a mode (i.e. subclass have no dispatch override) - -```python -class Subclass(torch.Tensor): - - @staticmethod - def __new__(cls, elem): - dtype = tensor.j2t_dtype(elem.dtype) - shape = list(elem.shape) - for i, s in enumerate(shape): - if not isinstance(s, int): - shape[i] = 1 - if dtype is None: - dtype = torch.float32 - - self = torch.Tensor._make_wrapper_subclass( - cls, - shape, - dtype=dtype, - device='meta', - requires_grad=False, - ) - self._meta = torch.empty( - shape, dtype=dtype, device='meta', requires_grad=False - ) - self._elem = elem - return self - - def __init__(self, elem: jax.Array): - super().__init__() - self._elem = elem - - def __str__(self): - return "Subclass({} {})".format(str(type(self._elem)), str(self._elem)) - -``` - -This fails with an error saying that exhausted subclasses and all the `__torch_dispatch__` returned `NotImplemented`. - diff --git a/torchax/docs/understand_jax_jit/jax_jit.py b/torchax/docs/understand_jax_jit/jax_jit.py deleted file mode 100644 index cea6f090a0ae..000000000000 --- a/torchax/docs/understand_jax_jit/jax_jit.py +++ /dev/null @@ -1,105 +0,0 @@ -## Please read: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html - -## Goal 1. Jax jit what it is -## Goal 2. Illustrate that changing shapes == recompile -## Goal 3. Illustrate that closures == inlined constant in graph (undesirable) - -import random -import time -import jax.numpy as jnp -import numpy as np -from jax import jit -import jax - - -def norm(X): - for i in range(10): - X = X @ X - X = X - X.mean(0) - return X / X.std(0) - - -norm_compiled = jit(norm) - -np.random.seed(1701) - -# print('---- example 1 -----') -# for i in range(5): -# X = jnp.array(np.random.rand(1000, 1000)) -# start = time.perf_counter() -# norm(X).block_until_ready() -# end = time.perf_counter() -# print(f'iteration {i}: norm: {end - start}') -# start = time.perf_counter() -# norm_compiled(X).block_until_ready() -# end = time.perf_counter() -# print(f'iteration {i}: norm_compiled: {end - start}') - -#jax.config.update("jax_explain_cache_misses", True) -#print('---- example 3 -----') -#print('---- example 2 -----') -# for i in range(5): -# shape = random.randint(1000, 2000) -# print('shape is ', shape) -# X = jnp.array(np.random.rand(shape, shape)) -# start = time.perf_counter() -# norm(X).block_until_ready() -# end = time.perf_counter() -# print(f'iteration {i}: norm: {end - start}') -# start = time.perf_counter() -# norm_compiled(X).block_until_ready() -# end = time.perf_counter() -# print(f'iteration {i}: norm_compiled: {end - start}') - -#Example 4: print out the graph -print('--- example 4 ---') - -X = jnp.array(np.random.rand(1000, 1000)) - -#print(norm_compiled.lower(X).as_text()) -#print(norm_compiled.lower(jax.ShapeDtypeStruct((1000, 1000), jnp.float32.dtype)).as_text()) - -# Example 5: What happen to closures -print('--- example 5 ---') - - -def addx(y): - #print('y is ', y) - #print('X is ', X) - # import pdb; pdb.set_trace() - #jax.debug.print(...) - #jax.debug.breakpoint() - return y + X - - -addx_jitted = jax.jit(addx).lower( - jax.ShapeDtypeStruct((1000, 1000), jnp.float32.dtype)) -#print(addx_jitted.as_text()) -#print(addx_jitted.compile()(X)) - -# ASIDE: pdb; print; -# https://jax.readthedocs.io/en/latest/debugging/print_breakpoint.html - -# Example 6: What happens with class attr - - -class Model: - - def __init__(self): - self.weight = jnp.array(np.random.randn(1000, 1000)) - self.bias = jnp.array(np.random.randn(1000)) - - def __call__(self, X): - return X @ self.weight + self.bias - - -m = Model() -print('Not jitted', m(X)) -print(jax.jit(m).lower(X).as_text()) - - -def model_pure(weight, bias, X): - m.weight = weight - m.bias = bias - res = m(X) - return res diff --git a/torchax/docs/understand_jax_jit/torch_module.py b/torchax/docs/understand_jax_jit/torch_module.py deleted file mode 100644 index 078daf8351dd..000000000000 --- a/torchax/docs/understand_jax_jit/torch_module.py +++ /dev/null @@ -1,121 +0,0 @@ -import jax -import jax.numpy as jnp -import torch -from torch.nn.functional import linear - -## Goal 1. Illustrate that class attr for torch modules -## dont play nice with jax.jit -## How we solve it. - - -class Linear(torch.nn.Module): - - def __init__(self): - super().__init__() - self.weight = torch.nn.Parameter(torch.randn(1000, 1000)) - self.bias = torch.nn.Parameter(torch.randn(1000)) - - def forward(self, X): - return linear(X, self.weight, self.bias) - - -# Running with torch native - -print('---- example 1 -----') -m = Linear() -x = torch.randn(2, 1000) -print(m(x)) -print(m.forward(x)) - -with torch.inference_mode(): - # with torch.no_grad(): - print(m.forward(x)) - -print('---- example 2 -----') - -import torchax - -env = torchax.default_env() - -with env: - m2 = Linear() - x = torch.randn(2, 1000) - print(m2.forward(x)) - -print('---- example 3 -----') -# where is the jax jit? - -# m2 is a callable that takes in Tensor and returns Tensor -# m2: (Tensor -> Tensor) - -# suppose t2j (Tensor -> jax.Array) "unwraps the XLATensor" -# suppose j2t (jax.Array -> Tensor) "wraps the XLATensor" -from torchax import tensor -import jax - - -def t2j(torch_tensor: tensor.Tensor) -> jax.Array: - return torch_tensor._elem - - -def j2t(jax_array: jax.Array) -> tensor.Tensor: - return tensor.Tensor(jax_array, env) - - -# # further notice t2j(j2t(x)) == x; j2t(t2j(x)) == x - - -def jax_m(X: jax.Array): - X_jax = j2t(X) - res = m2(X_jax) - return t2j(res) - - -jax_x = jnp.ones((10, 1000)) -print(jax_m(jax_x)) - -## Let f: Tensor -> Tensor -## There is a function g: jax.Array -> jax.Array; -## g = x |-> j2t (f (t2j(x))). OR, -## g = j2t . f . t2j (. denotes function composition) -# The correspondence f -> g is an isomorphism too. - -jitted_jax_m = jax.jit(jax_m) -print(jitted_jax_m(jax_x)) -print(jitted_jax_m.lower(jax_x).as_text()) - -from torch.utils import _pytree as pytree - - -def jax_m_functional(states, X): - states_torch = pytree.tree_map(j2t, states) - X = j2t(X) - old_state_dict = m2.state_dict() - m2.load_state_dict(states_torch, assign=True, strict=False) - res = m2(X) - m2.load_state_dict(old_state_dict, assign=True, strict=False) - return t2j(res) - - -jax_weights = { - 'weight': m2.weight._elem, - 'bias': m2.bias._elem, -} - -jitted_jax_m_functional = jax.jit(jax_m_functional) -print(jitted_jax_m_functional.lower(jax_weights, jax_x).as_text()) - -# ## interop module - -# print('---- exmaple 4 ----') -# import torchax.interop - -# def m_functional(states, x): -# return torch.func.functional_call(m2, states, x) - -# with jax.checking_leaks(): -# print(torchax.interop.jax_jit(m_functional)(m2.state_dict(), x)) - -# # Experiment if time: -# # 1. torch buffer persistence = False -# # 2. torch attr diff --git a/torchax/examples/README.md b/torchax/examples/README.md deleted file mode 100644 index f2cb4f0b66e4..000000000000 --- a/torchax/examples/README.md +++ /dev/null @@ -1,115 +0,0 @@ -## Intro - -This readme will have a subsection for every example *.py file. - -Please follow the instructions in [README.md](../README.md) to install torchax, -then install requirements for all of the examples with - -```bash -pip install -r requirements.txt -``` - - - -## basic_training.py - -This file constructed by first copy & paste code fragments from this pytorch training tutorial: -https://pytorch.org/tutorials/beginner/introyt/trainingyt.html - -Then adding few lines of code that serves the purpose of moving `torch.Tensor` into -`XLA devices`. - -Example: - -```python -state_dict = pytree.tree_map_only(torch.Tensor, - torchax.tensor.move_to_device, state_dict) -``` - -This fragment moves the state_dict to XLA devices; then the state_dict is passed -back to model via `load_state_dict`. - -Then, you can train the model. This shows what is minimum to train a model on XLA -devices. The perf is not as good because we didn't use `jax.jit`, this is intentional -as it is meant to showcase the minimum code change. - -Example run: -```bash -(xla2) hanq-macbookpro:examples hanq$ python basic_training.py -Training set has 60000 instances -Validation set has 10000 instances -Bag Dress Sneaker T-shirt/top -tensor([[0.8820, 0.3807, 0.3010, 0.9266, 0.7253, 0.9265, 0.0688, 0.4567, 0.7035, - 0.2279], - [0.3253, 0.1558, 0.1274, 0.2776, 0.2590, 0.4169, 0.1881, 0.7423, 0.4561, - 0.5985], - [0.5067, 0.4514, 0.9758, 0.6088, 0.7438, 0.6811, 0.9609, 0.3572, 0.4504, - 0.8738], - [0.1850, 0.1217, 0.8551, 0.2120, 0.9902, 0.7623, 0.1658, 0.6980, 0.3086, - 0.5709]]) -tensor([1, 5, 3, 7]) -Total loss for this batch: 2.325265645980835 -EPOCH 1: - batch 1000 loss: 1.041275198560208 - batch 2000 loss: 0.6450189483696595 - batch 3000 loss: 0.5793989677671343 - batch 4000 loss: 0.5170258888280951 - batch 5000 loss: 0.4920090722264722 - batch 6000 loss: 0.48910293977567926 - batch 7000 loss: 0.48058812761632724 - batch 8000 loss: 0.47159107415075413 - batch 9000 loss: 0.4712311488997657 - batch 10000 loss: 0.4675815168160479 - batch 11000 loss: 0.43210567891132085 - batch 12000 loss: 0.445208148030797 - batch 13000 loss: 0.4119230824254337 - batch 14000 loss: 0.4190662656680215 - batch 15000 loss: 0.4094535468676477 -LOSS train 0.4094535468676477 valid XLA -``` - -## basic_training_jax.py - -This file constructed by first copy & paste code fragments from this pytorch training tutorial: -https://pytorch.org/tutorials/beginner/introyt/trainingyt.html - -Then replacing torch optimizer with `optax` optimizer; and use `jax.grad` for -gradient instead of `torch.Tensor.backward()`. - -Then, you can train the model using jax ecosystem's training loop. This is meant to -showcase how easy is to integrate with Jax. - -Example run: -```bash -(xla2) hanq-macbookpro:examples hanq$ python basic_training_jax.py -Training set has 60000 instances -Validation set has 10000 instances -Pullover Ankle Boot Pullover Ankle Boot -tensor([[0.5279, 0.8340, 0.3131, 0.8608, 0.3668, 0.6192, 0.7453, 0.3261, 0.8872, - 0.1854], - [0.7414, 0.8309, 0.8127, 0.8866, 0.2475, 0.2664, 0.0327, 0.6918, 0.6010, - 0.2766], - [0.3304, 0.9135, 0.2762, 0.6737, 0.0480, 0.6150, 0.5610, 0.5804, 0.9607, - 0.6450], - [0.9464, 0.9439, 0.3122, 0.1814, 0.1194, 0.5012, 0.2058, 0.1170, 0.7377, - 0.7453]]) -tensor([1, 5, 3, 7]) -Total loss for this batch: 2.4054245948791504 -EPOCH 1: - batch 1000 loss: 1.0705260595591972 - batch 2000 loss: 1.0997755021179327 - batch 3000 loss: 1.0186579653513108 - batch 4000 loss: 0.9090727646966116 - batch 5000 loss: 0.8309370622411024 - batch 6000 loss: 0.8702225417760783 - batch 7000 loss: 0.8750176187023462 - batch 8000 loss: 0.9652624803795453 - batch 9000 loss: 0.8688667197711766 - batch 10000 loss: 0.8021814124770199 - batch 11000 loss: 0.8000540231048071 - batch 12000 loss: 0.9150884484921057 - batch 13000 loss: 0.819690621060171 - batch 14000 loss: 0.8569030471532278 - batch 15000 loss: 0.8740896808278603 -LOSS train 0.8740896808278603 valid 2.3132264614105225 -``` diff --git a/torchax/examples/__init__.py b/torchax/examples/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/torchax/examples/basic_training.py b/torchax/examples/basic_training.py deleted file mode 100644 index 1af28e64b587..000000000000 --- a/torchax/examples/basic_training.py +++ /dev/null @@ -1,195 +0,0 @@ -""" -This is the script from this tutorial: -https://pytorch.org/tutorials/beginner/introyt/trainingyt.html - -Then, it's modified to make the training loop using Jax's grad -and optimizer -""" - -import torch -import torchvision -import torchvision.transforms as transforms - -# PyTorch TensorBoard support -#from torch.utils.tensorboard import SummaryWriter -#from datetime import datetime - -# NOTE: add these lines to make it run on TPUs! -import torchax - -torchax.enable_globally() - -transform = transforms.Compose( - [transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,))]) - -# Create datasets for training & validation, download if necessary -training_set = torchvision.datasets.FashionMNIST( - './data', train=True, transform=transform, download=True) -validation_set = torchvision.datasets.FashionMNIST( - './data', train=False, transform=transform, download=True) - -# Create data loaders for our datasets; shuffle for training, not for validation -training_loader = torch.utils.data.DataLoader( - training_set, batch_size=4, shuffle=True) -validation_loader = torch.utils.data.DataLoader( - validation_set, batch_size=4, shuffle=False) - -# Class labels -classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', - 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot') - -# Report split sizes -print('Training set has {} instances'.format(len(training_set))) -print('Validation set has {} instances'.format(len(validation_set))) - -import matplotlib.pyplot as plt -import numpy as np - - -# Helper function for inline image display -def matplotlib_imshow(img, one_channel=False): - if one_channel: - img = img.mean(dim=0) - img = img / 2 + 0.5 # unnormalize - npimg = img.numpy() - if one_channel: - plt.imshow(npimg, cmap="Greys") - else: - plt.imshow(np.transpose(npimg, (1, 2, 0))) - - -#torchax.env.config.debug_print_each_op = True -#torchax.env.config.debug_mixed_tensor = True -dataiter = iter(training_loader) -images, labels = next(dataiter) - -# Create a grid from the images and show them -img_grid = torchvision.utils.make_grid(images) -matplotlib_imshow(img_grid, one_channel=True) -print(' '.join(classes[labels[j]] for j in range(4))) - -import torch.nn as nn -import torch.nn.functional as F - - -# PyTorch models inherit from torch.nn.Module -class GarmentClassifier(nn.Module): - - def __init__(self): - super(GarmentClassifier, self).__init__() - self.fc1 = nn.Linear(28 * 28, 120) - self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 10) - - def forward(self, x): - x = x.view(-1, 28 * 28) - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - x = self.fc3(x) - return x - - -model = GarmentClassifier().to('jax') - -loss_fn = torch.nn.CrossEntropyLoss() - -# NB: Loss functions expect data in batches, so we're creating batches of 4 -# Represents the model's confidence in each of the 10 classes for a given input -dummy_outputs = torch.rand(4, 10, device='jax') -# Represents the correct class among the 10 being tested -dummy_labels = torch.tensor([1, 5, 3, 7], device='jax') - -print(dummy_outputs) -print(dummy_labels) - -loss = loss_fn(dummy_outputs, dummy_labels) -print('Total loss for this batch: {}'.format(loss.item())) - -# Optimizers specified in the torch.optim package -optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) - - -def train_one_epoch(epoch_index, tb_writer=None): - running_loss = 0. - last_loss = 0. - - # Here, we use enumerate(training_loader) instead of - # iter(training_loader) so that we can track the batch - # index and do some intra-epoch reporting - for i, data in enumerate(training_loader): - # Every data instance is an input + label pair - # NEW: Move model to XLA device - inputs, labels = data - inputs = inputs.to('jax') - labels = labels.to('jax') - - # Zero your gradients for every batch! - optimizer.zero_grad() - - # Make predictions for this batch - - outputs = model(inputs) - - # Compute the loss and its gradients - loss = loss_fn(outputs, labels) - loss.backward() - - # Adjust learning weights - optimizer.step() - - # Gather data and report - running_loss += loss.item() - if i % 1000 == 999: - last_loss = running_loss / 1000 # loss per batch - print(' batch {} loss: {}'.format(i + 1, last_loss)) - tb_x = epoch_index * len(training_loader) + i + 1 - #tb_writer.add_scalar('Loss/train', last_loss, tb_x) - running_loss = 0. - - return last_loss - - -# Initializing in a separate cell so we can easily add more epochs to the same run -#timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') -#writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp)) -epoch_number = 0 -EPOCHS = 2 -best_vloss = 1_000_000. - -for epoch in range(EPOCHS): - print('EPOCH {}:'.format(epoch_number + 1)) - - # Make sure gradient tracking is on, and do a pass over the data - model.train(True) - - avg_loss = train_one_epoch(epoch_number) - - running_vloss = 0.0 - # Set the model to evaluation mode, disabling dropout and using population - # statistics for batch normalization. - model.eval() - - # Disable gradient computation and reduce memory consumption. - with torch.no_grad(): - for i, vdata in enumerate(validation_loader): - vinputs, vlabels = vdata - vinputs = vinputs.to('jax') - vlabels = vlabels.to('jax') - voutputs = model(vinputs) # call model's forward - vloss = loss_fn(voutputs, vlabels) - running_vloss += vloss - - avg_vloss = running_vloss / (i + 1) - print('LOSS train {} valid {}'.format(avg_loss, avg_vloss)) - - # Log the running loss averaged per batch - # for both training and validation - - # # Track best performance, and save the model's state - # if avg_vloss < best_vloss: - # best_vloss = avg_vloss - # model_path = 'model_{}_{}'.format(timestamp, epoch_number) - # torch.save(model.state_dict(), model_path) - - epoch_number += 1 diff --git a/torchax/examples/basic_training_jax.py b/torchax/examples/basic_training_jax.py deleted file mode 100644 index 96e39359cf7e..000000000000 --- a/torchax/examples/basic_training_jax.py +++ /dev/null @@ -1,138 +0,0 @@ -""" -This is the script from this tutorial: -https://pytorch.org/tutorials/beginner/introyt/trainingyt.html -""" - -import functools -from torchax import train, interop -import torch -from torch.utils import _pytree as pytree -import torchvision -import torchvision.transforms as transforms -import torchax -import torchax.interop -import jax -import optax -import numpy as np - -# PyTorch TensorBoard support -from torch.utils.tensorboard import SummaryWriter -from datetime import datetime - -env = torchax.enable_globally() - -transform = transforms.Compose( - [transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,))]) - -# Create datasets for training & validation, download if necessary -training_set = torchvision.datasets.FashionMNIST( - './data', train=True, transform=transform, download=True) -validation_set = torchvision.datasets.FashionMNIST( - './data', train=False, transform=transform, download=True) - -# Create data loaders for our datasets; shuffle for training, not for validation -training_loader = torch.utils.data.DataLoader( - training_set, batch_size=4, shuffle=True) -validation_loader = torch.utils.data.DataLoader( - validation_set, batch_size=4, shuffle=False) - -# Class labels -classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', - 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot') - -# Report split sizes -print('Training set has {} instances'.format(len(training_set))) -print('Validation set has {} instances'.format(len(validation_set))) - -import numpy as np -import torch.nn as nn -import torch.nn.functional as F - - -# PyTorch models inherit from torch.nn.Module -class GarmentClassifier(nn.Module): - - def __init__(self): - super(GarmentClassifier, self).__init__() - self.fc1 = nn.Linear(28 * 28, 120) - self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 10) - - def forward(self, x): - x = x.view(-1, 28 * 28) - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - x = self.fc3(x) - return x - - -model = GarmentClassifier() -loss_fn = torch.nn.CrossEntropyLoss() - -jax_optimizer = optax.adam(0.01) - -model.to('jax') # move the model to jax device -model_jittable = interop.JittableModule(model) -weights = model_jittable.params # these are trainable parameters -buffers = model_jittable.buffers # these are non-trainable parameters - -opt_state = interop.call_jax(jax_optimizer.init, weights) -model_fn = functools.partial(model_jittable.functional_call, 'forward') - -train_step = train.make_train_step(model_fn, loss_fn, jax_optimizer) - -train_step = interop.jax_jit( - train_step, kwargs_for_jax_jit={'donate_argnums': (0, 2)}) - -# NB: Loss functions expect data in batches, so we're creating batches of 4 -# Represents the model's confidence in each of the 10 classes for a given input -dummy_inputs = torch.rand(4, 28, 28).to('jax') -dummy_outputs = torch.rand(4, 10).to('jax') -# Represents the correct class among the 10 being tested -dummy_labels = torch.tensor([1, 5, 3, 7]).to('jax') - -# test train_step - - -def train_one_epoch(weights, buffers, opt_state, epoch_index, tb_writer): - running_loss = 0. - last_loss = 0. - - # Here, we use enumerate(training_loader) instead of - # iter(training_loader) so that we can track the batch - # index and do some intra-epoch reporting - for i, data in enumerate(training_loader): - inputs, labels = data - - inputs = inputs.to('jax') - labels = labels.to('jax') - - loss, weights, opt_state = train_step(weights, buffers, opt_state, inputs, - labels) - - # Gather data and report - running_loss += loss.item() - if i % 1000 == 999: - last_loss = running_loss / 1000 # loss per batch - print(' batch {} loss: {}'.format(i + 1, last_loss)) - tb_x = epoch_index * len(training_loader) + i + 1 - tb_writer.add_scalar('Loss/train', last_loss, tb_x) - running_loss = 0. - - return last_loss, weights, opt_state - - -# Initializing in a separate cell so we can easily add more epochs to the same run -timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') -writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp)) -epoch_number = 0 -EPOCHS = 2 -best_vloss = 1_000_000. - -for epoch in range(EPOCHS): - print('EPOCH {}:'.format(epoch_number + 1)) - - avg_loss, weights, opt_state = train_one_epoch(weights, buffers, opt_state, - epoch_number, writer) - print(avg_loss) diff --git a/torchax/examples/eager_mode.py b/torchax/examples/eager_mode.py deleted file mode 100644 index 946ecce77772..000000000000 --- a/torchax/examples/eager_mode.py +++ /dev/null @@ -1,38 +0,0 @@ -import torchax -from torch import nn -from torch.nn import functional as F -import torch - -xla_env = torchax.enable_globally() - - -class MyModel(nn.Module): - - def __init__(self): - super().__init__() - self.fc1 = nn.Linear(28 * 28, 120) - self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 10) - - def forward(self, x): - x = x.view(-1, 28 * 28) - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - x = self.fc3(x) - return x - - -m = MyModel() -m = m.to('jax') - -# Execute this model using torch -inputs = torch.randn(3, 3, 28, 28, device='jax') - -print(m(inputs)) -print('---=====') - -m_compiled = torchax.compile(m) - -print(m_compiled(inputs)) - -print('---') diff --git a/torchax/examples/lightning_training.py b/torchax/examples/lightning_training.py deleted file mode 100644 index 3f6e760bf94b..000000000000 --- a/torchax/examples/lightning_training.py +++ /dev/null @@ -1,82 +0,0 @@ -import os, torch, torch.nn as nn, torch.utils.data as data, torchvision as tv -import lightning as L - -encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3)) -decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28)) - - -class LitAutoEncoder(L.LightningModule): - - def __init__(self, encoder, decoder): - super().__init__() - self.encoder, self.decoder = encoder, decoder - - def training_step(self, batch, batch_idx): - x, y = batch - x = x.view(x.size(0), -1) - z = self.encoder(x) - x_hat = self.decoder(z) - loss = nn.functional.mse_loss(x_hat, x) - self.log("train_loss", loss) - return loss - - def configure_optimizers(self): - return torch.optim.Adam(self.parameters(), lr=1e-3) - - -dataset = tv.datasets.MNIST( - ".", download=True, transform=tv.transforms.ToTensor()) - -# Lightning will automatically use all available GPUs! -trainer = L.Trainer() -# trainer.fit(LitAutoEncoder(encoder, decoder), data.DataLoader(dataset, batch_size=64)) - -# ==== above is the lightning example from -# https://lightning.ai/pytorch-lightning - -import torchax -from torchax.interop import jax_view, torch_view -import jax -import optax - - -class JaxTrainer: - - def __init__(self): - pass - - def torch_opt_to_jax_opt(self, torch_opt): - # TODO: Can convert optimizer instead of using a jax one - return optax.adam(0.001) - - def fit(self, lightning_mod, data_loader): - - xla_env = torchax.default_env() - - def lightning_mod_loss(weights: jax.Array, data: jax.Array, batch_id): - """returns loss""" - weights, data = torch_view((weights, data)) - lightning_mod.load_state_dict(weights, assign=True) - with xla_env: - loss = lightning_mod.training_step(data, batch_id) - return jax_view(loss) - - jax_weights = jax_view(xla_env.to_xla(lightning_mod.state_dict())) - jax_optimizer = self.torch_opt_to_jax_opt( - lightning_mod.configure_optimizers()) - opt_state = jax_optimizer.init(jax_weights) - grad_fn = jax.jit(jax.value_and_grad(lightning_mod_loss)) - - for bid in range(3): - for item in data_loader: - xla_data = jax_view(xla_env.to_xla(item)) - loss, grads = grad_fn(jax_weights, xla_data, bid) - updates, opt_state = jax_optimizer.update(grads, opt_state) - jax_weights = optax.apply_updates(jax_weights, updates) - print('current_loss', loss) - - -print('-----------------') -trainer_jax = JaxTrainer() -trainer_jax.fit( - LitAutoEncoder(encoder, decoder), data.DataLoader(dataset, batch_size=64)) diff --git a/torchax/examples/requirements.txt b/torchax/examples/requirements.txt deleted file mode 100644 index 69e01ff3dd07..000000000000 --- a/torchax/examples/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -torchvision -matplotlib -optax \ No newline at end of file diff --git a/torchax/examples/train_llama/README.md b/torchax/examples/train_llama/README.md deleted file mode 100644 index eb4d4e3a9c24..000000000000 --- a/torchax/examples/train_llama/README.md +++ /dev/null @@ -1,194 +0,0 @@ -Lightning Based training for llama 3 -==================================== - -# Abstract: -We train llama3 model using a the PyTorch implementation from [litgpt](). -We train it using a pytorch-lightning like setup, where user defines a `training_step` methods. The trainer itself is a custom trainer that works on CloudTPU by leverying jax. - -## Result: - -* Best v5p MFU: 63.7% (batch = 8, sequence length = 8192) -* Best v4-8 MFU: 44% (batch = 8, sequence length = 2048) - -# Setup: - -```bash -pip install 'litgpt[all]' optax fire - -litgpt download meta-llama/Meta-Llama-3-8B-Instruct -litgpt download --repo_id meta-llama/Meta-Llama-3-8B-Instruct --tokenizer_only true --access_token -``` - -Then, run with -```bash -python -m examples.train_llama.train_llama_lightning --mode=all --seqlen=2048 --checkpoint_dir= -``` - -# The training script - -The script in [train_llama_lightning.py](train_llama_training.py) is envisioned -to be what the users need to write for their training setup. In summary it -consists the following: - -### 1. Put the model in a `LightningModule` subclass: - -```python -class GPTLightningModule(lightning.LightningModule): - - def __init__(self, gpt): - super().__init__() - self.gpt = utils.FSDPv2(gpt) - - def training_step(self, batch, batch_idx): - x, y = batch - logits = self.gpt.forward(x) - num_tokens = logits.shape[-1] - logits = logits[..., :-1, :].reshape(-1, num_tokens) - y = y[..., 1:].reshape(-1) - return torch.nn.functional.cross_entropy( - logits, y) - - def configure_optimizers(self): - return None -``` - -This class is responsible of wrapping (or instantiating) a model, call it's forward, and defining the formula to compute loss. - -Next, the user need to call our trainer along with the dataloader: - -```python -def main(): - gpt = ... - light_mod = GPTLightningModule(gpt) - # data loader setup and stuff skipped - train_loader = ... - - # Train - trainer = ... - trainer.fit(light_mod, train_loader) -``` - -The actual script in train_llama_lightning.py is more complex because -we are testing out different options and optimizing strategies. - -The trainer itself, as well as the helper class for sharding strategy (FSDPv2), -is defined in [utils.py](utils.py). In the future, we hope to upstream these -into pytorch-lightning and becames one of the `Strategy` that `pl.Trainer` uses. - -### FSDPv2 - -FSDPv2 is an implementation of Fully-sharded Data Parallel training strategy using -GSPMD. To implement this, we need 2 things: - -1. Shard inputs on batch dimension (i.e. like DDP) -2. Shard all the weights in the first dimension. - -To implement this, we create a mesh with first axis called 'fsdp' and shard -everything on this. - -```python -class FSDPv2(torch.nn.Module): - - def __init__(self, mod): - super().__init__() - self.mod = mod - - num_of_partitions = jax.device_count() - self.mesh = jax.sharding.Mesh( - mesh_utils.create_device_mesh((num_of_partitions, )), - axis_names=("fsdp", ), - ) - self.sharding = jax.sharding.NamedSharding(self.mesh, P("fsdp")) - - def forward(self, *args): - args = list(args) - args[0] = self.shard(args[0]) - res = self.mod(*args) - return self.shard(res) - - def shard(self, x): - return torchax.interop.call_jax( - jax.lax.with_sharding_constraint, - x, - self.sharding, - ) -``` -We also need a similar function that shards the weights. - -### Flash attention - -Flash attention is a important optimization that enables training with large -sequence length. Jax has an implementation of flash attention located in -`jax.experimental.pallas.ops.tpu.flash_attention`. To make the model uses -this version of flash attention, we simply register a lowering for PyTorch's -`torch.nn.functional.scaled_dot_product_attention` like so: - -```python -@register_function(torch.nn.functional.scaled_dot_product_attention, is_jax_function=False, needs_env=True) -def scaled_dot_product_attention( - query, key, value, attn_mask=None, - dropout_p=0.0, is_causal=False, scale=None, env=None) -> torch.Tensor: - - if env.use_flash_attention: - jquery, jkey, jvalue = env.t2j_iso((query, key, value)) - res = _tpu_flash_attention(jquery, jkey, jvalue, env) - return env.j2t_iso(res) - - return _sdpa_reference(query, key, value, attn_mask, dropout_p, is_causal, scale) -``` - -this implementation is located in [jtorch.py](../../torchax/ops/jtorch.py) in -torchax. The model itself does not need to change to use TPU version of -flash attention, because it's calling pytorch's `F.scaled_dot_product_attention`. - -## Misc optimizations - - -### Compile one layer first - -A program compiled with `jax.jit` is a straight graph of XLA operators (StableHLO ops). For the llama3 model, it consists of 32 layers of identical code. This makes compile time extremely long. We can outline one of the layers, and call that one repeatedly to get slightly faster compile time. - -To compile the regular 32-layer model, it takes 210s on v4-8; with outlining we -can reduce this to 190s. And program size (number of chars in `jax.jit(...).lowered.as_text()`) is reduced from 2.6 million to 1.9 million. - -To accomplish this, we can wrap the original model with the `GPTOutline` wrapper. - -### Use `jax.lax.scan` to iterate the layers - -Another more intrusive change is to wrap to change the loop to use -scan instead of python loop. This way we can get an even smaller program. -This change is illustrated in `GPTFori` wrapper. -With this change, we can shrink the program size to 0.49 million characters, and -compile time to 19.8s. - -Scan makes compiling faster but makes runtime slightly slower: XLA's ability -of optimizing across control flow boundaries is less than it's ability to optimize -on straight graph, so we lose a bit of runtime perf. - - -## Detailed numbers - -### v5p-8 - -seqlen = 8192 -bs = 8 - -| Batch Size | Sequence Length | Mode | Step Time (s) | Compile Time (s) | MFU | -|:-------------|:------------------|:-----------------------|:------------|:---------------|:-------------| -| 8 | 8192 | scan_layer | 4.38044 | 12.22 | 49.76% | -| 8 | 8192 | scan_manual | 4.32437 | 12.95 | 50.41% | -| 8 | 8192 | regular | 4.56214 | 1086.71 | 47.78% | -| 8 | 8192 | outlined | 3.41887 | 1079.77 | 63.76% | - - -### v4-8 - -seqlen = 2048 -bs = 8 - -| Batch Size | Sequence Length | Mode | Step Time (s) | Compile Time (s) | MFU | -|:-------------|:------------------|:------------|:-------------|:---------------|:------| -| 8 | 2048 | scan_layer | 1.80099 | 17.61 | 42% | -| 8 | 2048 | scan_manual | 1.85214 | 16.69 | 41% | -| 8 | 2048 | regular | 1.70979 | 362.32 | 44% | -| 8 | 2048 | outlined | OOM | - | - | diff --git a/torchax/examples/train_llama/__init__.py b/torchax/examples/train_llama/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/torchax/examples/train_llama/model.py b/torchax/examples/train_llama/model.py deleted file mode 100644 index 2cc545fd6662..000000000000 --- a/torchax/examples/train_llama/model.py +++ /dev/null @@ -1,510 +0,0 @@ -# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. -"""Full definition of a decoder-only transformer-based language model, all of it in this single file. - -Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and -https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model. -""" -import jax -from jax.sharding import PartitionSpec as P -import math -from typing import Any, Optional, Tuple - -import torch -import torch.nn as nn -from typing_extensions import Self - -from litgpt.config import Config - - -def reapply_sharding(x): - x._elem = jax.lax.with_sharding_constraint(x._elem, P('fsdp')) - - -class GPT(nn.Module): - - def __init__(self, config: Config, use_fori_loop=False) -> None: - super().__init__() - assert config.padded_vocab_size is not None - self.config = config - - self.lm_head = nn.Linear( - config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias) - self.transformer = nn.ModuleDict( - dict( - wte=nn.Embedding(config.padded_vocab_size, config.n_embd), - h=nn.ModuleList(Block(config) for _ in range(config.n_layer)), - ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), - )) - self.max_seq_length = self.config.block_size - self.mask_cache: Optional[torch.Tensor] = None - - @property - def max_seq_length(self) -> int: - return self._max_seq_length - - @max_seq_length.setter - def max_seq_length(self, value: int) -> None: - """ - When doing inference, the sequences used might be shorter than the model's context length. - This allows setting a smaller number to avoid allocating unused memory - """ - if value > self.config.block_size: - raise ValueError( - f"Cannot attend to {value}, block size is only {self.config.block_size}" - ) - self._max_seq_length = value - if not hasattr(self, "cos"): - # first call - cos, sin = self.rope_cache() - self.register_buffer("cos", cos, persistent=False) - self.register_buffer("sin", sin, persistent=False) - # override - elif value != self.cos.size(0): - self.cos, self.sin = self.rope_cache(device=self.cos.device) - # the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know - # if the kv cache is expected - - def reset_parameters(self) -> None: - # Trigger resetting the rope-cache - self.cos, self.sin = self.rope_cache(device=self.cos.device) - - def _init_weights(self, module: nn.Module) -> None: - """Meant to be used with `gpt.apply(gpt._init_weights)`.""" - if isinstance(module, nn.Linear): - torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) - if module.bias is not None: - torch.nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) - - def forward(self, - idx: torch.Tensor, - input_pos: Optional[torch.Tensor] = None) -> torch.Tensor: - T = idx.size(1) - if self.max_seq_length < T: - raise ValueError( - f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}." - ) - - if input_pos is not None: # use the kv cache - cos = self.cos.index_select(0, input_pos) - sin = self.sin.index_select(0, input_pos) - if self.mask_cache is None: - raise TypeError("You need to call `gpt.set_kv_cache()`") - mask = self.mask_cache.index_select(2, input_pos) - else: - cos = self.cos[:T] - sin = self.sin[:T] - mask = None - x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) - reapply_sharding(x) - if self.config.scale_embeddings: - x = x * (self.config.n_embd**0.5) - - for block in self.transformer.h: - x = block(x, cos, sin, mask, input_pos) - reapply_sharding(x) - x = self.transformer.ln_f(x) - res = self.lm_head(x) # (b, t, vocab_size) - reapply_sharding(res) - return res - - @classmethod - def from_name(cls, name: str, **kwargs: Any) -> Self: - return cls(Config.from_name(name, **kwargs)) - - def rope_cache( - self, - device: Optional[torch.device] = None - ) -> Tuple[torch.Tensor, torch.Tensor]: - return build_rope_cache( - seq_len=self.max_seq_length, - n_elem=self.config.rope_n_elem, - device=device, - condense_ratio=self.config.rope_condense_ratio, - base=self.config.rope_base, - ) - - def set_kv_cache( - self, - batch_size: int, - rope_cache_length: Optional[int] = None, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ) -> None: - if rope_cache_length is None: - rope_cache_length = self.cos.size(-1) - max_seq_length = self.max_seq_length - - # initialize the kv cache for all blocks - for block in self.transformer.h: - block.attn.kv_cache = block.attn.build_kv_cache(batch_size, - max_seq_length, - rope_cache_length, device, - dtype) - - if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length: - # passing `attn_mask` to SDPA disables the flash implementation. since we only need the mask - # for the kv-cache support (only during inference), we only create it in that situation - self.mask_cache = build_mask_cache(max_seq_length, device) - - def clear_kv_cache(self) -> None: - self.mask_cache = None - for block in self.transformer.h: - block.attn.kv_cache = None - - -class Block(nn.Module): - - def __init__(self, config: Config) -> None: - super().__init__() - if not config.parallel_residual and config.shared_attention_norm: - raise NotImplementedError( - "No checkpoint amongst the ones we support uses this configuration" - " (non-parallel residual and shared attention norm).") - - self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) - self.attn = CausalSelfAttention(config) - self.norm_2 = None if config.shared_attention_norm else config.norm_class( - config.n_embd, eps=config.norm_eps) - self.mlp = config.mlp_class(config) - - self.config = config - - def forward( - self, - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - mask: Optional[torch.Tensor] = None, - input_pos: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Non-parallel residual Parallel residual - ┌─ x ┌─ x ────────────┐ Note: if `shared_attention_norm` is True, - │ ↓ │ ↓ ↓ the output from `norm_1` is reused - │ norm_1 │ norm_1 ───► norm_2 - │ ↓ │ ↓ ↓ - │ attn │ attn mlp - │ ↓ │ ↓ │ - ┌─ └► + └► + ◄───────────┘ - │ norm_2 - │ ↓ - │ mlp - │ ↓ - └───► + - """ - - x_normed = self.norm_1(x) - attention_output = self.attn(x_normed, cos, sin, mask, input_pos) - - if self.config.parallel_residual: - x_normed = x_normed if self.config.shared_attention_norm else self.norm_2( - x) - x = self.mlp(x_normed) + attention_output + x - else: - x = attention_output + x - x = self.mlp(self.norm_2(x)) + x - return x - - -class CausalSelfAttention(nn.Module): - - def __init__(self, config: Config) -> None: - super().__init__() - shape = (config.n_head + 2 * config.n_query_groups) * config.head_size - # key, query, value projections for all heads, but in a batch - self.attn = nn.Linear(config.n_embd, shape, bias=config.bias) - # output projection - # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head` - self.proj = nn.Linear( - config.head_size * config.n_head, config.n_embd, bias=config.bias) - # disabled by default - self.kv_cache: Optional[KVCache] = None - - self.config = config - - def forward( - self, - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - mask: Optional[torch.Tensor] = None, - input_pos: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - B, T, C = x.size( - ) # batch size, sequence length, embedding dimensionality (n_embd) - - qkv = self.attn(x) - qkv._elem = jax.lax.with_sharding_constraint(qkv._elem, P('fsdp')) - - # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`) - q_per_kv = self.config.n_head // self.config.n_query_groups - total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value - qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, - self.config.head_size) - qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - - # split batched computation into three - q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - - # maybe repeat k and v if for the non multi-head attention cases - # training: flash attention requires it - # inference: multi-query would require a full kv cache so avoid it to limit its memory usage - if self.config.n_query_groups != self.config.n_head and ( - input_pos is None or self.config.n_query_groups != 1): - k = k.expand(B, self.config.n_query_groups, q_per_kv, T, - self.config.head_size) - v = v.expand(B, self.config.n_query_groups, q_per_kv, T, - self.config.head_size) - - q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - - q_roped = apply_rope(q[..., :self.config.rope_n_elem], cos, sin) - k_roped = apply_rope(k[..., :self.config.rope_n_elem], cos, sin) - q = torch.cat((q_roped, q[..., self.config.rope_n_elem:]), dim=-1) - k = torch.cat((k_roped, k[..., self.config.rope_n_elem:]), dim=-1) - - if input_pos is not None: - if not isinstance(self.kv_cache, KVCache): - raise TypeError("You need to call `gpt.set_kv_cache()`") - k, v = self.kv_cache(input_pos, k, v) - - y = self.scaled_dot_product_attention(q, k, v, mask) - - y = y.reshape( - B, T, self.config.head_size * - self.config.n_head) # re-assemble all head outputs side by side - - # output projection - return self.proj(y) - - def scaled_dot_product_attention( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - mask: Optional[torch.Tensor] = None) -> torch.Tensor: - scale = 1.0 / math.sqrt(self.config.head_size) - y = torch.nn.functional.scaled_dot_product_attention( - q, - k, - v, - attn_mask=mask, - dropout_p=0.0, - scale=scale, - is_causal=mask is None) - return y.transpose(1, 2) - - def build_kv_cache( - self, - batch_size: int, - max_seq_length: int, - rope_cache_length: Optional[int] = None, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ) -> "KVCache": - heads = 1 if self.config.n_query_groups == 1 else self.config.n_head - v_shape = (batch_size, heads, max_seq_length, self.config.head_size) - if rope_cache_length is None: - if self.config.rotary_percentage != 1.0: - raise TypeError( - "Please pass the `rope_cache_length=gpt.cos.size(-1)` value") - k_shape = v_shape - else: - k_shape = ( - batch_size, - heads, - max_seq_length, - rope_cache_length + self.config.head_size - self.config.rope_n_elem, - ) - return KVCache(k_shape, v_shape, device=device, dtype=dtype) - - -class GptNeoxMLP(nn.Module): - - def __init__(self, config: Config) -> None: - super().__init__() - self.fc = nn.Linear( - config.n_embd, config.intermediate_size, bias=config.bias) - self.proj = nn.Linear( - config.intermediate_size, config.n_embd, bias=config.bias) - - self.config = config - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.fc(x) - x = torch.nn.functional.gelu(x, approximate=self.config.gelu_approximate) - return self.proj(x) - - -class LLaMAMLP(nn.Module): - - def __init__(self, config: Config) -> None: - super().__init__() - self.fc_1 = nn.Linear( - config.n_embd, config.intermediate_size, bias=config.bias) - self.fc_2 = nn.Linear( - config.n_embd, config.intermediate_size, bias=config.bias) - self.proj = nn.Linear( - config.intermediate_size, config.n_embd, bias=config.bias) - - self.config = config - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x_fc_1 = self.fc_1(x) - x_fc_2 = self.fc_2(x) - x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - return self.proj(x) - - -class GemmaMLP(LLaMAMLP): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x_fc_1 = self.fc_1(x) - x_fc_2 = self.fc_2(x) - x = torch.nn.functional.gelu( - x_fc_1, approximate=self.config.gelu_approximate) * x_fc_2 - return self.proj(x) - - -class LLaMAMoE(nn.Module): - - def __init__(self, config: Config) -> None: - super().__init__() - self.gate = nn.Linear(config.n_embd, config.n_expert, bias=False) - self.experts = nn.ModuleList( - LLaMAMLP(config) for _ in range(config.n_expert)) - - self.config = config - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Derived from: https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219 - See also figure 1 in https://arxiv.org/abs/2211.15841 - """ - B, T, C = x.size( - ) # batch size, sequence length, embedding dimensionality (n_embd) - x = x.view(-1, C) # (B*T, C) - router = self.gate(x) # (B*T, n_expert) - probs, indices = torch.topk( - router, self.config.n_expert_per_token) # (B*T, n_expert_per_token) - probs = probs.softmax(dim=1, dtype=torch.float).to(dtype=x.dtype) - masks = indices.unsqueeze(-1) == torch.arange( - self.config.n_expert, device=x.device) - masks = masks.permute(2, 0, 1) # (n_expert, B*T, n_expert_per_token) - y = torch.zeros_like(x) # (B*T, C) - for mask, expert in zip(masks, self.experts): - token_idx, expert_idx = torch.where(mask) - y[token_idx] += probs[token_idx, expert_idx, None] * expert(x[token_idx]) - return y.view(B, T, C) - - -def build_rope_cache( - seq_len: int, - n_elem: int, - device: Optional[torch.device] = None, - base: int = 10000, - condense_ratio: int = 1) -> Tuple[torch.Tensor, torch.Tensor]: - """Enhanced Transformer with Rotary Position Embedding. - - Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ - transformers/rope/__init__.py. MIT License: - https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. - """ - # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ - theta = 1.0 / ( - base**(torch.arange(0, n_elem, 2, device=device).float() / n_elem)) - - # Create position indexes `[0, 1, ..., seq_len - 1]` - seq_idx = torch.arange(seq_len, device=device) / condense_ratio - - # Calculate the product of position index and $\theta_i$ - idx_theta = torch.outer(seq_idx, theta).repeat(1, 2) - - return torch.cos(idx_theta), torch.sin(idx_theta) - - -def apply_rope(x: torch.Tensor, cos: torch.Tensor, - sin: torch.Tensor) -> torch.Tensor: - head_size = x.size(-1) - x1 = x[..., :head_size // 2] # (B, nh, T, hs/2) - x2 = x[..., head_size // 2:] # (B, nh, T, hs/2) - rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - roped = (x * cos) + (rotated * sin) - return roped.to(dtype=x.dtype) - - -class KVCache(nn.Module): - - def __init__( - self, - k_shape: Tuple[int, int, int, int], - v_shape: Tuple[int, int, int, int], - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ) -> None: - super().__init__() - self.register_buffer( - "k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False) - self.register_buffer( - "v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False) - - def forward(self, input_pos: torch.Tensor, k: torch.Tensor, - v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - # move the buffer to the activation dtype for when AMP is used - self.k = self.k.to(k.dtype) - self.v = self.v.to(v.dtype) - # update the cache - k = self.k.index_copy_(2, input_pos, k) - v = self.v.index_copy_(2, input_pos, v) - return k, v - - def reset_parameters(self) -> None: - torch.nn.init.zeros_(self.k) - torch.nn.init.zeros_(self.v) - - -def build_mask_cache(max_seq_length: int, - device: Optional[torch.device] = None) -> torch.Tensor: - ones = torch.ones((max_seq_length, max_seq_length), - device=device, - dtype=torch.bool) - return torch.tril(ones).unsqueeze(0).unsqueeze(0) - - -class RMSNorm(torch.nn.Module): - """Root Mean Square Layer Normalization. - - Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License: - https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE. - """ - - def __init__(self, - size: int, - dim: int = -1, - eps: float = 1e-6, - add_unit_offset: bool = False) -> None: - super().__init__() - self.weight = torch.nn.Parameter(torch.ones(size)) - self.eps = eps - self.dim = dim - self.add_unit_offset = add_unit_offset - - def forward(self, x: torch.Tensor) -> torch.Tensor: - dtype = x.dtype - x = x.float() - # NOTE: the original RMSNorm paper implementation is not equivalent - norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - x_normed = x * torch.rsqrt(norm_x + self.eps) - x_normed = x_normed.to(dtype=dtype) - if self.add_unit_offset: - # Gemma model requires a unit offset - # https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L176 - return x_normed * (1 + self.weight) - return x_normed * self.weight - - def reset_parameters(self) -> None: - torch.nn.init.ones_(self.weight) diff --git a/torchax/examples/train_llama/train_llama_lightning.py b/torchax/examples/train_llama/train_llama_lightning.py deleted file mode 100644 index 2cc35e7cfeed..000000000000 --- a/torchax/examples/train_llama/train_llama_lightning.py +++ /dev/null @@ -1,307 +0,0 @@ -import jax -import jax.numpy as jnp -from litgpt import config -from litgpt import model -from litgpt.data import Alpaca -from litgpt.tokenizer import Tokenizer -import lightning -import torch -from collections import defaultdict -from jax.experimental import shard_map - -import torch.nn.functional -import torchax.interop - -from . import utils -from . import model as editted_model -import os - - -def _setup_default_env(): - os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1') - os.environ.setdefault('GRPC_VERBOSITY', 'ERROR') - os.environ.setdefault('ALLOW_MULTIPLE_LIBTPU_LOAD', '1') - # only need for tpu v4 - # os.environ.setdefault('TPU_MEGACORE', 'megacore_dense') - tpu_args = "--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" - - os.environ.setdefault('LIBTPU_INIT_ARGS', tpu_args) - - -_setup_default_env() - -default_checkpoint_dir = '/home/hanq/litgpt/checkpoints/meta-llama/Meta-Llama-3-8B/' - - -class GPTLightningModule(lightning.LightningModule): - - def __init__(self, gpt): - super().__init__() - self.gpt = utils.FSDPv2(gpt) - - def training_step(self, batch, batch_idx): - x, y = batch - logits = self.gpt.forward(x) - num_tokens = logits.shape[-1] - logits = logits[..., :-1, :].reshape(-1, num_tokens) - y = y[..., 1:].reshape(-1) - return torch.nn.functional.cross_entropy(logits, y) - - def configure_optimizers(self): - return None - - -from jax.experimental import mesh_utils - -P = jax.sharding.PartitionSpec -mesh = jax.sharding.Mesh( - mesh_utils.create_device_mesh(utils.num_partitions), - axis_names=utils.global_axis, -) - - -class GPTOutline(torch.nn.Module): - - def __init__(self, gpt_orig): - super().__init__() - self.gpt_orig = gpt_orig - - def one_layer(weights, args): - return torch.func.functional_call(self.gpt_orig.transformer.h[0], weights, - args) - - self.one_layer = torchax.interop.jax_jit(one_layer) - - def forward(self, idx: torch.Tensor, input_pos=None) -> torch.Tensor: - T = idx.size(1) - if self.gpt_orig.max_seq_length < T: - raise ValueError( - f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}." - ) - - cos = self.gpt_orig.cos[:T] - sin = self.gpt_orig.sin[:T] - mask = None - - x = self.gpt_orig.transformer.wte( - idx) # token embeddings of shape (b, t, n_embd) - editted_model.reapply_sharding(x) - if self.gpt_orig.config.scale_embeddings: - x = x * (self.gpt_orig.config.n_embd**0.5) - - for block in self.gpt_orig.transformer.h: - args = (x, cos, sin, mask, input_pos) - weights = block.state_dict() - x = self.one_layer(weights, args) - editted_model.reapply_sharding(x) - - x = self.gpt_orig.transformer.ln_f(x) - editted_model.reapply_sharding(x) - res = self.gpt_orig.lm_head(x) # (b, t, vocab_size) - editted_model.reapply_sharding(res) - return res - - -class GPTFori: - - def __init__(self, gpt_orig, manual_all_gather=False): - super().__init__() - self.gpt_orig = gpt_orig - - one_block = self.gpt_orig.transformer.h[0] - self.manual_all_gather = manual_all_gather - - def one_layer(args, weights): - # inputs are jax array - orig_args = args - - x, cos, sin, mask, input_pos = args - if self.manual_all_gather: - weights, cos, sin = jax.lax.all_gather((weights, cos, sin), - 'fsdp', - tiled=True) - args = (x, cos, sin, mask, input_pos) - args, weights = torchax.default_env().j2t_iso((args, weights)) - res = torch.func.functional_call(one_block, weights, args) - res = torchax.default_env().t2j_iso(res) - return (res, *orig_args[1:]), jnp.array([0]) - - if self.manual_all_gather: - one_layer = shard_map.shard_map( - one_layer, - mesh=mesh, - in_specs=(P(*utils.global_axis), P(*utils.global_axis)), - out_specs=(P(*utils.global_axis), P()), - check_rep=False) - - one_layer = jax.checkpoint( - one_layer, - policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable) - - def compiled_block(weights, x): - x, _ = jax.lax.scan(one_layer, x, weights, unroll=4) - return x[0] - - self.compiled_block = compiled_block - - self.weights = { - 'wte': self.gpt_orig.transformer.wte.state_dict(), - 'block': self.make_weights_scan(), - 'ln_f': self.gpt_orig.transformer.ln_f.state_dict(), - 'lm_head': self.gpt_orig.lm_head.state_dict(), - 'sin': self.gpt_orig.sin, - 'cos': self.gpt_orig.cos, - } - - def make_weights_scan(self): - temp = defaultdict(list) # key to list of tensors - for block in self.gpt_orig.transformer.h: - state_dict = block.state_dict() - for k, v in state_dict.items(): - temp[k].append(v) - - temp = {k: torch.stack(v) for k, v in temp.items()} - return temp - - def forward_with_weights(self, weights, idx): - T = idx.size(1) - if self.gpt_orig.max_seq_length < T: - raise ValueError( - f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}." - ) - - cos = weights['cos'][:T] - sin = weights['sin'][:T] - mask = None - - x = torch.func.functional_call( - self.gpt_orig.transformer.wte, - weights['wte'], - idx, - ) - - editted_model.reapply_sharding(x) - - if self.gpt_orig.config.scale_embeddings: - x = x * (self.config.n_embd**0.5) - editted_model.reapply_sharding(x) - - args = (x, cos, sin, mask, None) - #import pdb; pdb.set_trace() - x = torchax.interop.call_jax( - self.compiled_block, - weights['block'], - args, - ) - editted_model.reapply_sharding(x) - - x = torch.func.functional_call(self.gpt_orig.transformer.ln_f, - weights['ln_f'], x) - editted_model.reapply_sharding(x) - x = torch.func.functional_call(self.gpt_orig.lm_head, weights['lm_head'], x) - editted_model.reapply_sharding(x) - return x - - -import logging -import torchax - -# Modes: - -REGULAR = 'regular' -OUTLINED_LAYER = 'outlined' # jax.jit a block to get faster compilation -SCAN_LAYER = 'scan_layer' # jax.lax.scan for looping layers to get faster compilation -SCAN_LAYER_MANUAL = 'scan_manual' # jax.lax.scan AND shmap layers to get faster compilation - - -def main_one( - use_flash_attention=True, - seqlen=8196, - n_layers=32, - batch_size=8, - checkpoint_dir=default_checkpoint_dir, - mode='regular', - use_editted_model=False, -): - logging.getLogger("jax").setLevel(logging.DEBUG) - print(f"Running with parameters {locals()}") - utils.SEQLEN = seqlen - utils.BATCH = batch_size - env = torchax.default_env() - env.config.use_tpu_flash_attention = use_flash_attention - cfg = config.Config.from_name("Meta-Llama-3-8B") - cfg.n_layer = n_layers - #cfg.n_layer = 32 - if use_editted_model: - gpt = editted_model.GPT(cfg) - else: - gpt = model.GPT(cfg) - gpt.to(torch.bfloat16) - - env.config.shmap_flash_attention = mode != SCAN_LAYER_MANUAL - use_fori = False - - if mode in (SCAN_LAYER, SCAN_LAYER_MANUAL): - gpt = GPTFori(gpt, mode == SCAN_LAYER_MANUAL) - use_fori = True - elif mode == OUTLINED_LAYER: - gpt = GPTOutline(gpt) - - light_mod = GPTLightningModule(gpt) - tokenizer = Tokenizer(checkpoint_dir) - data = Alpaca(num_workers=1) - data.connect( - tokenizer=tokenizer, batch_size=batch_size, max_seq_length=utils.SEQLEN) - data.prepare_data() - data.setup() - train_loader = data.train_dataloader() - - with mesh: - trainer = utils.JaxTrainer(use_fori) - if use_fori: - return trainer.fit_model_fori(gpt, train_loader) - else: - return trainer.fit(light_mod, train_loader) - - -def main( - use_flash_attention=True, - seqlen=8196, - n_layers=32, - batch_size=8, - checkpoint_dir=default_checkpoint_dir, - mode='regular', - use_editted_model=False, -): - if mode == 'all': - from jaxlib.xla_extension import XlaRuntimeError - res = [] - for m, editted in ((SCAN_LAYER, False), (SCAN_LAYER_MANUAL, False), - (REGULAR, False), (REGULAR, True), (OUTLINED_LAYER, - True)): - try: - run_time, comp_time = main_one( - use_flash_attention, - seqlen, - n_layers, - batch_size, - checkpoint_dir, - m, - use_editted_model=editted, - ) - res.append((m, editted, run_time, comp_time)) - except XlaRuntimeError as e: - import traceback - traceback.print_exc() - res.append((m, editted, 'OOM', '')) - for m, e, r, c in res: - print(f'{m}-edit={e}: \t {r} \t {c} ') - - else: - main_one(use_flash_attention, seqlen, n_layers, batch_size, checkpoint_dir, - mode, use_editted_model) - - -if __name__ == '__main__': - import fire - fire.Fire(main) diff --git a/torchax/examples/train_llama/utils.py b/torchax/examples/train_llama/utils.py deleted file mode 100644 index 2d6daa74a388..000000000000 --- a/torchax/examples/train_llama/utils.py +++ /dev/null @@ -1,318 +0,0 @@ -from typing import Tuple -import time -import torchax -from torchax.interop import jax_view, torch_view, JittableModule -import jax -import optax -import jax -import jax.numpy as jnp -from jax.experimental import mesh_utils -import functools -from torch.utils import _pytree as pytree - -Mesh = jax.sharding.Mesh -P = jax.sharding.PartitionSpec - -SEQLEN = 8192 -BATCH = 8 -global_axis: Tuple[str, str] = ('fsdp',) -num_global_devices = jax.device_count() -num_local_devices = jax.local_device_count() -num_partitions = (num_global_devices,) -#SEQLEN = 512 - -import torch - - -def group_data(dataloader, block_size): - """yields tuple of inputs, label with seqlen == block_size""" - - tally = 0 - inputs = [] - labels = [] - - for line in dataloader: - x, y = line['input_ids'], line['labels'] - inputs.append(x) - labels.append(y) - batch, seqlen = x.shape - tally += seqlen - if tally > block_size: - inputs_stacked = torch.concat(inputs, dim=-1) - inputs_stacked.resize_((batch, block_size)) - labels_stacked = torch.concat(labels, dim=-1) - labels_stacked.resize_((batch, block_size)) - yield inputs_stacked, labels_stacked - tally = 0 - inputs = [] - labels = [] - - -def sharded_device_put(tensor, sharding): - if isinstance(tensor, tuple): - return tuple(sharded_device_put(t, sharding) for t in tensor) - - if num_global_devices == num_local_devices: - return jax.device_put(tensor, sharding) - - shape = tensor.shape - x_split = [ - jax.device_put(tensor[i], device) - for device, i in sharding.addressable_devices_indices_map(shape).items() - ] - return jax.make_array_from_single_device_arrays(shape, sharding, x_split) - - -class FSDPv2(torch.nn.Module): - - def __init__(self, mod): - super().__init__() - self.mod = mod - self.mesh = jax.sharding.Mesh( - mesh_utils.create_device_mesh(num_partitions), - axis_names=global_axis, - ) - self.sharding = jax.sharding.NamedSharding(self.mesh, P(*global_axis)) - - def forward(self, *args): - args = list(args) - args[0] = self.shard(args[0]) - res = self.mod(*args) - return self.shard(res) - - def shard(self, x): - return torchax.interop.call_jax( - jax.lax.with_sharding_constraint, - x, - self.sharding, - ) - - -def print_shapes(pyt): - for p in pytree.tree_flatten(pyt)[0]: - if hasattr(p, 'shape'): - print(p.shape, p.dtype) - - -class JaxTrainer: - - def __init__(self, use_fori): - self.use_fori = use_fori - self.mesh = jax.sharding.Mesh( - mesh_utils.create_device_mesh(num_partitions), - axis_names=global_axis, - ) - self.x_sharding = jax.sharding.NamedSharding(self.mesh, P(global_axis)) - self.y_sharding = jax.sharding.NamedSharding(self.mesh, P(*global_axis)) - self.replicated = jax.sharding.NamedSharding(self.mesh, P()) - - def torch_opt_to_jax_opt(self, torch_opt): - # TODO: Can convert optimizer instead of using a jax one - return optax.adamw(0.01) - - def fit_model_fori(self, gpt_mod, data_loader): - xla_env = torchax.default_env() - jax.config.update('jax_enable_x64', False) - xla_env._mesh = self.mesh - xla_env.use_flash_attention = True - - weights = gpt_mod.weights - - jax_params = {} - for k, v in weights.items(): - sharding = self.y_sharding if k == 'block' else self.x_sharding - print(k, sharding) - jax_params[k] = self._shard_fsdp_style(v, sharding) - - print('ALL weights ===') - for x in jax.tree_util.tree_flatten(jax_params)[0]: - print(x.shape, x.sharding) - print(' ===') - - @jax.checkpoint - def loss(jax_params, data): - data = jax.lax.with_sharding_constraint(data, self.x_sharding) # fsdpv2 - x, y = data - res = torchax.interop.call_torch(gpt_mod.forward_with_weights, jax_params, - x) - res = jax.lax.with_sharding_constraint(res, self.x_sharding) - return jnp.mean( - optax.losses.softmax_cross_entropy_with_integer_labels(res, y)) - - grad_fn = jax.value_and_grad(loss) - jax_optimizer = optax.adamw(0.01) - opt_state = jax_optimizer.init(jax_params) - - @functools.partial(jax.jit, donate_argnums=(0, 1)) - def step(jax_weights, opt_state, data): - with jax.named_scope('compute_gradient'): - loss, gradient = grad_fn(jax_weights, data) - with jax.named_scope("optimizer_updates"): - updates, opt_state = jax_optimizer.update(gradient, opt_state, - jax_weights) - jax_weights = optax.apply_updates(jax_weights, updates) - return loss, jax_weights, opt_state - - print('Start compiling') - start = time.perf_counter() - lowered = step.lower( - jax_params, - opt_state, - (jax.ShapeDtypeStruct( - (BATCH, SEQLEN), jnp.dtype('int32'), sharding=self.x_sharding), - jax.ShapeDtypeStruct( - (BATCH, SEQLEN), jnp.dtype('int32'), sharding=self.x_sharding)), - ) - # print(lowered.as_text()) - print('program size:', len(lowered.as_text()) / 1e6, 'm chars') - step_compiled = lowered.compile() - end = time.perf_counter() - print('End compiling', end - start) - compile_time = end - start - - for co in step_compiled.cost_analysis(): - print('flops counter:', co['flops']) - - s = time.perf_counter() - jax.profiler.start_trace('/tmp/tensorboard') - print('start training') - min_loop_time = 10000 - for i, item in enumerate(group_data(data_loader, SEQLEN)): - inputs, labels = sharded_device_put( - jax_view(xla_env.to_xla(item)), self.x_sharding) - print('INPUT shape', inputs.shape) - - step_start = time.perf_counter() - loss, jax_params, opt_state = step_compiled(jax_params, opt_state, - (inputs, labels)) - jax.block_until_ready((loss, jax_params)) - step_end = time.perf_counter() - print(i, 'loss', loss, 'step latency: ', step_end - step_start) - min_loop_time = min(min_loop_time, step_end - step_start) - print('======') - if i >= 3: - break - jax.profiler.stop_trace() - return min_loop_time, compile_time - - def _shard_fsdp_style(self, state_dict, sharding=None): - if sharding is None: - sharding = self.x_sharding - - def move_one_tensor(x): - env = torchax.default_env() - jval = env.t2j_copy(x) - return sharded_device_put(jval, sharding) - - if isinstance(state_dict, torch.Tensor): - return move_one_tensor(state_dict) - res = {} - for k, v in sorted(state_dict.items()): - res[k] = move_one_tensor(v) - return res - - def fit(self, lightning_mod, data_loader): - - xla_env = torchax.default_env() - jax.config.update('jax_enable_x64', False) - xla_env._mesh = self.mesh - xla_env.use_flash_attention = True - - jittable_mod = JittableModule(lightning_mod) - jax_params = self._shard_fsdp_style(jittable_mod.params) - jax_buffers = self._shard_fsdp_style(jittable_mod.buffers) - - @jax.checkpoint - def lightning_mod_loss(weights: jax.Array, buffers: jax.Array, - data: jax.Array, batch_id): - """returns loss""" - with jax.named_scope("Computing_loss"): - weights, buffers, data = torch_view((weights, buffers, data)) - # NOTE: these is needed because the original model - # did not register those as persistent buffer - with xla_env: - loss = jittable_mod.functional_call('training_step', weights, buffers, - data, batch_id) - return jax_view(loss) - - jax_optimizer = self.torch_opt_to_jax_opt( - lightning_mod.configure_optimizers()) - - opt_state = jax_optimizer.init(jax_params) - grad_fn = jax.value_and_grad(lightning_mod_loss) - - opt_state_sharding = jax.tree_util.tree_map(lambda p: p.sharding, opt_state) - - print('Begining training') - - # NOTE: explicitly set sharding so the sharding of opt_state wont change - # if it changes, it would trigger recompile - @functools.partial( - jax.jit, - donate_argnums=(0, 2), - #in_shardings=(self.x_sharding, self.x_sharding, opt_state_sharding, self.x_sharding, self.replicated), - #out_shardings=(self.replicated, self.x_sharding, opt_state_sharding), - ) - def step(jax_weights, jax_buffers, optimizer_state, xla_data, bid): - print('Tracing inside of step') - with jax.named_scope("Computing_loss_and_grad"): - loss, grads = grad_fn(jax_weights, jax_buffers, xla_data, bid) - with jax.named_scope("optimizer_updates"): - updates, opt_state = jax_optimizer.update(grads, optimizer_state, - jax_weights) - jax_weights = optax.apply_updates(jax_weights, updates) - return loss, jax_weights, opt_state - - total_param_size = 0 - for k, v in jax_params.items(): - total_param_size += v.size - - print('Total number of params: ', total_param_size) - # print(jax.jit(jax.grad(lightning_mod_loss)).lower( - # jax_params, jax_buffers, - # (jax.ShapeDtypeStruct((8, SEQLEN), jnp.dtype('int32')), - # jax.ShapeDtypeStruct((8, SEQLEN), jnp.dtype('int32'))), - # 0 - # ).as_text()) - - print('Start compiling') - start = time.perf_counter() - lowered = step.lower( - jax_params, jax_buffers, opt_state, - (jax.ShapeDtypeStruct( - (8, SEQLEN), jnp.dtype('int32'), sharding=self.x_sharding), - jax.ShapeDtypeStruct( - (8, SEQLEN), jnp.dtype('int32'), sharding=self.x_sharding)), 0) - # print(lowered.as_text()) - print('program size:', len(lowered.as_text()) / 1e6, 'm chars') - step_compiled = lowered.compile() - end = time.perf_counter() - compile_time = end - start - print('End compiling', compile_time) - - for co in step_compiled.cost_analysis(): - print('flops counter:', co['flops']) - - s = time.perf_counter() - jax.profiler.start_trace('/tmp/tensorboard') - print('start training') - min_loop_time = 10000 - for i, item in enumerate(group_data(data_loader, SEQLEN)): - inputs, labels = sharded_device_put( - jax_view(xla_env.to_xla(item)), self.x_sharding) - print('INPUT shape', inputs.shape) - - step_start = time.perf_counter() - loss, jax_params, opt_state = step_compiled(jax_params, jax_buffers, - opt_state, (inputs, labels), - 0) - jax.block_until_ready((loss, jax_params)) - step_end = time.perf_counter() - print(i, 'loss', loss, 'step latency: ', step_end - step_start) - loop_time = step_end - step_start - min_loop_time = min(min_loop_time, loop_time) - print('======') - if i >= 2: - break - jax.profiler.stop_trace() - return min_loop_time, compile_time diff --git a/torchax/examples/train_llama_torchtitan/Dockerfile b/torchax/examples/train_llama_torchtitan/Dockerfile deleted file mode 100644 index f1f5575f247f..000000000000 --- a/torchax/examples/train_llama_torchtitan/Dockerfile +++ /dev/null @@ -1,34 +0,0 @@ -# syntax=docker/dockerfile:experimental -# Use Python 3.10 as the base image -FROM python:3.10-slim-bullseye - -# Install system dependencies -RUN apt-get update && apt-get upgrade -y -RUN apt-get update && apt-get install -y curl gnupg - -# Add the Google Cloud SDK package repository -RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list -RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - - -# Install the Google Cloud SDK -RUN apt-get update && apt-get install -y google-cloud-sdk git - -# Set the default Python version to 3.10 -RUN update-alternatives --install /usr/bin/python3 python3 /usr/local/bin/python3.10 1 -RUN pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -RUN pip install optax fire tensorflow tensorboard-plugin-profile -RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu - -WORKDIR / -RUN git clone https://github.com/pytorch/torchtitan.git -WORKDIR /torchtitan -RUN pip install -r requirements.txt -RUN pip install . - -WORKDIR / -RUN git clone https://github.com/pytorch/xla.git -WORKDIR xla/experimental/torchax -RUN pip install -e . - -ENTRYPOINT ["python", "examples/train_llama_torchtitan/train_llama.py"] -CMD ["--batch_size=8", "--seqlen=2048"] \ No newline at end of file diff --git a/torchax/examples/train_llama_torchtitan/README.md b/torchax/examples/train_llama_torchtitan/README.md deleted file mode 100644 index 7bb58a040332..000000000000 --- a/torchax/examples/train_llama_torchtitan/README.md +++ /dev/null @@ -1,511 +0,0 @@ -Training based on torchtitan llama model -==================================== - -This examples demonstrates how we can make a model implemented for single device -run on multiple devices without modifying the model itself. - -We choose [torchtitan's llama implementation](https://github.com/pytorch/torchtitan/tree/main/torchtitan/models/llama); -because torchtitan's model implementation is a clean single device version. (Not those -sprinkled with `ColumnParallelLinear`'s from megatron). torchtitan accomplishes running -single device model code in multi-device environment through module-swaps, and we accomplishes -the same with gSPMD. - - - -## Install dependencies - -```bash -pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -pip install optax fire tensorflow tensorboard-plugin-profile -pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu - -cd ~ -git clone https://github.com/pytorch/torchtitan.git -cd torchtitan -pip install -r requirements.txt -pip install . - -cd ~ -git clone https://github.com/pytorch/xla.git -cd xla/experimental/torchax -pip install -e . -``` - -(Optional) Export libtpu flags that helps with performance -```bash -export LIBTPU_INIT_ARGS="--xla_tpu_use_minor_sharding_for_major_trivial_input=true --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 --xla_tpu_scoped_vmem_limit_kib=98304 --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" -``` -NOTE: these flags are copied from https://github.com/AI-Hypercomputer/maxtext/blob/main/MaxText/configs/trillium/llama2_70b_4096.sh -Tested locally on v6e-8 doesnt seems to make a difference. - -```bash -cd ~/xla/experimental/torchax/examples/train_llama_torchtitan -python train_llama.py --seqlen=8192 -``` - -## Detailed Code walkthrough: - -Below is the copy & paste of `train_llama.py` and annotated with what they do: - -```python -import os -import time -import logging -from typing import Tuple -from collections import defaultdict -import functools -import torch -import torch.nn.functional -from torch.utils import _pytree as pytree -import splash_attn -import helper - -import torchax as tx -import torchax.interop -import torchax.train -from torchax.interop import jax_view, torch_view, JittableModule -import jax -import jax.numpy as jnp -from jax.experimental import shard_map -from jax.experimental import mesh_utils -from jax.sharding import NamedSharding -import optax -``` -Above is just regular imports, uninteresting - -```python -from torchtitan.models.llama import llama3_configs -from torchtitan.models.llama import model as titan -``` -Above is importing the model and model config from torchtitan directly. -i.e. we don't need to modify the model code at all (note, there are caveats, keep reading). - -```python -P = jax.sharding.PartitionSpec -num_global_devices = jax.device_count() -num_local_devices = jax.local_device_count() -``` -This bit above are some aliases - - -```python -def sharded_device_put(tensor: jax.Array, sharding) -> jax.Array: - if isinstance(tensor, tuple): - return tuple(sharded_device_put(t, sharding) for t in tensor) - - if num_global_devices == num_local_devices: - return jax.device_put(tensor, sharding) - - # NOTE: at here, num_global_devices != num_local_devices - # meaning we are in multi-host setup. Each host will run the same process - # and each process only need to handle the devices accessible to this host. - shape = tensor.shape - x_split = [jax.device_put(tensor[i], device) - for device, i in sharding.addressable_devices_indices_map(shape).items()] - return jax.make_array_from_single_device_arrays(shape, sharding, x_split) -``` - -When running on single-host, `jax.device_put` suffices. Multi-host need some -extra incantations so that we split an array to only the shards corresponding -to the accessible devices in this host. - - -```python -sharding_map_original = { - "freqs_cis" : (), # torch.complex64 (2048, 64) - "tok_embeddings.weight" : ('fsdp', 'tp'), # torch.float32 (vocab_size, 4096) - "layers.*.attention.wo.weight" : ('fsdp', 'tp'), # torch.int8 (4096, 4096) - "layers.*.attention.wq.weight" : ('tp', 'fsdp'), # torch.int8 (4096, 4096) - "layers.*.attention.wk.weight" : ('tp', 'fsdp'), # torch.int8 (4096, 4096) - "layers.*.attention.wv.weight" : ('tp', 'fsdp'), # torch.int8 (4096, 4096) - "layers.*.feed_forward.w1.weight" : ('tp', 'fsdp'), # torch.float32 (11008, 4096) - "layers.*.feed_forward.w2.weight" : ('fsdp', 'tp'), # torch.float32 (4096, 11008) - "layers.*.feed_forward.w3.weight": ('tp', 'fsdp'), # torch.float32 (11008, 4096) - "layers.*.attention_norm.weight" : ('fsdp', ), # torch.float32 (4096,) - "layers.*.ffn_norm.weight" : ('fsdp', ), # torch.float32 (4096,) - "norm.weight" : ('fsdp', ), # torch.float32 (4096,) - "output.weight" : ('tp', 'fsdp'), # torch.float32 (vocab_size, 4096) -} - -sharding_map_scan = { - "freqs_cis" : (), # torch.complex64 (2048, 64) - # ParallelEmbedding for llama2; VocabParallelEmbedding for 3 - "tok_embeddings.weight" : ('tp', 'fsdp'), # torch.float32 (vocab_size, 4096) - "layers.params.attention___wo___weight" : (None, 'fsdp', 'tp'), # torch.int8 (n, 4096, 4096) - "layers.params.attention___wq___weight" : (None, 'tp', 'fsdp'), # torch.int8 (n, 4096, 4096) - "layers.params.attention___wk___weight" : (None, 'tp', 'fsdp'), # torch.int8 (n, 4096, 4096) - "layers.params.attention___wv___weight" : (None, 'tp', 'fsdp'), # torch.int8 (n, 4096, 4096) - "layers.params.feed_forward___w1___weight" : (None, 'tp', 'fsdp'), # torch.float32 (n, 11008, 4096) - "layers.params.feed_forward___w2___weight" : (None, 'fsdp', 'tp'), # torch.float32 (n, 4096, 11008) - "layers.params.feed_forward___w3___weight": (None, 'tp', 'fsdp'), # torch.float32 (n, 11008, 4096) - "layers.params.attention_norm___weight" : (None, 'fsdp', ), # torch.float32 (n, 4096,) - "layers.params.ffn_norm___weight" : (None, 'fsdp', ), # torch.float32 (n, 4096,) - "norm.weight" : ('fsdp', ), # torch.float32 (4096,) - "output.weight" : ('tp', 'fsdp'), # torch.float32 (vocab_size, 4096) -} - -sharding_map_scan_fsdp = { - "freqs_cis" : (), # torch.complex64 (2048, 64) - # ParallelEmbedding for llama2; VocabParallelEmbedding for 3 - "tok_embeddings.weight" : ('fsdp',), # torch.float32 (vocab_size, 4096) - "layers.params.attention___wo___weight" : (None, 'fsdp'), # torch.int8 (n, 4096, 4096) - "layers.params.attention___wq___weight" : (None, 'fsdp'), # torch.int8 (n, 4096, 4096) - "layers.params.attention___wk___weight" : (None, 'fsdp'), # torch.int8 (n, 4096, 4096) - "layers.params.attention___wv___weight" : (None, 'fsdp'), # torch.int8 (n, 4096, 4096) - "layers.params.feed_forward___w1___weight" : (None, 'fsdp'), # torch.float32 (n, 11008, 4096) - "layers.params.feed_forward___w2___weight" : (None, 'fsdp'), # torch.float32 (n, 4096, 11008) - "layers.params.feed_forward___w3___weight": (None, 'fsdp'), # torch.float32 (n, 11008, 4096) - "layers.params.attention_norm___weight" : (None, 'fsdp', ), # torch.float32 (n, 4096,) - "layers.params.ffn_norm___weight" : (None, 'fsdp', ), # torch.float32 (n, 4096,) - "norm.weight" : ('fsdp', ), # torch.float32 (4096,) - "output.weight" : ('fsdp', ), # torch.float32 (vocab_size, 4096) -} -``` - -The above are different sharding schemes. Because we are using gSPMD, we need some -mechanism of sharding the weights. Because we don't (can't) modify the model code -itself, we can just use a dictionary of names to keep that information - - -```python -class Trainer: - - def __init__(self, mesh): - self.mesh = mesh - self.x_sharding = jax.sharding.NamedSharding(self.mesh, P('fsdp')) - self.replicated = jax.sharding.NamedSharding(self.mesh, P()) - - def fit(self, model, loss_fn, data_loader): - xla_env = torchax.default_env() - jax.config.update('jax_enable_x64', False) - xla_env._mesh = self.mesh - xla_env.use_flash_attention = True - - jittable_mod = JittableModule(model) - - # split the params to the n devices - - def model_fn(weights, buffers, args): - return jittable_mod.functional_call('forward', weights, buffers, args) - - - jax_optimizer = optax.sgd(0.01) - opt_state = torch_view(jax_optimizer.init(jax_view(jittable_mod.params))) - - train_step = torchax.train.make_train_step( - model_fn, loss_fn, jax_optimizer, - remat_policy=jax.checkpoint_policies.offload_dot_with_no_batch_dims('device', 'pinned_host'), - mark_fsdp_sharding_axis='fsdp') - - print('Begining training') - s = time.perf_counter() - jax.profiler.start_trace('/tmp/tensorboard') - print('start training') - min_loop_time = 10000 - for i, item in enumerate(data_loader): - inputs, labels = item - # Move them to jax device - inputs = inputs.to('jax') - labels = labels.to('jax') - - # Shard them on batch dim for fsdp - inputs.apply_jax_(sharded_device_put, self.x_sharding) - labels.apply_jax_(sharded_device_put, self.x_sharding) - - if i == 0: - train_step = helper.compile_step_func( - train_step, - jittable_mod.params, jittable_mod.buffers, opt_state, inputs, labels, - self.mesh - ) - - print('INPUT shape', inputs.shape) - step_start = time.perf_counter() - loss, jittable_mod.params, opt_state = train_step( - jittable_mod.params, jittable_mod.buffers, opt_state, inputs, labels) - # wait for iteration to finish to measure time - torchax.interop.call_jax(jax.block_until_ready, (loss, jittable_mod.params)) - step_end = time.perf_counter() - print(i, 'loss', loss, 'step latency: ', step_end - step_start) - loop_time = step_end - step_start - min_loop_time = min(min_loop_time, loop_time) - print('======') - if i >= 3: - break - jax.profiler.stop_trace() - return min_loop_time -``` - -The trainer class is the training loop. -Few things to note: - -1. The training loop is something that calls a `train_step` repeatedly. - The `train_step` is function that maps (weights, buffer, optimizer_state, inputs, labels) - to (loss, updated weight, updated optimizer state). Returning loss is not needed - only there for printing out. The buffer argument is the non-trainable paramters, in - our case, it holds the `freqs_cis` variable - - The `train_step` is roughly equivalent to the follwoing: - - ```python - def train_step(weights, buffer, optimizer_state, inputs, label): - optimizer = recreate optimizer from optimizer_state - state_dict = weights + buffer - result = torch.func.functional_call(model, state_dict, inputs) - loss = loss_fn(result, label) - loss.backward() - optimizer.step() - return loss, model.paramters(), optimizer.state_dict() - ``` - -2. Here we are using a fake dataloader. - -3. We are calling `jax.block_until_ready` to measure iteration time, this is not needed - for real training jobs - -4. We use `jax.profiler` to capture profiles. Tools listed in here: https://jax.readthedocs.io/en/latest/profiling.html - all works out of the box. - -5. `interop.call_jax` API is used whenever we need something from Jax. Those API can be - wrapped and have the "jaxiness" hidden. However, I don't think we need to do such hidding. - -6. Precompile: call to `helpers.compile_step_func`. This is not needed. If not used, then - it will compile on the first invokation. However, triggering compilation manually - allows to print some stats (such as GBs accessed), also will error if the input shape - / layout / sharding changed in the future iterations. For example I got the below while developing: - ``` - ValueError: Received incompatible devices for jitted computation. Got argument args[0]['layers.params.attention___wk___weight'] of with shape bfloat16[32,1024,4096] and device ids [0, 2, 4, 6, 1, 3, 5, 7] on platform TPU and explicit output sharding with device ids [0] on platform TPU - ``` - this tells me the sharding I specified was wrong and I would go back and fix. - - - - -```python -def _process_sharding_name(name): - """Replace integers in param name with *. - - Presumably all layers should have the same sharding. - """ - - def is_integer(t): - try: - int(t) - return True - # pylint: disable-next=all - except: # noqa: E722 - return False - - tokens = name.split(".") - for i, t in enumerate(tokens): - if is_integer(t): - tokens[i] = "*" - return ".".join(tokens) -``` -This is a helper to process names in sharding map - - -```python -def create_sharded_weights(model, mesh, sharding_map): - res = {} - env = torchax.default_env() - for name, weight_meta in model.state_dict().items(): - sharding_spec = sharding_map.get(_process_sharding_name(name)) - if sharding_spec is None: - print('Skipping weight:', name) - continue - sharding = NamedSharding(mesh, P(*sharding_spec)) - with jax.default_device(jax.devices('cpu')[0]): - weight_torch = torch.randn( - weight_meta.shape, - dtype=weight_meta.dtype) - weight_jax = torchax.default_env().to_xla(weight_torch).jax() - #print(name, weight.shape, weight.dtype) - res[name] = env.j2t_iso(jax.make_array_from_callback( - weight_jax.shape, sharding, lambda a: weight_jax[a] - )) - return res -``` -The strategy of not OOMing the host on larger scale training: -allocate the model on meta device, then re-initialize weights one by one, -shard the weight immediately after creation. - - -```python -def fake_dataloader(size, seqlen, batch_size): - for _ in range(size): - x = torch.randint(0, 32000, (batch_size, seqlen), device='cpu') - yield x, (x + 1) % 32000 -``` - -Fake dataloader, just create random ints of desired shape. - - -Then the below is the `main` function. I will split it into pieces for better commenting - -```python -def main( - model_type='8B', - batch_size=8, - seqlen=2048, - override_num_layers=-1, - use_scan = True, - tp_parallelism=1, -): - torchax.enable_globally() - torchax.enable_performance_mode() - #logging.getLogger("jax").setLevel(logging.DEBUG) - print(f"Running with parameters {locals()}") - - fsdp = num_global_devices // tp_parallelism - mesh = jax.make_mesh((fsdp, tp_parallelism), ('fsdp', 'tp')) -``` -Above, the config is set to run either fsdp only or also with tensor parallelism. -If using tp (i.e. passing `tp_parallelism > 1`) then the global devices will be -split into fsdp x tp 2D array. Tensors will be sharded on those 2 axis - -```python - if use_scan: - # using scan the individial weights will have shape (num_layers, w, h) - sharding_map = sharding_map_scan_fsdp - else: - sharding_map = sharding_map_original -``` -Scan is implemented as the `TransformerWithScan` below. - -```python - env = torchax.default_env() - env.config.use_tpu_flash_attention = True - env.config.shmap_flash_attention = True - env._mesh = mesh # this is the mesh used by flash attention pallas kernel -``` -this bit tells TX to use flash_attention implemented in pallas. Because pallas is -single device by default, we apply `jax.shard_map` with a mesh. - -```python - args = llama3_configs[model_type] - # Note: torchtitan's upstream config did not specify this value - args.vocab_size = 128256 - args.max_seq_len = seqlen - if override_num_layers > 0: - args.n_layers = override_num_layers - - # Note: because a single device don't have enough HBM memory - # nor enough CPU memory to hold the parameters. We instantiate - # the model on meta then manually initialize then shard each param - torch.set_default_dtype(torch.bfloat16) - with torch.device('meta'): - gpt = titan.Transformer(args) -``` -Above, instantiate the model on meta device so no OOM. - -```python - with torch.device('cpu'): - # need actual value for freqs_cis - freqs_cis = gpt._precompute_freqs_cis() -``` -Compute freqs_cis on CPU because we actually need its value. - -```python - if use_scan: - checkpoint_policy=jax.checkpoint_policies.offload_dot_with_no_batch_dims('device', 'pinned_host') - gpt = TransfomerWithScan(gpt, checkpoint_policy) - - state_dict = dict(gpt.state_dict()) - state_dict.pop('freqs_cis') # dont shard freqs_cis - state_dict = create_sharded_weights(gpt, mesh, sharding_map) - replicated = jax.sharding.NamedSharding(mesh, P()) - - state_dict['freqs_cis'] = freqs_cis.to('jax').apply_jax(jax.device_put, replicated) - gpt.load_state_dict(state_dict, assign=True) - - train_loader = fake_dataloader(10, seqlen, batch_size) -``` -Put the sharded arrays inside of XLATensor back to the model with `load_state_dict` - -```python - # NOTE: overriding attention to capture mesh and sharding info - partition = P('fsdp', 'tp', None, None) - attention = functools.partial( - splash_attn.tpu_splash_attention, - mesh, partition, True) - attention = jax.jit(attention) - - def custom_attention( - query, key, value, attn_mask=None, - dropout_p=0.0, is_causal=False, - scale=None, enable_gqa=False): - # batch, num of head, seq, dim - jk, jq, jv = jax_view((query, key, value)) - res = attention(jk, jq, jv, None) - return torch_view(res) - env.override_op_definition(torch.nn.functional.scaled_dot_product_attention, custom_attention) -``` -Above, this bit is to showcase the "hackability": User can override the definition -of an torch op at runtime, and user's version will be invoked. Here I am using -jax pallas implementation of `splash_attention` i.e. sparse flash attention. -Note this can be done without modifying the model at all. -All ops that are `__torch_function__` capturable or `__torch_dispatch__` capturable are -eligible to be overriden. - -```python - def loss_fn(logits, y): - num_tokens = logits.shape[-1] - logits = logits.reshape(-1, num_tokens) - y = y.reshape(-1) - return torch.nn.functional.cross_entropy( - logits, y) -``` -Standard torch loss function. Needed reshape because `cross_entropy` only work with one -batch dim (not both batch and sequence) - -```python - with mesh: - trainer = Trainer(mesh) - return trainer.fit( - gpt, - loss_fn, - train_loader - ) -``` -Invoking the traininer. - - -```python -class TransfomerWithScan(torch.nn.Module): - - def __init__(self, old_transformer, checkpoint_policy): - super().__init__() - self.tok_embeddings = old_transformer.tok_embeddings - self.norm = old_transformer.norm - self.output = old_transformer.output - self.layers = torchax.train.ScannedModule(list(old_transformer.layers.values()), checkpoint_policy) - - self.register_buffer('freqs_cis', old_transformer.freqs_cis) - - def forward(self, tokens: torch.Tensor): - """ - Perform a forward pass through the Transformer model. - - Args: - tokens (torch.Tensor): Input token indices. - - Returns: - torch.Tensor: Output logits after applying the Transformer model. - - """ - # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages - h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens - - # for layer in self.layers.values(): - # h = layer(h, self.freqs_cis) - - h = self.layers(h, self.freqs_cis) - - h = self.norm(h) if self.norm else h - output = self.output(h) if self.output else h - return output -``` -The goal of this class is to replace the for loop that iterate the layers with -a loop with scan. The use of scan is encaptured in `ScanedModule`. This class -is to override the `forward` to call `ScannedModule` instead of calling it in a loop. diff --git a/torchax/examples/train_llama_torchtitan/__init__.py b/torchax/examples/train_llama_torchtitan/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/torchax/examples/train_llama_torchtitan/helper.py b/torchax/examples/train_llama_torchtitan/helper.py deleted file mode 100644 index 783e392328a9..000000000000 --- a/torchax/examples/train_llama_torchtitan/helper.py +++ /dev/null @@ -1,37 +0,0 @@ -import time -import jax -from jax.tree_util import tree_map -from jax.sharding import NamedSharding -from torchax import interop - -P = jax.sharding.PartitionSpec - - -def compile_step_func(step, weights, buffers, opt_state, args, label, mesh): - step, weights, buffers, opt_state, args, label = interop.jax_view( - (step, weights, buffers, opt_state, args, label)) - wshardings = tree_map( - lambda a: a.sharding if isinstance(a, jax.Array) else None, weights) - bshardings = tree_map( - lambda a: a.sharding if isinstance(a, jax.Array) else None, buffers) - oshardings = tree_map( - lambda a: a.sharding if isinstance(a, jax.Array) else None, opt_state) - print('Start compiling') - start = time.perf_counter() - lowered = jax.jit( - step, - donate_argnums=(0, 2), - #in_shardings=shardings, - out_shardings=(NamedSharding(mesh, P()), wshardings, oshardings), - ).lower(weights, buffers, opt_state, args, label) - #print(lowered.as_text()) - # import pdb; pdb.set_trace() - print('program size:', len(lowered.as_text()) / 1e6, 'm chars') - step_compiled = lowered.compile() - end = time.perf_counter() - print('End compiling', end - start) - compile_time = end - start - for co in step_compiled.cost_analysis(): - print('Flops', co['flops']) - print('GB accessed', co['bytes accessed'] / 1e9) - return interop.torch_view(step_compiled) diff --git a/torchax/examples/train_llama_torchtitan/splash_attn.py b/torchax/examples/train_llama_torchtitan/splash_attn.py deleted file mode 100644 index 0cbc657ef373..000000000000 --- a/torchax/examples/train_llama_torchtitan/splash_attn.py +++ /dev/null @@ -1,101 +0,0 @@ -import functools - -import jax -import jax.numpy as jnp - -from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel -from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask -from jax.experimental.shard_map import shard_map - - -def tpu_splash_attention( - mesh, - q_sharding, - # Input should be of shape (batch, length, heads, kv_dim) - apply_shard_map, - query: jax.Array, - key: jax.Array, - value: jax.Array, - decoder_segment_ids: jax.Array | None, - attn_logits_soft_cap: float | None = None, -) -> jax.Array: - """TPU Flash Attention.""" - if decoder_segment_ids is not None: - decoder_segment_ids = splash_attention_kernel.SegmentIds( - decoder_segment_ids, decoder_segment_ids) - - print('HERE', locals()) - - global_block_q = 1024 - global_block_kv = 512 - global_block_kv_compute = 512 - global_block_q_dkv = 2048 - global_block_kv_dkv = 512 - global_block_kv_dkv_compute = 512 - global_block_q_dq = 2048 - global_block_kv_dq = 512 - global_use_fused_bwd_kernel = False - global_q_layout = 'HEAD_DIM_MINOR' - global_k_layout = 'HEAD_DIM_MINOR' - global_v_layout = 'HEAD_DIM_MINOR' - - def wrap_flash_attention(query, key, value, decoder_segment_ids): - if decoder_segment_ids is not None: - assert ( - query.shape[2] == decoder_segment_ids.q.shape[1] - ), "Sharding along sequence dimension not allowed in tpu kernel attention" - block_sizes = splash_attention_kernel.BlockSizes( - block_q=min(global_block_q, query.shape[2]), - block_kv=min(global_block_kv, key.shape[2]), - block_kv_compute=min(global_block_kv_compute, key.shape[2]), - block_q_dkv=min(global_block_q_dkv, query.shape[2]), - block_kv_dkv=min(global_block_kv_dkv, key.shape[2]), - block_kv_dkv_compute=min(global_block_kv_dkv_compute, query.shape[2]), - block_q_dq=None if global_use_fused_bwd_kernel else min( - global_block_q_dq, query.shape[2]), - block_kv_dq=None if global_use_fused_bwd_kernel else min( - global_block_kv_dq, query.shape[2]), - use_fused_bwd_kernel=global_use_fused_bwd_kernel, - q_layout=splash_attention_kernel.QKVLayout[global_q_layout], - k_layout=splash_attention_kernel.QKVLayout[global_k_layout], - v_layout=splash_attention_kernel.QKVLayout[global_v_layout], - ) - - mask = splash_attention_mask.CausalMask( - shape=(query.shape[2], query.shape[2])) - - # Create multi-head mask - multi_head_mask = splash_attention_mask.MultiHeadMask( - masks=(mask,) * query.shape[1]) - #splash_kernel = splash_attention_kernel.make_splash_mha( - splash_kernel = splash_attention_kernel.make_splash_mha( - mask=multi_head_mask, - head_shards=1, - q_seq_shards=1, - block_sizes=block_sizes, - attn_logits_soft_cap=attn_logits_soft_cap, - ) - - return jax.vmap(splash_kernel)( - query, key, value, segment_ids=decoder_segment_ids) - - if apply_shard_map: - wrap_flash_attention = shard_map( - wrap_flash_attention, - mesh=mesh, - in_specs=( - q_sharding, - q_sharding, - q_sharding, - None, - ), - out_specs=q_sharding, - check_rep=False, - ) - - x = wrap_flash_attention(query, key, value, decoder_segment_ids) - return x - - -if __name__ == '__main__': - main() diff --git a/torchax/examples/train_llama_torchtitan/train_llama.py b/torchax/examples/train_llama_torchtitan/train_llama.py deleted file mode 100644 index 210b17d06f15..000000000000 --- a/torchax/examples/train_llama_torchtitan/train_llama.py +++ /dev/null @@ -1,385 +0,0 @@ -import os -import time -import logging -from typing import Tuple -from collections import defaultdict -import functools -import torch -import torch.nn.functional -from torch.utils import _pytree as pytree -import splash_attn -import helper - -import torchax as tx -import torchax.interop -import torchax.train -from torchax.interop import jax_view, torch_view, JittableModule -import jax -import jax.numpy as jnp -from jax.experimental import shard_map -from jax.experimental import mesh_utils -from jax.sharding import NamedSharding -import optax - -from torchtitan.models.llama import llama3_configs -from torchtitan.models.llama import model as titan - -P = jax.sharding.PartitionSpec - -num_global_devices = jax.device_count() -num_local_devices = jax.local_device_count() - - -def sharded_device_put(tensor: jax.Array, sharding) -> jax.Array: - if isinstance(tensor, tuple): - return tuple(sharded_device_put(t, sharding) for t in tensor) - - if num_global_devices == num_local_devices: - return jax.device_put(tensor, sharding) - - # NOTE: at here, num_global_devices != num_local_devices - # meaning we are in multi-host setup. Each host will run the same process - # and each process only need to handle the devices accessible to this host. - shape = tensor.shape - x_split = [ - jax.device_put(tensor[i], device) - for device, i in sharding.addressable_devices_indices_map(shape).items() - ] - return jax.make_array_from_single_device_arrays(shape, sharding, x_split) - - -sharding_map_original = { - "freqs_cis": (), # torch.complex64 (2048, 64) - "tok_embeddings.weight": - ('fsdp', 'tp'), # torch.float32 (vocab_size, 4096) - "layers.*.attention.wo.weight": ('fsdp', 'tp'), # torch.int8 (4096, 4096) - "layers.*.attention.wq.weight": ('tp', 'fsdp'), # torch.int8 (4096, 4096) - "layers.*.attention.wk.weight": ('tp', 'fsdp'), # torch.int8 (4096, 4096) - "layers.*.attention.wv.weight": ('tp', 'fsdp'), # torch.int8 (4096, 4096) - "layers.*.feed_forward.w1.weight": - ('tp', 'fsdp'), # torch.float32 (11008, 4096) - "layers.*.feed_forward.w2.weight": - ('fsdp', 'tp'), # torch.float32 (4096, 11008) - "layers.*.feed_forward.w3.weight": - ('tp', 'fsdp'), # torch.float32 (11008, 4096) - "layers.*.attention_norm.weight": ('fsdp',), # torch.float32 (4096,) - "layers.*.ffn_norm.weight": ('fsdp',), # torch.float32 (4096,) - "norm.weight": ('fsdp',), # torch.float32 (4096,) - "output.weight": ('tp', 'fsdp'), # torch.float32 (vocab_size, 4096) -} - -sharding_map_scan = { - "freqs_cis": (), # torch.complex64 (2048, 64) - # ParallelEmbedding for llama2; VocabParallelEmbedding for 3 - "tok_embeddings.weight": - ('tp', 'fsdp'), # torch.float32 (vocab_size, 4096) - "layers.params.attention___wo___weight": - (None, 'fsdp', 'tp'), # torch.int8 (n, 4096, 4096) - "layers.params.attention___wq___weight": - (None, 'tp', 'fsdp'), # torch.int8 (n, 4096, 4096) - "layers.params.attention___wk___weight": - (None, 'tp', 'fsdp'), # torch.int8 (n, 4096, 4096) - "layers.params.attention___wv___weight": - (None, 'tp', 'fsdp'), # torch.int8 (n, 4096, 4096) - "layers.params.feed_forward___w1___weight": - (None, 'tp', 'fsdp'), # torch.float32 (n, 11008, 4096) - "layers.params.feed_forward___w2___weight": - (None, 'fsdp', 'tp'), # torch.float32 (n, 4096, 11008) - "layers.params.feed_forward___w3___weight": - (None, 'tp', 'fsdp'), # torch.float32 (n, 11008, 4096) - "layers.params.attention_norm___weight": ( - None, - 'fsdp', - ), # torch.float32 (n, 4096,) - "layers.params.ffn_norm___weight": ( - None, - 'fsdp', - ), # torch.float32 (n, 4096,) - "norm.weight": ('fsdp',), # torch.float32 (4096,) - "output.weight": ('tp', 'fsdp'), # torch.float32 (vocab_size, 4096) -} - -sharding_map_scan_fsdp = { - "freqs_cis": (), # torch.complex64 (2048, 64) - # ParallelEmbedding for llama2; VocabParallelEmbedding for 3 - "tok_embeddings.weight": ('fsdp',), # torch.float32 (vocab_size, 4096) - "layers.params.attention___wo___weight": - (None, 'fsdp'), # torch.int8 (n, 4096, 4096) - "layers.params.attention___wq___weight": - (None, 'fsdp'), # torch.int8 (n, 4096, 4096) - "layers.params.attention___wk___weight": - (None, 'fsdp'), # torch.int8 (n, 4096, 4096) - "layers.params.attention___wv___weight": - (None, 'fsdp'), # torch.int8 (n, 4096, 4096) - "layers.params.feed_forward___w1___weight": - (None, 'fsdp'), # torch.float32 (n, 11008, 4096) - "layers.params.feed_forward___w2___weight": - (None, 'fsdp'), # torch.float32 (n, 4096, 11008) - "layers.params.feed_forward___w3___weight": - (None, 'fsdp'), # torch.float32 (n, 11008, 4096) - "layers.params.attention_norm___weight": ( - None, - 'fsdp', - ), # torch.float32 (n, 4096,) - "layers.params.ffn_norm___weight": ( - None, - 'fsdp', - ), # torch.float32 (n, 4096,) - "norm.weight": ('fsdp',), # torch.float32 (4096,) - "output.weight": ('fsdp',), # torch.float32 (vocab_size, 4096) -} - - -class Trainer: - - def __init__(self, mesh): - self.mesh = mesh - self.x_sharding = jax.sharding.NamedSharding(self.mesh, P('fsdp')) - self.replicated = jax.sharding.NamedSharding(self.mesh, P()) - - def fit(self, model, loss_fn, data_loader): - xla_env = torchax.default_env() - jax.config.update('jax_enable_x64', False) - xla_env._mesh = self.mesh - xla_env.use_flash_attention = True - - jittable_mod = JittableModule(model) - - # split the params to the n devices - - # model_fn is responsible to shard if needed - # to do FSDP one shards the first input args and output - # on the batch dimension - def model_fn(weights, buffers, args): - return jittable_mod.functional_call('forward', weights, buffers, args) - - jax_optimizer = optax.sgd(0.01) - opt_state = torch_view(jax_optimizer.init(jax_view(jittable_mod.params))) - - #opt_state = torchax.interop.call_jax(jax_optimizer.init, jittable_mod.params) - - train_step = torchax.train.make_train_step( - model_fn, - loss_fn, - jax_optimizer, - remat_policy=jax.checkpoint_policies.offload_dot_with_no_batch_dims( - 'device', 'pinned_host')) - - print('Begining training') - s = time.perf_counter() - jax.profiler.start_trace('/tmp/tensorboard') - print('start training') - min_loop_time = 10000 - for i, item in enumerate(data_loader): - inputs, labels = item - # Move them to jax device - inputs = inputs.to('jax') - labels = labels.to('jax') - - # Shard them on batch dim for fsdp - inputs.apply_jax_(sharded_device_put, self.x_sharding) - labels.apply_jax_(sharded_device_put, self.x_sharding) - - if i == 0: - train_step = helper.compile_step_func(train_step, jittable_mod.params, - jittable_mod.buffers, opt_state, - inputs, labels, self.mesh) - - print('INPUT shape', inputs.shape) - step_start = time.perf_counter() - loss, jittable_mod.params, opt_state = train_step(jittable_mod.params, - jittable_mod.buffers, - opt_state, inputs, - labels) - # wait for iteration to finish to measure time - torchax.interop.call_jax(jax.block_until_ready, - (loss, jittable_mod.params)) - step_end = time.perf_counter() - print(i, 'loss', loss, 'step latency: ', step_end - step_start) - loop_time = step_end - step_start - min_loop_time = min(min_loop_time, loop_time) - print('======') - if i >= 3: - break - jax.profiler.stop_trace() - return min_loop_time - - -def _process_sharding_name(name): - """Replace integers in param name with *. - - Presumably all layers should have the same sharding. - """ - - def is_integer(t): - try: - int(t) - return True - # pylint: disable-next=all - except: # noqa: E722 - return False - - tokens = name.split(".") - for i, t in enumerate(tokens): - if is_integer(t): - tokens[i] = "*" - return ".".join(tokens) - - -def create_sharded_weights(model, mesh, sharding_map): - res = {} - env = torchax.default_env() - for name, weight_meta in model.state_dict().items(): - sharding_spec = sharding_map.get(_process_sharding_name(name)) - if sharding_spec is None: - print('Skipping weight:', name) - continue - sharding = NamedSharding(mesh, P(*sharding_spec)) - with jax.default_device(jax.devices('cpu')[0]): - weight_torch = torch.randn(weight_meta.shape, dtype=weight_meta.dtype) - weight_jax = torchax.default_env().to_xla(weight_torch).jax() - #print(name, weight.shape, weight.dtype) - res[name] = env.j2t_iso( - jax.make_array_from_callback(weight_jax.shape, sharding, - lambda a: weight_jax[a])) - return res - - -def fake_dataloader(size, seqlen, batch_size): - for _ in range(size): - x = torch.randint(0, 32000, (batch_size, seqlen), device='cpu') - yield x, (x + 1) % 32000 - - -def main( - model_type='8B', - batch_size=8, - seqlen=2048, - override_num_layers=-1, - use_scan=True, - tp_parallelism=1, -): - torchax.enable_globally() - torchax.enable_performance_mode() - #logging.getLogger("jax").setLevel(logging.DEBUG) - print(f"Running with parameters {locals()}") - - fsdp = num_global_devices // tp_parallelism - mesh = jax.make_mesh((fsdp, tp_parallelism), ('fsdp', 'tp')) - if use_scan: - # using scan the individial weights will have shape (num_layers, w, h) - sharding_map = sharding_map_scan_fsdp - else: - sharding_map = sharding_map_original - - env = torchax.default_env() - env.config.use_tpu_flash_attention = True - env.config.shmap_flash_attention = True - env._mesh = mesh # this is the mesh used by flash attention pallas kernel - - args = llama3_configs[model_type] - # Note: torchtitan's upstream config did not specify this value - args.vocab_size = 128256 - args.max_seq_len = seqlen - if override_num_layers > 0: - args.n_layers = override_num_layers - - # Note: because a single device don't have enough HBM memory - # nor enough CPU memory to hold the parameters. We instantiate - # the model on meta then manually initialize then shard each param - torch.set_default_dtype(torch.bfloat16) - with torch.device('meta'): - gpt = titan.Transformer(args) - - with torch.device('cpu'): - # need actual value for freqs_cis - freqs_cis = gpt._precompute_freqs_cis() - - if use_scan: - checkpoint_policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims( - 'device', 'pinned_host') - gpt = TransfomerWithScan(gpt, checkpoint_policy) - - state_dict = dict(gpt.state_dict()) - state_dict.pop('freqs_cis') # dont shard freqs_cis - state_dict = create_sharded_weights(gpt, mesh, sharding_map) - replicated = jax.sharding.NamedSharding(mesh, P()) - - state_dict['freqs_cis'] = freqs_cis.to('jax').apply_jax( - jax.device_put, replicated) - gpt.load_state_dict(state_dict, assign=True) - - train_loader = fake_dataloader(10, seqlen, batch_size) - - # NOTE: overriding attention to capture mesh and sharding info - partition = P('fsdp', 'tp', None, None) - attention = functools.partial(splash_attn.tpu_splash_attention, mesh, - partition, True) - attention = jax.jit(attention) - - def custom_attention(query, - key, - value, - attn_mask=None, - dropout_p=0.0, - is_causal=False, - scale=None, - enable_gqa=False): - # batch, num of head, seq, dim - jk, jq, jv = jax_view((query, key, value)) - res = attention(jk, jq, jv, None) - return torch_view(res) - - env.override_op_definition(torch.nn.functional.scaled_dot_product_attention, - custom_attention) - - def loss_fn(logits, y): - num_tokens = logits.shape[-1] - logits = logits.reshape(-1, num_tokens) - y = y.reshape(-1) - return torch.nn.functional.cross_entropy(logits, y) - - with mesh: - trainer = Trainer(mesh) - return trainer.fit(gpt, loss_fn, train_loader) - - -class TransfomerWithScan(torch.nn.Module): - - def __init__(self, old_transformer, checkpoint_policy): - super().__init__() - self.tok_embeddings = old_transformer.tok_embeddings - self.norm = old_transformer.norm - self.output = old_transformer.output - self.layers = torchax.train.ScannedModule( - list(old_transformer.layers.values()), checkpoint_policy) - - self.register_buffer('freqs_cis', old_transformer.freqs_cis) - - def forward(self, tokens: torch.Tensor): - """ - Perform a forward pass through the Transformer model. - - Args: - tokens (torch.Tensor): Input token indices. - - Returns: - torch.Tensor: Output logits after applying the Transformer model. - - """ - # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages - h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens - - # for layer in self.layers.values(): - # h = layer(h, self.freqs_cis) - - h = self.layers(h, self.freqs_cis) - - h = self.norm(h) if self.norm else h - output = self.output(h) if self.output else h - return output - - -if __name__ == '__main__': - import fire - fire.Fire(main) diff --git a/torchax/format.sh b/torchax/format.sh deleted file mode 100755 index 9b9663294ca4..000000000000 --- a/torchax/format.sh +++ /dev/null @@ -1,4 +0,0 @@ -#!/usr/bin/env bash -set -ex - -yapf --recursive -i *.py test torchax \ No newline at end of file diff --git a/torchax/pyproject.toml b/torchax/pyproject.toml deleted file mode 100644 index 2f30f30e7c68..000000000000 --- a/torchax/pyproject.toml +++ /dev/null @@ -1,50 +0,0 @@ -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[project] -name = "torchax" -dependencies = [] -requires-python = ">=3.10" -license = {file = "LICENSE"} -dynamic = ["version"] -authors = [ - {name = "Han Qi", email = "qihan.dev@gmail.com"}, - {name = "Pytorch/XLA team", email = "pytorchxla-dev@google.com"}, -] -description = "torchax is a library for running Jax and PyTorch together" -readme = "README.md" -classifiers = [ - "Development Status :: 3 - Alpha", - "Intended Audience :: Developers", - "Intended Audience :: Education", - "Intended Audience :: Science/Research", - "License :: OSI Approved :: BSD License", - "Topic :: Scientific/Engineering", - "Topic :: Scientific/Engineering :: Mathematics", - "Topic :: Scientific/Engineering :: Artificial Intelligence", - "Topic :: Software Development", - "Topic :: Software Development :: Libraries", - "Topic :: Software Development :: Libraries :: Python Modules", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.13", -] - -[project.urls] -"Homepage" = "https://github.com/pytorch/xla/tree/master/torchax" - - -[tool.hatch.version] -path = "torchax/__init__.py" - -[project.optional-dependencies] -cpu = ["jax[cpu]>=0.6.2", "jax[cpu]"] -# Add libtpu index `-f https://storage.googleapis.com/libtpu-wheels/index.html -f https://storage.googleapis.com/libtpu-releases/index.html` -tpu = ["jax[cpu]>=0.6.2", "jax[tpu]"] -cuda = ["jax[cpu]>=0.6.2", "jax[cuda12]"] -odml = ["jax[cpu]>=0.6.2", "jax[cpu]"] - -[tool.hatch.build.targets.wheel] -packages = ["torchax"] diff --git a/torchax/test-requirements.txt b/torchax/test-requirements.txt deleted file mode 100644 index 677912bbd04d..000000000000 --- a/torchax/test-requirements.txt +++ /dev/null @@ -1,9 +0,0 @@ --r dev-requirements.txt -absl-py==2.2.2 -immutabledict==4.2.1 -pytest==8.3.5 -sentencepiece -expecttest==0.3.0 -optax==0.2.4 -pytest -pytest-xdist diff --git a/torchax/test/__init__.py b/torchax/test/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/torchax/test/base_test_util.py b/torchax/test/base_test_util.py deleted file mode 100644 index f155b78f67a9..000000000000 --- a/torchax/test/base_test_util.py +++ /dev/null @@ -1,55 +0,0 @@ -import unittest -import torch -from torch.utils import _pytree as pytree - -from torchax import tensor - -TestCase = unittest.TestCase -main = unittest.main - - -def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True): - if isinstance(output1, torch.Tensor): - testcase.assertIsInstance(output2, torch.Tensor) - output2_cpu = output2.detach().cpu() - if output2_cpu.dtype != output1.dtype: - output2_cpu = output2_cpu.to(output1.dtype) - testcase.assertTrue( - torch.allclose( - output1, output2_cpu, atol=atol, rtol=rtol, equal_nan=equal_nan)) - elif isinstance(output1, (tuple, list)): - testcase.assertIsInstance(output2, (tuple, list)) - testcase.assertEqual(len(output1), len(output2)) - for o1, o2 in zip(output1, output2): - diff_output(testcase, o1, o2, rtol, atol) - else: - testcase.assertEqual(output1, output2) - - -def run_function_and_compare(testcase, - func, - args, - kwargs, - atol=1e-3, - rtol=1e-5, - equal_nan=True, - ignore_indices=False): - with testcase.subTest("torch_eval"): - res = func(*args, **kwargs) - with testcase.subTest("torchax_eval"): - args2, kwargs2 = pytree.tree_map_only(torch.Tensor, tensor.move_to_device, - (args, kwargs)) - res2 = func(*args2, **kwargs2) - res2 = pytree.tree_map_only(tensor.Tensor, lambda t: t.torch(), res2) - with testcase.subTest("torchax_diff:" + str(atol)): - if ignore_indices and isinstance(res, tuple) and len(res) == 2: - diff_output( - testcase, - res[0], - res2[0], - atol=atol, - rtol=rtol, - equal_nan=equal_nan) - else: - diff_output( - testcase, res, res2, atol=atol, rtol=rtol, equal_nan=equal_nan) \ No newline at end of file diff --git a/torchax/test/gemma/__init__.py b/torchax/test/gemma/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/torchax/test/gemma/config.py b/torchax/test/gemma/config.py deleted file mode 100644 index 423c9a5b0d0b..000000000000 --- a/torchax/test/gemma/config.py +++ /dev/null @@ -1,83 +0,0 @@ -# From: https://github.com/google/gemma_pytorch/blob/main/gemma/config.py - -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Gemma model config.""" - -import dataclasses -import immutabledict -import torch -from typing import Optional - -# Keep a mapping from dtype strings to the supported torch dtypes. -_STR_DTYPE_TO_TORCH_DTYPE = immutabledict.immutabledict({ - 'float16': torch.float16, - 'float': torch.float32, - 'float32': torch.float32, - 'bfloat16': torch.bfloat16, -}) - - -@dataclasses.dataclass -class GemmaConfig: - # The number of tokens in the vocabulary. - vocab_size: int = 256000 - # The maximum sequence length that this model might ever be used with. - max_position_embeddings: int = 8192 - # The number of blocks in the model. - num_hidden_layers: int = 28 - # The number of attention heads used in the attention layers of the model. - num_attention_heads: int = 16 - # The number of key-value heads for implementing attention. - num_key_value_heads: int = 16 - # The hidden size of the model. - hidden_size: int = 3072 - # The dimension of the MLP representations. - intermediate_size: int = 24576 - # The number of head dimensions. - head_dim: int = 256 - # The epsilon used by the rms normalization layers. - rms_norm_eps: float = 1e-6 - # The dtype of the weights. - dtype: str = 'bfloat16' - # Whether a quantized version of the model is used. - quant: bool = False - # The path to the model tokenizer. - tokenizer: Optional[str] = 'tokenizer/tokenizer.model' - - def get_dtype(self) -> Optional[torch.dtype]: - """Gets the torch dtype from the config dtype string.""" - return _STR_DTYPE_TO_TORCH_DTYPE.get(self.dtype, None) - - -def get_config_for_7b() -> GemmaConfig: - return GemmaConfig() - - -def get_config_for_2b() -> GemmaConfig: - return GemmaConfig( - num_hidden_layers=18, - num_attention_heads=8, - num_key_value_heads=1, - hidden_size=2048, - intermediate_size=16384) - - -def get_model_config(variant: str) -> GemmaConfig: - if variant == '7b': - return get_config_for_7b() - elif variant == '2b': - return get_config_for_2b() - return ValueError(f'Invalid variant {variant}. Supported variants are "2b"' - 'and "7b"') diff --git a/torchax/test/gemma/model.py b/torchax/test/gemma/model.py deleted file mode 100644 index 520a221ef632..000000000000 --- a/torchax/test/gemma/model.py +++ /dev/null @@ -1,549 +0,0 @@ -# From: https://raw.githubusercontent.com/google/gemma_pytorch/main/gemma/model.py - -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Inference-only Gemma model implementation.""" - -import re -import torch -from torch import nn -import torch.nn.functional as F -from typing import Any, List, Optional, Sequence, Tuple, Union - -from . import config as gemma_config -from . import tokenizer - - -class Sampler(nn.Module): - - def __init__(self, vocab_size: int): - super().__init__() - self.vocab_size = vocab_size - - @torch.no_grad() - def forward( - self, - embedding: torch.Tensor, - hidden_states: torch.Tensor, - output_positions: torch.Tensor, - temperatures: torch.Tensor, - top_ps: torch.Tensor, - top_ks: torch.Tensor, - embedding_bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - # Select the last element for each sequence. - # (batch_size, input_len, hidden_size) -> (batch_size, hidden_size) - hidden_states = hidden_states.index_select(1, - output_positions).squeeze(dim=1) - logits = torch.matmul(hidden_states, embedding.t()) - if embedding_bias is not None: - logits += embedding_bias - - if temperatures is None: - return torch.argmax(logits, dim=-1).squeeze(dim=-1) - - # Apply temperature scaling. - logits.div_(temperatures.unsqueeze(dim=1)) - - # Calculate probabilities with softmax. - probs = torch.softmax(logits, dim=-1, dtype=torch.float) - probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) - - # Apply top-p, top-k. - probs_sum = torch.cumsum(probs_sort, dim=-1) - top_ps_mask = (probs_sum - probs_sort) > top_ps.unsqueeze(dim=1) - probs_sort = torch.where(top_ps_mask, 0, probs_sort) - - top_ks_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) - top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1) - top_ks_mask = top_ks_mask >= top_ks.unsqueeze(dim=1) - probs_sort = torch.where(top_ks_mask, 0, probs_sort) - - # Re-normalization. - probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) - probs = torch.gather( - probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1)) - next_token_ids = torch.multinomial( - probs, num_samples=1, replacement=True).squeeze(dim=-1) - return next_token_ids - - -def precompute_freqs_cis(dim: int, - end: int, - theta: float = 10000.0) -> torch.Tensor: - """Precomputes the frequency cis.""" - freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device) - freqs = torch.outer(t, freqs).float() - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - return freqs_cis - - -def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: - """Applies the rotary embedding to the query and key tensors.""" - x_ = torch.view_as_complex( - torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1)) - x_out = torch.view_as_real(x_ * freqs_cis).type_as(x) - x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2) - x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], - -1).transpose(1, 2) - return x_out - - -class Linear(nn.Module): - - def __init__(self, in_features: int, out_features: int, quant: bool): - super().__init__() - if quant: - self.weight = nn.Parameter( - torch.empty((out_features, in_features), dtype=torch.int8), - requires_grad=False, - ) - self.weight_scaler = nn.Parameter(torch.Tensor(out_features)) - else: - self.weight = nn.Parameter( - torch.empty((out_features, in_features)), - requires_grad=False, - ) - self.quant = quant - - def forward(self, x): - weight = self.weight - if self.quant: - weight = weight * self.weight_scaler.unsqueeze(-1) - output = F.linear(x, weight) - return output - - -class Embedding(nn.Module): - - def __init__(self, num_embeddings: int, embedding_dim: int, quant: bool): - super().__init__() - if quant: - self.weight = nn.Parameter( - torch.empty((num_embeddings, embedding_dim), dtype=torch.int8), - requires_grad=False, - ) - self.weight_scaler = nn.Parameter(torch.Tensor(num_embeddings)) - else: - self.weight = nn.Parameter( - torch.empty((num_embeddings, embedding_dim)), - requires_grad=False, - ) - self.quant = quant - - def forward(self, x): - weight = self.weight - if self.quant: - weight = weight * self.weight_scaler.unsqueeze(-1) - output = F.embedding(x, weight) - return output - - -class RMSNorm(torch.nn.Module): - - def __init__( - self, - dim: int, - eps: float = 1e-6, - add_unit_offset: bool = True, - ): - super().__init__() - self.eps = eps - self.add_unit_offset = add_unit_offset - self.weight = nn.Parameter(torch.zeros(dim)) - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - x = self._norm(x.float()).type_as(x) - if self.add_unit_offset: - output = x * (1 + self.weight) - else: - output = x * self.weight - return output - - -class GemmaMLP(nn.Module): - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - quant: bool, - ): - super().__init__() - self.gate_proj = Linear(hidden_size, intermediate_size, quant) - self.up_proj = Linear(hidden_size, intermediate_size, quant) - self.down_proj = Linear(intermediate_size, hidden_size, quant) - - def forward(self, x): - gate = self.gate_proj(x) - gate = F.gelu(gate) - up = self.up_proj(x) - fuse = gate * up - outputs = self.down_proj(fuse) - return outputs - - -class GemmaAttention(nn.Module): - - def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - head_dim: int, - quant: bool, - ): - super().__init__() - - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - - assert self.num_heads % self.num_kv_heads == 0 - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - self.hidden_size = hidden_size - self.head_dim = head_dim - - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - - self.scaling = self.head_dim**-0.5 - - self.qkv_proj = Linear( - self.hidden_size, - (self.num_heads + 2 * self.num_kv_heads) * self.head_dim, - quant=quant) - self.o_proj = Linear( - self.num_heads * self.head_dim, self.hidden_size, quant=quant) - - def forward( - self, - hidden_states: torch.Tensor, - freqs_cis: torch.Tensor, - kv_write_indices: torch.Tensor, - kv_cache: Tuple[torch.Tensor, torch.Tensor], - mask: torch.Tensor, - ) -> torch.Tensor: - hidden_states_shape = hidden_states.shape - assert len(hidden_states_shape) == 3 - - batch_size, input_len, _ = hidden_states_shape - - qkv = self.qkv_proj(hidden_states) - xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - - xq = xq.view(batch_size, -1, self.num_heads, self.head_dim) - xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim) - xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim) - - # Positional embedding. - xq = apply_rotary_emb(xq, freqs_cis=freqs_cis) - xk = apply_rotary_emb(xk, freqs_cis=freqs_cis) - - # Write new kv cache. - # [batch_size, input_len, n_local_kv_heads, head_dim] - k_cache, v_cache = kv_cache - k_cache.index_copy_(1, kv_write_indices, xk) - v_cache.index_copy_(1, kv_write_indices, xv) - - key = k_cache - value = v_cache - if self.num_kv_heads != self.num_heads: - # [batch_size, max_seq_len, n_local_heads, head_dim] - key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2) - value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=2) - - # [batch_size, n_local_heads, input_len, head_dim] - q = xq.transpose(1, 2) - # [batch_size, n_local_heads, max_seq_len, head_dim] - k = key.transpose(1, 2) - v = value.transpose(1, 2) - - # [batch_size, n_local_heads, input_len, max_seq_len] - scores = torch.matmul(q, k.transpose(2, 3)) * self.scaling - scores = scores + mask - scores = F.softmax(scores.float(), dim=-1).type_as(q) - - # [batch_size, n_local_heads, input_len, head_dim] - output = torch.matmul(scores, v) - - # [batch_size, input_len, hidden_dim] - output = ( - output.transpose(1, 2).contiguous().view(batch_size, input_len, -1)) - output = self.o_proj(output) - return output - - -class GemmaDecoderLayer(nn.Module): - - def __init__( - self, - config: gemma_config.GemmaConfig, - ): - super().__init__() - self.self_attn = GemmaAttention( - hidden_size=config.hidden_size, - num_heads=config.num_attention_heads, - num_kv_heads=config.num_key_value_heads, - head_dim=config.head_dim, - quant=config.quant, - ) - self.mlp = GemmaMLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - quant=config.quant, - ) - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm( - config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - freqs_cis: torch.Tensor, - kv_write_indices: torch.Tensor, - kv_cache: Tuple[torch.Tensor, torch.Tensor], - mask: torch.Tensor, - ) -> torch.Tensor: - # Self Attention - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - hidden_states = self.self_attn( - hidden_states=hidden_states, - freqs_cis=freqs_cis, - kv_write_indices=kv_write_indices, - kv_cache=kv_cache, - mask=mask, - ) - hidden_states = residual + hidden_states - - # MLP - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - return hidden_states - - -class GemmaModel(nn.Module): - - def __init__(self, config: gemma_config.GemmaConfig): - super().__init__() - self.config = config - self.vocab_size = config.vocab_size - - self.layers = nn.ModuleList() - for _ in range(config.num_hidden_layers): - self.layers.append(GemmaDecoderLayer(config)) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - freqs_cis: torch.Tensor, - kv_write_indices: torch.Tensor, - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - mask: torch.Tensor, - ) -> torch.Tensor: - for i in range(len(self.layers)): - layer = self.layers[i] - hidden_states = layer( - hidden_states=hidden_states, - freqs_cis=freqs_cis, - kv_write_indices=kv_write_indices, - kv_cache=kv_caches[i], - mask=mask, - ) - hidden_states = self.norm(hidden_states) - return hidden_states - - -class GemmaForCausalLM(nn.Module): - - def __init__( - self, - config: gemma_config.GemmaConfig, - ): - super().__init__() - self.config = config - assert config.hidden_size % config.num_attention_heads == 0 - - max_seq_len = config.max_position_embeddings - head_dim = config.head_dim - vocab_size = config.vocab_size - - self.tokenizer = None #tokenizer.Tokenizer(config.tokenizer) - self.embedder = Embedding(vocab_size, config.hidden_size, config.quant) - self.model = GemmaModel(config) - self.sampler = Sampler(vocab_size) - - # Pre-compute rotary embedding table. - rope_theta = getattr(config, 'rope_theta', 10000) - freqs_cis = precompute_freqs_cis( - head_dim, max_seq_len * 2, theta=rope_theta) - self.register_buffer('freqs_cis', freqs_cis) - - @torch.no_grad() - def forward( - self, - input_token_ids: torch.Tensor, - input_positions: torch.Tensor, - kv_write_indices: torch.Tensor, - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - mask: torch.Tensor, - output_positions: torch.Tensor, - temperatures: torch.Tensor, - top_ps: torch.Tensor, - top_ks: torch.Tensor, - **kwargs, - ) -> torch.Tensor: - freqs_cis = self.freqs_cis.index_select(0, input_positions) - kv_write_indices = input_positions - - # [batch_size, input_len, hidden_size] - hidden_states = self.embedder(input_token_ids) - # Gemma normalizes the embedding by sqrt(hidden_size). - hidden_states = hidden_states * (self.config.hidden_size**0.5) - - hidden_states = self.model( - hidden_states=hidden_states, - freqs_cis=freqs_cis, - kv_write_indices=kv_write_indices, - kv_caches=kv_caches, - mask=mask, - ) - embedder_weight = self.embedder.weight - if self.config.quant: - embedder_weight = ( - embedder_weight * self.embedder.weight_scaler.unsqueeze(-1)) - #next_tokens = self.sampler( - return hidden_states - # return next_tokens - - def generate( - self, - prompts: Union[str, Sequence[str]], - device: Any, - output_len: int = 100, - temperature: float = 0.95, - top_p: float = 1.0, - top_k: int = 100, - ) -> Union[str, Sequence[str]]: - """Generates responses for given prompts using Gemma model.""" - # If a single prompt is provided, treat it as a batch of 1. - is_str_prompt = isinstance(prompts, str) - if is_str_prompt: - prompts = [prompts] - - batch_size = len(prompts) - prompt_tokens = [self.tokenizer.encode(prompt) for prompt in prompts] - min_prompt_len = min(len(p) for p in prompt_tokens) - max_prompt_len = max(len(p) for p in prompt_tokens) - max_seq_len = max_prompt_len + output_len - assert max_seq_len <= self.config.max_position_embeddings - - # build KV caches - kv_caches = [] - for _ in range(self.config.num_hidden_layers): - size = (batch_size, max_seq_len, self.config.num_key_value_heads, - self.config.head_dim) - dtype = self.config.get_dtype() - k_cache = torch.zeros(size=size, dtype=dtype, device=device) - v_cache = torch.zeros(size=size, dtype=dtype, device=device) - kv_caches.append((k_cache, v_cache)) - - # prepare inputs - token_ids_tensor = torch.full((batch_size, max_seq_len), - self.tokenizer.pad_id, - dtype=torch.int64) - input_token_ids_tensor = torch.full((batch_size, min_prompt_len), - self.tokenizer.pad_id, - dtype=torch.int64) - for i, p in enumerate(prompt_tokens): - token_ids_tensor[i, :len(p)] = torch.tensor(p) - input_token_ids_tensor[i, :min_prompt_len] = torch.tensor( - p[:min_prompt_len]) - token_ids_tensor = token_ids_tensor.to(device) - input_token_ids_tensor = input_token_ids_tensor.to(device) - prompt_mask_tensor = token_ids_tensor != self.tokenizer.pad_id - input_positions_tensor = torch.arange( - 0, min_prompt_len, dtype=torch.int64).to(device) - mask_tensor = torch.full((1, 1, max_seq_len, max_seq_len), - -2.3819763e38).to(torch.float) - mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device) - curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor) - output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(device) - temperatures_tensor = torch.FloatTensor([temperature] * - batch_size).to(device) - top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device) - top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device) - output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(device) - - # Prefill up to min_prompt_len tokens, then treat other prefill as - # decode and ignore output. - for i in range(max_seq_len - min_prompt_len): - next_token_ids = self( - input_token_ids=input_token_ids_tensor, - input_positions=input_positions_tensor, - kv_write_indices=None, - kv_caches=kv_caches, - mask=curr_mask_tensor, - output_positions=output_positions_tensor, - temperatures=temperatures_tensor, - top_ps=top_ps_tensor, - top_ks=top_ks_tensor, - ) - - curr_prompt_mask = prompt_mask_tensor.index_select( - 1, output_index).squeeze(dim=1) - curr_token_ids = token_ids_tensor.index_select( - 1, output_index).squeeze(dim=1) - output_token_ids = torch.where(curr_prompt_mask, curr_token_ids, - next_token_ids).unsqueeze(dim=1) - token_ids_tensor.index_copy_(1, output_index, output_token_ids) - - input_token_ids_tensor = output_token_ids - input_positions_tensor = output_index.unsqueeze(dim=-1) - curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor) - output_positions_tensor = torch.tensor(0, dtype=torch.int64).to(device) - output_index = output_index + 1 - - # Detokenization. - token_ids = token_ids_tensor.tolist() - results = [] - for i, tokens in enumerate(token_ids): - trimmed_output = tokens[len(prompt_tokens[i]):len(prompt_tokens[i]) + - output_len] - if self.tokenizer.eos_id in trimmed_output: - eos_index = trimmed_output.index(self.tokenizer.eos_id) - trimmed_output = trimmed_output[:eos_index] - results.append(self.tokenizer.decode(trimmed_output)) - - # If a string was provided as input, return a string as output. - return results[0] if is_str_prompt else results - - def load_weights(self, model_path: str): - self.load_state_dict( - torch.load( - model_path, - mmap=True, - weights_only=True, - )['model_state_dict'], - strict=False, - ) diff --git a/torchax/test/gemma/test_gemma.py b/torchax/test/gemma/test_gemma.py deleted file mode 100644 index 9160ad7aa7ee..000000000000 --- a/torchax/test/gemma/test_gemma.py +++ /dev/null @@ -1,82 +0,0 @@ -import torch -import unittest -import torchax -from torch.utils import _pytree as pytree -from . import config -from . import model as gemma - - -class GemmaTest(unittest.TestCase): - - def setup(self): - torch.manual_seed(0) - - def test_gemma(self): - mconfig = config.GemmaConfig( - num_hidden_layers=3, - num_attention_heads=8, - num_key_value_heads=1, - hidden_size=256, - intermediate_size=16384, - dtype=torch.float32) - model = gemma.GemmaForCausalLM(mconfig) - batch_size = 1 - max_seq_len = 1000 - min_prompt_len = 1000 - device = 'cpu' - pad_id = -1 - temperature = 0.8 - top_k = 100 - top_p = 1.0 - - # prepare inputs - token_ids_tensor = torch.randint( - 0, max_seq_len, (batch_size, max_seq_len), dtype=torch.int64) - - # build KV caches - kv_caches = [] - for _ in range(model.config.num_hidden_layers): - size = (batch_size, max_seq_len, model.config.num_key_value_heads, - model.config.head_dim) - dtype = model.config.get_dtype() - k_cache = torch.zeros(size=size, dtype=dtype, device=device) - v_cache = torch.zeros(size=size, dtype=dtype, device=device) - kv_caches.append((k_cache, v_cache)) - - token_ids_tensor = token_ids_tensor.to(device) - prompt_mask_tensor = torch.ones_like(token_ids_tensor) - input_positions_tensor = torch.arange( - 0, min_prompt_len, dtype=torch.int64).to(device) - mask_tensor = torch.full((1, 1, max_seq_len, max_seq_len), - -2.3819763e38).to(torch.float) - mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device) - curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor) - output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(device) - temperatures_tensor = torch.FloatTensor([temperature] * - batch_size).to(device) - top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device) - top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device) - output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(device) - - inputs = ( - token_ids_tensor, - input_positions_tensor, - None, # kv_write_indexes - kv_caches, - mask_tensor, - output_positions_tensor, - temperatures_tensor, - top_ps_tensor, - top_ks_tensor, - ) - - weights, jax_func = torchax.extract_jax(model) - env = torchax.default_env() - inputs_jax = env.t2j_copy(inputs) - - import jax - print(jax.jit(jax_func)(weights, inputs_jax)) - - -if __name__ == '__main__': - unittest.main() diff --git a/torchax/test/gemma/tokenizer.py b/torchax/test/gemma/tokenizer.py deleted file mode 100644 index ba4c7f539044..000000000000 --- a/torchax/test/gemma/tokenizer.py +++ /dev/null @@ -1,48 +0,0 @@ -# From: https://github.com/google/gemma_pytorch/blob/main/gemma/tokenizer.py - -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -from typing import List, Optional - -from sentencepiece import SentencePieceProcessor - - -class Tokenizer: - - def __init__(self, model_path: Optional[str]): - # Reload tokenizer. - assert os.path.isfile(model_path), model_path - self.sp_model = SentencePieceProcessor(model_file=model_path) - - # BOS / EOS token IDs. - self.n_words: int = self.sp_model.vocab_size() - self.bos_id: int = self.sp_model.bos_id() - self.eos_id: int = self.sp_model.eos_id() - self.pad_id: int = self.sp_model.pad_id() - assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() - - def encode(self, s: str, bos: bool = True, eos: bool = False) -> List[int]: - """Converts a string into a list of tokens.""" - assert isinstance(s, str) - t = self.sp_model.encode(s) - if bos: - t = [self.bos_id] + t - if eos: - t = t + [self.eos_id] - return t - - def decode(self, t: List[int]) -> str: - """Converts a list of tokens into a string.""" - return self.sp_model.decode(t) diff --git a/torchax/test/llama/BUILD b/torchax/test/llama/BUILD deleted file mode 100644 index 5fd0fdf4b966..000000000000 --- a/torchax/test/llama/BUILD +++ /dev/null @@ -1,25 +0,0 @@ -# TODO(hanq): describe this package. -load( - "//third_party/py/torch/google/bazel_rules/rules_python/python:defs.bzl", - "py_test", -) - -package( - default_applicable_licenses = ["//devtools/compliance/licenses:no_external_contributions"], - default_visibility = ["//visibility:public"], - licenses = ["notice"], -) - -py_test( - name = "test_llama", - srcs = [ - "llama_model.py", - "test_llama.py", - ], - deps = [ - "//third_party/py/jax", - "//third_party/py/torch:pytorch", - "//third_party/py/torch/google/_torx", - "//third_party/py/torch/google/_torx/test:test_base", - ], -) diff --git a/torchax/test/llama/__init__.py b/torchax/test/llama/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/torchax/test/llama/llama_model.py b/torchax/test/llama/llama_model.py deleted file mode 100644 index 2aa3566ae0b1..000000000000 --- a/torchax/test/llama/llama_model.py +++ /dev/null @@ -1,310 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -# This file is copied from https://github.com/meta-pytorch/gpt-fast -# This is used for unit test purposes -from dataclasses import dataclass -import math -from typing import Optional - -import torch -from torch import Tensor -import torch.nn as nn -from torch.nn import functional as F - - -def find_multiple(n: int, k: int) -> int: - if n % k == 0: - return n - return n + k - (n % k) - - -@dataclass -class ModelArgs: - block_size: int = 2048 - vocab_size: int = 32000 - n_layer: int = 32 - n_head: int = 32 - dim: int = 4096 - intermediate_size: int = None - n_local_heads: int = -1 - head_dim: int = 64 - rope_base: float = 10000 - norm_eps: float = 1e-5 - - def __post_init__(self): - if self.n_local_heads == -1: - self.n_local_heads = self.n_head - if self.intermediate_size is None: - hidden_dim = 4 * self.dim - n_hidden = int(2 * hidden_dim / 3) - self.intermediate_size = find_multiple(n_hidden, 256) - self.head_dim = self.dim // self.n_head - - @classmethod - def from_name(cls, name: str): - if name in transformer_configs: - return cls(**transformer_configs[name]) - # fuzzy search - config = [ - config for config in transformer_configs - if config in str(name).upper() or config in str(name) - ] - assert len(config) == 1, name - return cls(**transformer_configs[config[0]]) - - -transformer_configs = { - "CodeLlama-7b-Python-hf": - dict( - block_size=16384, - vocab_size=32000, - n_layer=32, - dim=4096, - rope_base=1000000, - ), - "7B": - dict(n_layer=32, n_head=32, dim=4096), - "13B": - dict(n_layer=40, n_head=40, dim=5120), - "30B": - dict(n_layer=60, n_head=52, dim=6656), - "34B": - dict( - n_layer=48, - n_head=64, - dim=8192, - vocab_size=32000, - n_local_heads=8, - intermediate_size=22016, - rope_base=1000000, - ), # CodeLlama-34B-Python-hf - "70B": - dict( - n_layer=80, - n_head=64, - dim=8192, - n_local_heads=8, - intermediate_size=28672, - ), -} - - -class KVCache(nn.Module): - - def __init__( - self, - max_batch_size, - max_seq_length, - n_heads, - head_dim, - dtype=torch.bfloat16, - ): - super().__init__() - cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) - self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) - self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) - - def update(self, input_pos, k_val, v_val): - # input_pos: [S], k_val: [B, H, S, D] - assert input_pos.shape[0] == k_val.shape[2] - - k_out = self.k_cache - v_out = self.v_cache - k_out[:, :, input_pos] = k_val - v_out[:, :, input_pos] = v_val - - return k_out, v_out - - -class Transformer(nn.Module): - - def __init__(self, config: ModelArgs) -> None: - super().__init__() - self.config = config - - self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) - self.layers = nn.ModuleList( - TransformerBlock(config) for _ in range(config.n_layer)) - self.norm = RMSNorm(config.dim, eps=config.norm_eps) - self.output = nn.Linear(config.dim, config.vocab_size, bias=False) - - self.max_batch_size = -1 - self.max_seq_length = -1 - - def setup_caches(self, max_batch_size, max_seq_length): - if (self.max_seq_length >= max_seq_length and - self.max_batch_size >= max_batch_size): - return - head_dim = self.config.dim // self.config.n_head - max_seq_length = find_multiple(max_seq_length, 8) - self.max_seq_length = max_seq_length - self.max_batch_size = max_batch_size - for b in self.layers: - b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, - self.config.n_local_heads, head_dim) - - freqs_cis = precompute_freqs_cis( - self.config.block_size, - self.config.dim // self.config.n_head, - self.config.rope_base, - ) - self.register_buffer('freqs_cis', freqs_cis) - causal_mask = torch.tril( - torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) - self.register_buffer('causal_mask', causal_mask) - - def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: - assert self.freqs_cis is not None, "Caches must be initialized first" - mask = self.causal_mask[None, None, input_pos] - freqs_cis = self.freqs_cis[input_pos] - x = self.tok_embeddings(idx) - - for i, layer in enumerate(self.layers): - x = layer(x, input_pos, freqs_cis, mask) - x = self.norm(x) - logits = self.output(x) - return logits - - @classmethod - def from_name(cls, name: str): - return cls(ModelArgs.from_name(name)) - - -class TransformerBlock(nn.Module): - - def __init__(self, config: ModelArgs) -> None: - super().__init__() - self.attention = Attention(config) - self.feed_forward = FeedForward(config) - self.ffn_norm = RMSNorm(config.dim, config.norm_eps) - self.attention_norm = RMSNorm(config.dim, config.norm_eps) - - def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, - mask: Tensor) -> Tensor: - h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) - out = h + self.feed_forward(self.ffn_norm(h)) - return out - - -class Attention(nn.Module): - - def __init__(self, config: ModelArgs): - super().__init__() - assert config.dim % config.n_head == 0 - - total_head_dim = (config.n_head + - 2 * config.n_local_heads) * config.head_dim - # key, query, value projections for all heads, but in a batch - self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) - self.wo = nn.Linear(config.dim, config.dim, bias=False) - self.kv_cache = None - - self.n_head = config.n_head - self.head_dim = config.head_dim - self.n_local_heads = config.n_local_heads - self.dim = config.dim - self._register_load_state_dict_pre_hook(self.load_hook) - - def load_hook(self, state_dict, prefix, *args): - if prefix + "wq.weight" in state_dict: - wq = state_dict.pop(prefix + "wq.weight") - wk = state_dict.pop(prefix + "wk.weight") - wv = state_dict.pop(prefix + "wv.weight") - state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) - - def forward( - self, - x: Tensor, - freqs_cis: Tensor, - mask: Tensor, - input_pos: Optional[Tensor] = None, - ) -> Tensor: - bsz, seqlen, _ = x.shape - - kv_size = self.n_local_heads * self.head_dim - q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) - - q = q.view(bsz, seqlen, self.n_head, self.head_dim) - k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) - v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) - - q = apply_rotary_emb(q, freqs_cis) - k = apply_rotary_emb(k, freqs_cis) - - q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) - - if self.kv_cache is not None: - k, v = self.kv_cache.update(input_pos, k, v) - - k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) - v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) - print('q=', q.shape) - print('k=', k.shape) - print('v=', v.shape) - print('mask=', mask.shape) - y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) - - y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) - - y = self.wo(y) - return y - - -class FeedForward(nn.Module): - - def __init__(self, config: ModelArgs) -> None: - super().__init__() - self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) - self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) - self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) - - def forward(self, x: Tensor) -> Tensor: - return self.w2(F.silu(self.w1(x)) * self.w3(x)) - - -class RMSNorm(nn.Module): - - def __init__(self, dim: int, eps: float = 1e-5): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x): - return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) - - def forward(self, x: Tensor) -> Tensor: - output = self._norm(x.float()).type_as(x) - return output * self.weight - - -def precompute_freqs_cis(seq_len: int, - n_elem: int, - base: int = 10000) -> Tensor: - freqs = 1.0 / ( - base**(torch.arange(0, n_elem, 2)[:(n_elem // 2)].float() / n_elem)) - t = torch.arange(seq_len, device=freqs.device) - freqs = torch.outer(t, freqs) - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) - cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) - return cache.to(dtype=torch.bfloat16) - - -def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: - xshaped = x.float().reshape(*x.shape[:-1], -1, 2) - freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) - x_out2 = torch.stack( - [ - xshaped[..., 0] * freqs_cis[..., 0] - - xshaped[..., 1] * freqs_cis[..., 1], - xshaped[..., 1] * freqs_cis[..., 0] + - xshaped[..., 0] * freqs_cis[..., 1], - ], - -1, - ) - - x_out2 = x_out2.flatten(3) - return x_out2.type_as(x) diff --git a/torchax/test/llama/model_exportable.py b/torchax/test/llama/model_exportable.py deleted file mode 100644 index a3ab9358e9b5..000000000000 --- a/torchax/test/llama/model_exportable.py +++ /dev/null @@ -1,304 +0,0 @@ -# this version contains modification to make it easier to trace - -import math -from dataclasses import dataclass -from typing import Any, Optional, Tuple, List - -import torch -import torch.nn.functional as F -from torch import nn - - -@dataclass -class ModelArgs: - dim: int = 4096 - n_layers: int = 32 - n_heads: int = 32 - n_kv_heads: Optional[int] = None - vocab_size: int = -1 # defined later by tokenizer - multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 - ffn_dim_multiplier: Optional[float] = None - norm_eps: float = 1e-5 - max_batch_size: int = 32 - max_seq_len: int = 2048 - bf16_enable = True - - -class RMSNorm(torch.nn.Module): - - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - output = self._norm(x.float()).type_as(x) - return output * self.weight - - -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): - freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device) # type: ignore - freqs = torch.outer(t, freqs).float() # type: ignore - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - return freqs_cis - - -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): - ndim = x.ndim - assert 0 <= 1 < ndim - assert freqs_cis.shape == (x.shape[-3], x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - # bs, seqlen, heads, dim - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - freqs_cis = reshape_for_broadcast(freqs_cis, xq_) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" - bs, slen, n_kv_heads, head_dim = x.shape - if n_rep == 1: - return x - return (x[:, :, :, - None, :].expand(bs, slen, n_kv_heads, n_rep, - head_dim).reshape(bs, slen, n_kv_heads * n_rep, - head_dim)) - - -class Attention(nn.Module): - - def __init__(self, args: ModelArgs): - super().__init__() - - self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads - self.n_local_heads = args.n_heads - self.n_local_kv_heads = self.n_kv_heads - self.n_rep = self.n_local_heads // self.n_local_kv_heads - self.head_dim = args.dim // args.n_heads - - init_method = lambda x: x - - self.wq = nn.Linear( - args.dim, - args.n_heads * self.head_dim, - bias=False, - ) - self.wk = nn.Linear( - args.dim, - self.n_kv_heads * self.head_dim, - bias=False, - ) - self.wv = nn.Linear( - args.dim, - self.n_kv_heads * self.head_dim, - bias=False, - ) - self.wo = nn.Linear( - args.n_heads * self.head_dim, - args.dim, - bias=False, - ) - - def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor, - mask: Optional[torch.Tensor], prefill: bool, - input_indexes: torch.Tensor, cache_indexes: torch.Tensor, cache_k, - cache_v): - # bsz, seqlen, _ = x.shape - bsz, seqlen = x.shape[0], x.shape[-2] - xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - - xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) - xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - - xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) - input_indexes = input_indexes.to(torch.int64) - - cache_k = cache_k.index_copy(1, input_indexes, xk) - cache_v = cache_v.index_copy(1, input_indexes, xv) - - #keys = cache_k.index_select(1, cache_indexes) - #values = cache_v.index_select(1, cache_indexes) - keys = cache_k - values = cache_v - # repeat k/v heads if n_kv_heads < n_heads - keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) - values = repeat_kv(values, - self.n_rep) # (bs, seqlen, n_local_heads, head_dim) - - #xq = xq.transpose(-3, -2) # (bs, n_local_heads, seqlen, head_dim) - #keys = keys.transpose(-3,-2) - #values = values.transpose(-3,-2) - xq_new = torch.clone(xq) - #scores = torch.matmul(xq_new, keys.transpose(-3,-2).transpose(-2, -1)) / math.sqrt(self.head_dim) - scores = torch.einsum('ijkl,imkl->ikjm', xq_new, keys) / math.sqrt( - self.head_dim) - if mask is not None: - scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) - scores = F.softmax(scores.float(), dim=-1).type_as(xq_new) - #output = torch.matmul(scores, values.transpose(-3,-2)) # (bs, n_local_heads, seqlen, head_dim) - output = torch.einsum('ikjm,imkl->ikjl', scores, values) - output = output.transpose(-3, -2).contiguous().view(bsz, seqlen, -1) - - return self.wo(output), cache_k, cache_v - - -class FeedForward(nn.Module): - - def __init__( - self, - dim: int, - hidden_dim: int, - multiple_of: int, - ffn_dim_multiplier: Optional[float], - ): - super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - # custom dim factor multiplier - if ffn_dim_multiplier is not None: - hidden_dim = int(ffn_dim_multiplier * hidden_dim) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - - init_method = lambda x: x - - self.w1 = nn.Linear( - dim, - hidden_dim, - bias=False, - ) - self.w2 = nn.Linear( - hidden_dim, - dim, - bias=False, - ) - self.w3 = nn.Linear( - dim, - hidden_dim, - bias=False, - ) - - def forward(self, x): - return self.w2(F.silu(self.w1(x)) * self.w3(x)) - - -class TransformerBlock(nn.Module): - - def __init__(self, - layer_id: int, - args: ModelArgs, - groups: Optional[List] = None): - super().__init__() - self.n_heads = args.n_heads - self.dim = args.dim - self.head_dim = args.dim // args.n_heads - - self.attention = Attention(args,) - self.feed_forward = FeedForward( - dim=args.dim, - hidden_dim=4 * args.dim, - multiple_of=args.multiple_of, - ffn_dim_multiplier=args.ffn_dim_multiplier, - ) - self.layer_id = layer_id - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) - - def forward( - self, - x: torch.Tensor, - freqs_cis: torch.Tensor, - mask: Optional[torch.Tensor], - prefill: bool, - input_indexes: torch.Tensor, - cache_indexes, - cache_k, - cache_v, - ): - attn, xk, xv = self.attention.forward( - self.attention_norm(x), freqs_cis, mask, prefill, input_indexes, - cache_indexes, cache_k, cache_v) - h = x + attn - out = h + self.feed_forward.forward(self.ffn_norm(h)) - return out, xk, xv - - -class Transformer(nn.Module): - - def __init__(self, - params: ModelArgs, - world_size: Optional[int] = None, - rank: Optional[int] = None, - groups: Optional[List] = None): - super().__init__() - self.params = params - self.vocab_size = params.vocab_size - self.n_layers = params.n_layers - - init_method = lambda x: x - - self.tok_embeddings = nn.Embedding( - params.vocab_size, - params.dim, - ) - - self.layers = torch.nn.ModuleList() - for layer_id in range(params.n_layers): - self.layers.append(TransformerBlock( - layer_id, - params, - )) - - self.norm = RMSNorm(params.dim, eps=params.norm_eps) - self.output = nn.Linear( - params.dim, - params.vocab_size, - bias=False, - ) - - freqs_cis = precompute_freqs_cis(self.params.dim // self.params.n_heads, - self.params.max_seq_len * 2) - # self.register_buffer("freqs_cis", freqs_cis) - self.freqs_cis = freqs_cis - mask = torch.full( - (1, 1, self.params.max_seq_len, self.params.max_seq_len), - float("-inf")).to( - torch.bfloat16 if self.params.bf16_enable else torch.float) - mask = torch.triu(mask, diagonal=1) - self.mask = mask - # self.register_buffer("mask", mask) - - @torch.no_grad() - def forward(self, tokens: torch.Tensor, input_indexes: torch.Tensor, - cache_indexes, caches: List[Tuple[torch.tensor, ...]], prefill): - seqlen = tokens.shape[-1] - h = self.tok_embeddings(tokens) - freqs_cis = self.freqs_cis.index_select(0, input_indexes) - mask = None - if prefill: - mask = torch.full((1, 1, seqlen, seqlen), float("-inf")).to( - torch.bfloat16 if self.params.bf16_enable else torch.float) - mask = torch.triu(mask, diagonal=1) - - new_caches = [] - for layer, (cache_k, cache_v) in zip(self.layers, caches): - h, new_k, new_v = layer(h, freqs_cis, mask, prefill, input_indexes, - cache_indexes, cache_k, cache_v) - new_caches.append((new_k, new_v)) - h = self.norm(h) - output = self.output(h).float() - return output, new_caches diff --git a/torchax/test/llama/test_llama.py b/torchax/test/llama/test_llama.py deleted file mode 100644 index b06f8160b0f4..000000000000 --- a/torchax/test/llama/test_llama.py +++ /dev/null @@ -1,111 +0,0 @@ -import torch -from torchax import tensor # pylint: disable=unused-import -import torchax -import torchax.export - -from .. import base_test_util -from . import llama_model -from . import model_exportable -from torch.utils import _pytree as pytree - - -class LlamaTest(base_test_util.TestCase): - - def test_can_run(self): - with torchax.default_env(): - sample_args = ( - torch.randint(0, 32000, (1, 2048), device='jax:0'), - torch.arange(0, 2048, device='jax:0'), - ) - - model_args = llama_model.ModelArgs( - block_size=2048, - vocab_size=32000, - n_layer=2, - n_head=4, - dim=256, - ) - m = llama_model.Transformer(model_args) - m.to(torch.bfloat16) - m.setup_caches(1, 2048) - m = m.to('jax') - - print(m(*sample_args)) - - def test_can_run_exportable(self): - model_args = model_exportable.ModelArgs( - vocab_size=32000, - n_layers=2, - n_heads=4, - dim=256, - ) - m = model_exportable.Transformer(model_args) - context_length = 2048 - input_shape_prefill = (1, context_length) - input_shape_decode = (1, 1) - - def make_cache(args, batch_size): - n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads - n_local_heads = args.n_heads - n_local_kv_heads = n_kv_heads - n_rep = n_local_heads // n_local_kv_heads - head_dim = args.dim // args.n_heads - res = [] - for i in range(args.n_layers): - if batch_size is None: - size = ( - args.max_seq_len, - n_local_kv_heads, - head_dim, - ) - else: - size = ( - batch_size, - args.max_seq_len, - n_local_kv_heads, - head_dim, - ) - res.append( - (torch.zeros( - size, - dtype=torch.bfloat16 if args.bf16_enable else torch.float), - torch.zeros( - size, - dtype=torch.bfloat16 if args.bf16_enable else torch.float))) - return res - - prefill_caches = make_cache(model_args, 1) - - sample_input_prefill = ( - torch.randint(0, 1000, input_shape_prefill, - dtype=torch.int32), # len seq length - torch.arange(0, context_length, dtype=torch.int32), # input indexes - torch.arange(0, context_length, dtype=torch.int32), # context indexes - prefill_caches, - True, # prefil - ) - with torch.no_grad(): - m_prefill = torch.export.export(m, sample_input_prefill) - - weights, mj_prefill = torchax.export.exported_program_to_jax(m_prefill) - env = torchax.default_env() - sample_inputs = env.t2j_copy(sample_input_prefill) - print('Prefill', mj_prefill(weights, sample_inputs)) - - sample_input_decode = ( - torch.randint(0, 1000, input_shape_decode, - dtype=torch.int32), # len = 1 - torch.tensor([0], dtype=torch.int32), - torch.roll(torch.arange(context_length, dtype=torch.int32), 1, 0), - prefill_caches, - False # prefill - ) - with torch.no_grad(): - m_decode = torch.export.export(m, sample_input_decode) - weights, mj_decode = torchax.export.exported_program_to_jax(m_decode) - sample_inputs = env.t2j_copy(sample_input_decode) - print('Decode', mj_decode(weights, sample_inputs)) - - -if __name__ == "__main__": - base_test_util.main() diff --git a/torchax/test/moe/__init__.py b/torchax/test/moe/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/torchax/test/moe/model.py b/torchax/test/moe/model.py deleted file mode 100644 index 83189e70aeda..000000000000 --- a/torchax/test/moe/model.py +++ /dev/null @@ -1,307 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -from dataclasses import dataclass -from typing import Optional - -import torch -import torch.nn as nn -from torch import Tensor -from torch.nn import functional as F - - -def find_multiple(n: int, k: int) -> int: - if n % k == 0: - return n - return n + k - (n % k) - - -@dataclass -class ModelArgs: - block_size: int = 2048 - vocab_size: int = 32000 - n_layer: int = 32 - n_head: int = 32 - dim: int = 4096 - intermediate_size: int = None - n_local_heads: int = -1 - head_dim: int = 64 - rope_base: float = 10000 - norm_eps: float = 1e-5 - num_experts: int = 8 - num_activated_experts: int = 2 - - def __post_init__(self): - if self.n_local_heads == -1: - self.n_local_heads = self.n_head - if self.intermediate_size is None: - hidden_dim = 4 * self.dim - n_hidden = int(2 * hidden_dim / 3) - self.intermediate_size = find_multiple(n_hidden, 256) - self.head_dim = self.dim // self.n_head - - @classmethod - def from_name(cls, name: str): - if name in transformer_configs: - return cls(**transformer_configs[name]) - # fuzzy search - config = [ - config for config in transformer_configs - if config in str(name).upper() or config in str(name) - ] - assert len(config) == 1, name - return cls(**transformer_configs[config[0]]) - - -transformer_configs = { - "Mixtral-8x7B-v0.1": - dict( - block_size=32768, - n_layer=32, - n_head=32, - n_local_heads=8, - dim=4096, - intermediate_size=14336, - rope_base=1000000.0, - num_experts=8, - num_activated_experts=2), -} - - -class KVCache(nn.Module): - - def __init__(self, - max_batch_size, - max_seq_length, - n_heads, - head_dim, - dtype=torch.bfloat16): - super().__init__() - cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) - self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype)) - self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) - - def update(self, input_pos, k_val, v_val): - # input_pos: [S], k_val: [B, H, S, D] - assert input_pos.shape[0] == k_val.shape[2] - - k_out = self.k_cache - v_out = self.v_cache - k_out[:, :, input_pos] = k_val - v_out[:, :, input_pos] = v_val - - return k_out, v_out - - -class Transformer(nn.Module): - - def __init__(self, config: ModelArgs) -> None: - super().__init__() - self.config = config - - self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) - self.layers = nn.ModuleList( - TransformerBlock(config) for _ in range(config.n_layer)) - self.norm = RMSNorm(config.dim, eps=config.norm_eps) - self.output = nn.Linear(config.dim, config.vocab_size, bias=False) - - self.freqs_cis: Optional[Tensor] = None - self.mask_cache: Optional[Tensor] = None - self.max_batch_size = -1 - self.max_seq_length = -1 - - def setup_caches(self, max_batch_size, max_seq_length): - if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: - return - head_dim = self.config.dim // self.config.n_head - max_seq_length = find_multiple(max_seq_length, 8) - self.max_seq_length = max_seq_length - self.max_batch_size = max_batch_size - for b in self.layers: - b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, - self.config.n_local_heads, head_dim) - - self.freqs_cis = precompute_freqs_cis(self.config.block_size, - self.config.dim // self.config.n_head, - self.config.rope_base) - self.causal_mask = torch.tril( - torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) - - def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: - assert self.freqs_cis is not None, "Caches must be initialized first" - mask = self.causal_mask[None, None, input_pos] - freqs_cis = self.freqs_cis[input_pos] - x = self.tok_embeddings(idx) - - for i, layer in enumerate(self.layers): - x = layer(x, input_pos, freqs_cis, mask) - x = self.norm(x) - logits = self.output(x) - return logits - - @classmethod - def from_name(cls, name: str): - return cls(ModelArgs.from_name(name)) - - -class TransformerBlock(nn.Module): - - def __init__(self, config: ModelArgs) -> None: - super().__init__() - self.attention = Attention(config) - self.block_sparse_moe = MOEFeedForward(config) - self.ffn_norm = RMSNorm(config.dim, config.norm_eps) - self.attention_norm = RMSNorm(config.dim, config.norm_eps) - - def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, - mask: Tensor) -> Tensor: - h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) - out = h + self.block_sparse_moe(self.ffn_norm(h)) - return out - - -class Attention(nn.Module): - - def __init__(self, config: ModelArgs): - super().__init__() - assert config.dim % config.n_head == 0 - - total_head_dim = (config.n_head + - 2 * config.n_local_heads) * config.head_dim - # key, query, value projections for all heads, but in a batch - self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) - self.wo = nn.Linear(config.dim, config.dim, bias=False) - self.kv_cache = None - - self.n_head = config.n_head - self.head_dim = config.head_dim - self.n_local_heads = config.n_local_heads - self.dim = config.dim - self._register_load_state_dict_pre_hook(self.load_hook) - - def load_hook(self, state_dict, prefix, *args): - if prefix + "wq.weight" in state_dict: - wq = state_dict.pop(prefix + "wq.weight") - wk = state_dict.pop(prefix + "wk.weight") - wv = state_dict.pop(prefix + "wv.weight") - state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) - - def forward(self, - x: Tensor, - freqs_cis: Tensor, - mask: Tensor, - input_pos: Optional[Tensor] = None) -> Tensor: - bsz, seqlen, _ = x.shape - - kv_size = self.n_local_heads * self.head_dim - q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) - - q = q.view(bsz, seqlen, self.n_head, self.head_dim) - k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) - v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) - - q = apply_rotary_emb(q, freqs_cis) - k = apply_rotary_emb(k, freqs_cis) - - q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) - - if self.kv_cache is not None: - k, v = self.kv_cache.update(input_pos, k, v) - - k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) - v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) - y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) - - y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) - - y = self.wo(y) - return y - - -class ConditionalFeedForward(nn.Module): - - def __init__(self, config): - super().__init__() - self.w1 = nn.Parameter( - torch.empty(config.num_experts, config.intermediate_size, config.dim)) - self.w2 = nn.Parameter( - torch.empty(config.num_experts, config.dim, config.intermediate_size)) - self.w3 = nn.Parameter( - torch.empty(config.num_experts, config.intermediate_size, config.dim)) - - def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor: - w1_weights = self.w1[expert_indices] # [T, A, D, D] - w3_weights = self.w3[expert_indices] # [T, A, D, D] - w2_weights = self.w2[expert_indices] # [T, A, D, D] - x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights)) - x3 = torch.einsum('ti, taoi -> tao', x, w3_weights) - expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights) - return expert_outs - - -class MOEFeedForward(nn.Module): - - def __init__(self, config) -> None: - super().__init__() - self.gate = nn.Linear(config.dim, config.num_experts, bias=False) - self.cond_ffn = ConditionalFeedForward(config) - self.dim = config.dim - self.num_activated_experts = config.num_activated_experts - - def forward(self, x: Tensor) -> Tensor: - x = x.view(-1, self.dim) - # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts - # x: [T, D] - scores = self.gate(x) # [T, E] - expert_weights = F.softmax(scores, dim=-1) - expert_weights, expert_indices = torch.topk( - expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A] - expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A] - expert_outs = self.cond_ffn(x, expert_indices) - return torch.einsum('tai,ta -> ti', expert_outs, expert_weights) - - -class RMSNorm(nn.Module): - - def __init__(self, dim: int, eps: float = 1e-5): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x): - return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) - - def forward(self, x: Tensor) -> Tensor: - output = self._norm(x.float()).type_as(x) - return output * self.weight - - -def precompute_freqs_cis(seq_len: int, - n_elem: int, - base: int = 10000) -> Tensor: - freqs = 1.0 / ( - base**(torch.arange(0, n_elem, 2)[:(n_elem // 2)].float() / n_elem)) - t = torch.arange(seq_len, device=freqs.device) - freqs = torch.outer(t, freqs) - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) - cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) - return cache.to(dtype=torch.bfloat16) - - -def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: - xshaped = x.float().reshape(*x.shape[:-1], -1, 2) - freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) - x_out2 = torch.stack( - [ - xshaped[..., 0] * freqs_cis[..., 0] - - xshaped[..., 1] * freqs_cis[..., 1], - xshaped[..., 1] * freqs_cis[..., 0] + - xshaped[..., 0] * freqs_cis[..., 1], - ], - -1, - ) - - x_out2 = x_out2.flatten(3) - return x_out2.type_as(x) diff --git a/torchax/test/moe/moe_test.py b/torchax/test/moe/moe_test.py deleted file mode 100644 index ca841f084a58..000000000000 --- a/torchax/test/moe/moe_test.py +++ /dev/null @@ -1,68 +0,0 @@ -import torchax -import torchax.interop -import torch -import unittest -import jax - -from test.moe import model - - -class TestMoe(unittest.TestCase): - - def _make_tiny_config(self): - return model.ModelArgs( - block_size=128, - vocab_size=32000, - n_layer=4, - n_head=4, - dim=128, - intermediate_size=None, - n_local_heads=-1, - head_dim=32, - rope_base=10000, - norm_eps=1e-5, - num_experts=8, - num_activated_experts=2, - ) - - def _random_init(self, model): - new_state_dict = {} - - for k, v in model.state_dict().items(): - new_state_dict[k] = torch.randn_like(v) - - model.load_state_dict(new_state_dict, assign=True) - return model - - def test_moe_layer(self): - model_args = self._make_tiny_config() - - moe_layer = model.MOEFeedForward(model_args) - moe_layer = self._random_init(moe_layer) - seqlen = 32 - x = torch.randn((seqlen, model_args.dim)) - res = moe_layer(x) - - env = torchax.default_env() - model_xla = env.to_xla(moe_layer) - x_xla = env.to_xla(x) - with jax.default_matmul_precision('float32'): - res_xla = model_xla(x_xla) - res2 = res_xla.to('cpu') - print('max diff', torch.max((res - res2).abs())) - - self.assertTrue(torch.allclose(res2, res, atol=1e-2)) - - # test can jit - - def f(weights, x): - return torch.func.functional_call(moe_layer, weights, (x,)) - - fjitted = torchax.interop.jax_jit(f) - weights_xla = env.to_xla(moe_layer.state_dict()) - - print(fjitted(weights_xla, x_xla)) - - -if __name__ == '__main__': - unittest.main() diff --git a/torchax/test/test_amp.py b/torchax/test/test_amp.py deleted file mode 100644 index 608e5deb92a1..000000000000 --- a/torchax/test/test_amp.py +++ /dev/null @@ -1,38 +0,0 @@ -import unittest -import jax -import jax.numpy as jnp -import torchax -from torchax import interop -import torch - - -class AutocastTest(unittest.TestCase): - - def setUp(self): - self.env = torchax.default_env() - - def test_auto_cast_ir(self): - with self.env: - with torchax.amp.autocast('jax', dtype=torch.bfloat16, env=self.env): - a = jax.ShapeDtypeStruct((2, 2), jnp.float32) - b = jax.ShapeDtypeStruct((2, 2), jnp.float32) - ir_text = jax.jit(interop.jax_view(torch.matmul)).lower(a, b).as_text() - self.assertIn('tensor<2x2xbf16>', ir_text) - - def test_auto_cast_matmul(self): - with self.env: - a = torch.randn(2, 2, device='jax') - b = torch.randn(2, 2, device='jax') - with torchax.amp.autocast('jax', dtype=torch.bfloat16, env=self.env): - c = a @ b - - self.assertEqual(c.dtype, torch.bfloat16) - - with torch.autocast('cpu', dtype=torch.bfloat16): - c_cpu = a.cpu() @ b.cpu() - - self.assertTrue(torch.allclose(c.cpu(), c_cpu)) - - -if __name__ == '__main__': - unittest.main() diff --git a/torchax/test/test_checkpoint.py b/torchax/test/test_checkpoint.py deleted file mode 100644 index 4867d44b1eb8..000000000000 --- a/torchax/test/test_checkpoint.py +++ /dev/null @@ -1,102 +0,0 @@ -import unittest -import torch -import torch.nn as nn -import torchax -from torchax.checkpoint import _to_torch, _to_jax -import optax -import tempfile -import os -import jax -import jax.numpy as jnp -import shutil - - -class CheckpointTest(unittest.TestCase): - - def test_save_and_load_jax_style_checkpoint(self): - model = torch.nn.Linear(10, 20) - optimizer = optax.adam(1e-3) - - torchax.enable_globally() - params_jax, _ = torchax.extract_jax(model) - opt_state = optimizer.init(params_jax) - torchax.disable_globally() - - epoch = 1 - state = { - 'model': model.state_dict(), - 'opt_state': opt_state, - 'epoch': epoch, - } - - with tempfile.TemporaryDirectory() as tmpdir: - path = os.path.join(tmpdir, 'checkpoint') - torchax.save_checkpoint(state, path, step=epoch) - loaded_state_jax = torchax.load_checkpoint(path) - loaded_state = _to_torch(loaded_state_jax) - - self.assertEqual(state['epoch'], loaded_state['epoch']) - - # Compare model state_dict - for key in state['model']: - self.assertTrue( - torch.allclose(state['model'][key], loaded_state['model'][key])) - - # Compare optimizer state - original_leaves = jax.tree_util.tree_leaves(state['opt_state']) - loaded_leaves = jax.tree_util.tree_leaves(loaded_state['opt_state']) - for original_leaf, loaded_leaf in zip(original_leaves, loaded_leaves): - if isinstance(original_leaf, (jnp.ndarray, jax.Array)): - # Convert loaded leaf to numpy array for comparison if it is a DeviceArray - self.assertTrue(jnp.allclose(original_leaf, jnp.asarray(loaded_leaf))) - else: - self.assertEqual(original_leaf, loaded_leaf) - - def test_load_pytorch_style_checkpoint(self): - model = torch.nn.Linear(10, 20) - optimizer = optax.adam(1e-3) - - torchax.enable_globally() - params_jax, _ = torchax.extract_jax(model) - opt_state = optimizer.init(params_jax) - torchax.disable_globally() - - epoch = 1 - state = { - 'model': model.state_dict(), - 'opt_state': opt_state, - 'epoch': epoch, - } - - with tempfile.TemporaryDirectory() as tmpdir: - path = os.path.join(tmpdir, 'checkpoint.pt') - torch.save(state, path) - loaded_state_jax = torchax.load_checkpoint(path) - - # convert original state to jax for comparison - state_jax = _to_jax(state) - - self.assertEqual(state_jax['epoch'], loaded_state_jax['epoch']) - - # Compare model state_dict - for key in state_jax['model']: - self.assertTrue( - jnp.allclose(state_jax['model'][key], - loaded_state_jax['model'][key])) - - # Compare optimizer state - original_leaves = jax.tree_util.tree_leaves(state_jax['opt_state']) - loaded_leaves = jax.tree_util.tree_leaves(loaded_state_jax['opt_state']) - for original_leaf, loaded_leaf in zip(original_leaves, loaded_leaves): - if isinstance(original_leaf, (jnp.ndarray, jax.Array)): - self.assertTrue(jnp.allclose(original_leaf, loaded_leaf)) - else: - self.assertEqual(original_leaf, loaded_leaf) - - def test_load_non_existent_checkpoint(self): - with self.assertRaises(FileNotFoundError): - torchax.load_checkpoint('/path/to/non_existent_checkpoint') - - -if __name__ == '__main__': - unittest.main() diff --git a/torchax/test/test_context.py b/torchax/test/test_context.py deleted file mode 100644 index ace28eeb4265..000000000000 --- a/torchax/test/test_context.py +++ /dev/null @@ -1,107 +0,0 @@ -import unittest - -import torch -import torchax -from torchax import tensor -import torchax.interop - -xla_env = torchax.default_env() - - -class TestContext(unittest.TestCase): - - def test_mode_context_manager(self): - with xla_env: - x = torch.full((3, 3), -1, device='jax') - self.assertIsInstance(x, tensor.Tensor) - y = x.abs() - self.assertIsInstance(y, tensor.Tensor) - - @staticmethod - @xla_env - def _test_mode_decorator(): - x = torch.full((3, 3), -1).to('jax') - y = x.abs() - - return x, y - - def test_mode_decorator(self): - x, y = self._test_mode_decorator() - self.assertIsInstance(x, tensor.Tensor) - self.assertIsInstance(y, tensor.Tensor) - - def test_same_manual_seed(self): - with xla_env: - xla_env.manual_seed(1234) - x = torch.randn((3, 3), device='jax') - self.assertIsInstance(x, tensor.Tensor) - - xla_env.manual_seed(1234) - y = torch.randn((3, 3), device='jax') - self.assertIsInstance(y, tensor.Tensor) - - self.assertTrue(torch.allclose(x, y)) - - def test_different_manual_seed(self): - with xla_env: - xla_env.manual_seed(1234) - x = torch.randn((3, 3), device='jax') - self.assertIsInstance(x, tensor.Tensor) - - xla_env.manual_seed(12345) - y = torch.randn((3, 3), device='jax') - self.assertIsInstance(y, tensor.Tensor) - - self.assertFalse(torch.allclose(x, y)) - - def test_jit_with_rng(self): - - with xla_env: - - def random_op(): - x = torch.randn(3, 3, device='jax') - y = torch.randn(3, 3, device='jax') - return x @ y - - random_jit = torchax.interop.jax_jit(random_op) - self.assertIsInstance(random_jit(), tensor.Tensor) - - # If we run the JIT twice, the random values should be different. - # TODO(qihqi): think about API for passing down seed - # with self.assertRaises(AssertionError): - # torch.testing.assert_close(random_jit(), random_jit(), atol=0, rtol=0) - - def test_generator_seed(self): - with xla_env: - x = torch.randn( - 2, 3, generator=torch.Generator().manual_seed(0), device='jax') - y = torch.randn( - 2, 3, generator=torch.Generator().manual_seed(0), device='jax') - - # Values will be the same given the same seed. - torch.testing.assert_close(x, y) - - def test_buffer(self): - - class M(torch.nn.Module): - - def __init__(self): - super().__init__() - c = torch.rand(2) - self.register_buffer('c', c) - self.register_buffer('c2', c, persistent=False) - - # Test context manager. - with xla_env: - m = M().to('jax') - self.assertIsInstance(m.c, tensor.Tensor) - self.assertIsInstance(m.c2, tensor.Tensor) - # Test `to_xla`. - m = M() - m = xla_env.to_xla(m) - self.assertIsInstance(m.c, tensor.Tensor) - self.assertIsInstance(m.c2, tensor.Tensor) - - -if __name__ == "__main__": - unittest.main() diff --git a/torchax/test/test_conv.py b/torchax/test/test_conv.py deleted file mode 100644 index 2e3f2b9156fa..000000000000 --- a/torchax/test/test_conv.py +++ /dev/null @@ -1,82 +0,0 @@ -import torch -from torch import nn -import torchax -from . import base_test_util - - -class CustomConv1(torch.nn.Module): - - def __init__( - self, - channels_conv1=3, - width_conv1=3, - channels_conv2=5, - width_conv2=5, - hidden_layer_size=50, - ): - super(CustomConv1, self).__init__() - self.conv1 = nn.Conv1d(1, channels_conv1, width_conv1) - self.conv2 = nn.Conv1d(channels_conv1, channels_conv2, width_conv2) - self.fc1 = nn.Linear(hidden_layer_size, 2) - - def forward(self, x): - x = nn.functional.max_pool1d(nn.functional.relu(self.conv1(x)), 2, stride=2) - x = nn.functional.max_pool1d(nn.functional.relu(self.conv2(x)), 2, stride=2) - x = torch.flatten(x, 1) - x = nn.functional.softmax(self.fc1(x), dim=1) - return x - - -class CustomConv2(nn.Module): - - def __init__(self): - super().__init__() - inp = 4 - out = 16 - - self.conv = nn.Conv2d(inp, out, kernel_size=3, padding=1) - - # This is supposed to be a squeeze and excitation block. - self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) - - self.scale = nn.Sequential(nn.Linear(out, out), nn.Sigmoid()) - - def forward(self, x): - x = self.conv(x) - - b = x.shape[0] - ap = self.avg_pool(x).view(b, -1) - ap = self.scale(ap) - ap = ap.view(b, -1, 1, 1) - - return x * ap - - -class ConvTest(base_test_util.TestCase): - - def test_conv1(self): - env = torchax.default_env() - m = CustomConv1() - arg = torch.randn((20, 1, 50)) - res = m(arg) - - jax_weights, jax_func = torchax.extract_jax(m) - arg = env.t2j_copy(arg) - res2 = jax_func(jax_weights, (arg,)) - res2_torch = env.j2t_copy(res2) - self.assertTrue(torch.allclose(res, res2_torch)) - - def test_conv2(self): - env = torchax.default_env() - m = CustomConv2() - arg = torch.randn((20, 4, 50, 100)) - res = m(arg) - jax_weights, jax_func = torchax.extract_jax(m) - arg = env.t2j_copy(arg) - res2 = jax_func(jax_weights, (arg,)) - res2_torch = env.j2t_copy(res2) - self.assertTrue(torch.allclose(res, res2_torch, atol=1e-4, rtol=1e-4)) - - -if __name__ == '__main__': - base_test_util.main() diff --git a/torchax/test/test_core_aten_ops.py b/torchax/test/test_core_aten_ops.py deleted file mode 100644 index 7a24c8bac408..000000000000 --- a/torchax/test/test_core_aten_ops.py +++ /dev/null @@ -1,4534 +0,0 @@ -import math -import unittest - -import torch -from torchax import tensor - -from . import base_test_util -from torch.utils import _pytree as pytree - - -def diff_output(testcase, - output1, - output2, - rtol, - atol, - equal_nan=True, - check_dtype=False): - if isinstance(output1, torch.Tensor): - testcase.assertIsInstance(output2, torch.Tensor) - output2_cpu = output2.detach().cpu() - torch.testing.assert_close( - output1, - output2_cpu, - atol=atol, - rtol=rtol, - equal_nan=equal_nan, - check_dtype=check_dtype) - elif isinstance(output1, (tuple, list)): - testcase.assertIsInstance(output2, (tuple, list)) - testcase.assertEqual(len(output1), len(output2)) - for o1, o2 in zip(output1, output2): - diff_output( - testcase, - o1, - o2, - rtol, - atol, - equal_nan=equal_nan, - check_dtype=check_dtype) - else: - testcase.assertEqual(output1, output2) - - -def run_export_and_compare(testcase, - func, - args, - kwargs, - atol=1e-3, - rtol=1e-5, - equal_nan=True, - check_dtype=False, - ignore_indices=False): - - with testcase.subTest("torch_eval"): - res = func(*args, **kwargs) - with testcase.subTest("torchax_eval"): - args2, kwargs2 = testcase.env.to_xla((args, kwargs)) - with testcase.env: - res2 = func(*args2, **kwargs2) - res2 = pytree.tree_map_only(tensor.Tensor, lambda t: t.torch(), res2) - # import pdb; pdb.set_trace() - with testcase.subTest("torchax_diff:" + str(atol)): - if ignore_indices and isinstance(res, tuple) and len(res) == 2: - diff_output( - testcase, - res[0], - res2[0], - atol=atol, - rtol=rtol, - equal_nan=equal_nan, - check_dtype=check_dtype) - else: - diff_output( - testcase, - res, - res2, - atol=atol, - rtol=rtol, - equal_nan=equal_nan, - check_dtype=check_dtype) - - -class TestCoreAtenOps(unittest.TestCase): - - @classmethod - def setUpClass(cls): - super().setUpClass() - - def setUp(self): - super().setUp() - torch.manual_seed(0) - self.env = tensor.Environment() - - def test_aten_abs_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.abs, args, kwargs) - - def test_aten_abs_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.abs, args, kwargs) - - def test_aten_abs_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.abs, args, kwargs) - - def test_aten_acos_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.acos, args, kwargs) - - def test_aten_acos_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.acos, args, kwargs, atol=0.005) - - def test_aten_acos_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.acos, args, kwargs) - - def test_aten_acosh_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.acosh, args, kwargs) - - def test_aten_acosh_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.acosh, args, kwargs) - - def test_aten_acosh_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.acosh, args, kwargs) - - def test_aten_unsqueeze_0(self): - args = ( - torch.randn((1, 3, 10)).to(torch.float32), - -2, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.unsqueeze, args, kwargs) - - def test_aten_unsqueeze_1(self): - args = ( - torch.randn((1, 3, 10)).to(torch.float16), - -2, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.unsqueeze, args, kwargs) - - def test_aten_unsqueeze_2(self): - args = ( - torch.randint(0, 10, (1, 3, 10)).to(torch.int32), - -2, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.unsqueeze, args, kwargs) - - def test_aten_unsqueeze_3(self): - args = ( - torch.randn((1, 3, 10)).to(torch.float32), - -2, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.unsqueeze, args, kwargs) - - def test_aten_unsqueeze_4(self): - args = ( - torch.randn((1, 3, 10)).to(torch.float16), - -2, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.unsqueeze, args, kwargs) - - def test_aten_unsqueeze_5(self): - args = ( - torch.randint(0, 10, (1, 3, 10)).to(torch.int32), - -2, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.unsqueeze, args, kwargs) - - def test_aten_unsqueeze_6(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.unsqueeze, args, kwargs) - - def test_aten_unsqueeze_7(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.unsqueeze, args, kwargs) - - def test_aten_unsqueeze_8(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.unsqueeze, args, kwargs) - - def test_aten__adaptive_avg_pool2d_0(self): - args = ( - torch.randn((1, 3, 1, 10)).to(torch.float32), - [ - 1, - 5, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten._adaptive_avg_pool2d, args, - kwargs) - - def test_aten__adaptive_avg_pool2d_1(self): - args = ( - torch.randn((1, 3, 10, 10)).to(torch.float32), - [ - 5, - 5, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten._adaptive_avg_pool2d, args, - kwargs) - - def test_aten_avg_pool2d_2(self): - args = ( - torch.randn((1, 3, 6, 6)).to(torch.float32), - [3, 3], - [1, 1], - [1, 1], - True, - True, - None, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.avg_pool2d, args, kwargs) - - def test_aten_squeeze_dim_0(self): - args = ( - torch.randn((1, 3, 1, 5)).to(torch.float32), - -2, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.squeeze.dim, args, kwargs) - - def test_aten_squeeze_dim_1(self): - args = ( - torch.randn((1, 3, 1, 5)).to(torch.float32), - -2, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.squeeze.dim, args, kwargs) - - def test_aten_squeeze_dim_2(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.squeeze.dim, args, kwargs) - - def test_aten_squeeze_dim_3(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.squeeze.dim, args, kwargs) - - def test_aten_squeeze_dim_4(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.squeeze.dim, args, kwargs) - - def test_aten__adaptive_avg_pool3d_0(self): - args = ( - torch.randn((1, 3, 10, 10, 10)).to(torch.float32), - [ - 5, - 5, - 5, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten._adaptive_avg_pool3d, args, - kwargs) - - def test_aten__adaptive_avg_pool3d_1(self): - args = ( - torch.randn((1, 3, 10, 10, 10)).to(torch.float16), - [ - 5, - 5, - 5, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten._adaptive_avg_pool3d, args, - kwargs) - - def test_aten_add_Scalar_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 0.1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.add.Scalar, args, kwargs) - - def test_aten_add_Scalar_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 0.1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.add.Scalar, args, kwargs) - - def test_aten_add_Scalar_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 0.1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.add.Scalar, args, kwargs) - - def test_aten_add_Tensor_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.add.Tensor, args, kwargs) - - def test_aten_add_Tensor_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - torch.randn((10, 10)).to(torch.float16), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.add.Tensor, args, kwargs) - - def test_aten_add_Tensor_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - torch.randint(0, 10, (10, 10)).to(torch.int32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.add.Tensor, args, kwargs) - - def test_aten_addmm_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.addmm, args, kwargs) - - def test_aten_addmm_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - torch.randn((10, 10)).to(torch.float16), - torch.randn((10, 10)).to(torch.float16), - ) - kwargs = dict() - run_export_and_compare( - self, torch.ops.aten.addmm, args, kwargs, atol=0.001, rtol=0.001) - - def test_aten_addmm_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - torch.randint(0, 10, (10, 10)).to(torch.int32), - torch.randint(0, 10, (10, 10)).to(torch.int32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.addmm, args, kwargs) - - def test_aten_alias_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.alias, args, kwargs) - - def test_aten_alias_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.alias, args, kwargs) - - def test_aten_alias_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.alias, args, kwargs) - - def test_aten_amax_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - [ - 0, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.amax, args, kwargs) - - def test_aten_amax_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - [ - 0, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.amax, args, kwargs) - - def test_aten_amax_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - [ - 0, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.amax, args, kwargs) - - def test_aten_amin_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - [ - 0, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.amin, args, kwargs) - - def test_aten_amin_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - [ - 0, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.amin, args, kwargs) - - def test_aten_amin_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - [ - 0, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.amin, args, kwargs) - - def test_aten_any_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.any, args, kwargs) - - def test_aten_any_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.any, args, kwargs) - - def test_aten_any_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.any, args, kwargs) - - def test_aten_any_dim_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.any.dim, args, kwargs) - - def test_aten_any_dim_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.any.dim, args, kwargs) - - def test_aten_any_dim_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.any.dim, args, kwargs) - - def test_aten_any_dims_0(self): - args = (torch.randn((10, 10)).to(torch.float32), 0) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.any.dim, args, kwargs) - - def test_aten_any_dims_1(self): - args = (torch.randn((10, 10)).to(torch.float16), 0) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.any.dim, args, kwargs) - - def test_aten_any_dims_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32), 0) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.any.dim, args, kwargs) - - def test_aten_argmax_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.argmax, args, kwargs) - - def test_aten_argmax_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.argmax, args, kwargs) - - def test_aten_argmax_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.argmax, args, kwargs) - - def test_aten_argmin_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.argmin, args, kwargs) - - def test_aten_argmin_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.argmin, args, kwargs) - - def test_aten_argmin_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.argmin, args, kwargs) - - def test_aten_as_strided_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - [2, 2, 2], - [ - 8, - 4, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.as_strided, args, kwargs) - - def test_aten_as_strided_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - [ - 0, - 1, - ], - [ - 0, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.as_strided, args, kwargs) - - def test_aten_as_strided_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - [ - 0, - 1, - ], - [ - 0, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.as_strided, args, kwargs) - - def test_aten_as_strided_copy_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - [ - 5, - 5, - ], - [ - 2, - 2, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.as_strided_copy, args, kwargs) - - def test_aten_as_strided_copy_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - [ - 5, - 5, - ], - [ - 2, - 2, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.as_strided_copy, args, kwargs) - - def test_aten_as_strided_copy_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - [ - 5, - 5, - ], - [ - 2, - 2, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.as_strided_copy, args, kwargs) - - def test_aten_asin_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.asin, args, kwargs) - - def test_aten_asin_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.asin, args, kwargs) - - def test_aten_asin_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.asin, args, kwargs) - - def test_aten_asinh_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.asinh, args, kwargs) - - def test_aten_asinh_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.asinh, args, kwargs) - - def test_aten_asinh_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.asinh, args, kwargs) - - def test_aten_atan_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.atan, args, kwargs) - - def test_aten_atan_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.atan, args, kwargs) - - def test_aten_atan_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.atan, args, kwargs) - - def test_aten_atan2_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.atan2, args, kwargs) - - def test_aten_atan2_1(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - torch.randint(0, 10, (10, 10)).to(torch.int32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.atan2, args, kwargs) - - def test_aten_atanh_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.atanh, args, kwargs) - - def test_aten_atanh_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.atanh, args, kwargs) - - def test_aten_atanh_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.atanh, args, kwargs) - - def test_aten_avg_pool2d_0(self): - args = ( - torch.randn((1, 3, 1, 10)).to(torch.float32), - [ - 1, - 2, - ], - [ - 1, - 2, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.avg_pool2d, args, kwargs) - - def test_aten_avg_pool2d_1(self): - args = ( - torch.randn((3, 2, 10)).to(torch.float32), - [ - 2, - 2, - ], - [ - 1, - 1, - ], - [ - 1, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.avg_pool2d, args, kwargs) - - def test_aten_avg_pool3d_0(self): - args = ( - torch.randn((1, 3, 10, 10, 10)).to(torch.float32), - [ - 2, - 2, - 2, - ], - [ - 2, - 2, - 2, - ], - [ - 0, - 0, - 0, - ], - False, - False, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.avg_pool3d, args, kwargs) - - def test_aten_bitwise_and_Scalar_0(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.bitwise_and.Scalar, args, - kwargs) - - def test_aten_bitwise_and_Tensor_0(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - torch.randint(0, 10, (10, 10)).to(torch.int32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.bitwise_and.Tensor, args, - kwargs) - - def test_aten_bitwise_and_Tensor_1(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - torch.randint(0, 10, (10, 10)).to(torch.int32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.bitwise_and.Tensor, args, - kwargs) - - def test_aten_bitwise_and_Tensor_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - torch.randint(0, 10, (10, 10)).to(torch.int32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.bitwise_and.Tensor, args, - kwargs) - - def test_aten_bitwise_and_Tensor_3(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - torch.randint(0, 10, (10, 10)).to(torch.int32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.bitwise_and.Tensor, args, - kwargs) - - def test_aten_bitwise_or_Scalar_0(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.bitwise_or.Scalar, args, kwargs) - - def test_aten_bitwise_xor_Scalar_0(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.bitwise_xor.Scalar, args, - kwargs) - - def test_aten_bmm_0(self): - args = ( - torch.randn((10, 10, 10)).to(torch.float32), - torch.randn((10, 10, 10)).to(torch.float32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.bmm, args, kwargs) - - def test_aten_bmm_1(self): - args = ( - torch.randn((10, 10, 10)).to(torch.float16), - torch.randn((10, 10, 10)).to(torch.float16), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.bmm, args, kwargs) - - def test_aten_bmm_2(self): - args = ( - torch.randint(0, 10, (10, 10, 10)).to(torch.int32), - torch.randint(0, 10, (10, 10, 10)).to(torch.int32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.bmm, args, kwargs) - - def test_aten_cat_0(self): - args = ( - [ - torch.randn((10, 10)).to(torch.float32), - ], - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.cat, args, kwargs) - - def test_aten_cat_1(self): - args = ( - [ - torch.randn((10, 10)).to(torch.float32), - ], - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.cat, args, kwargs) - - def test_aten_cat_2(self): - args = ( - [ - torch.randn((10, 10)).to(torch.float32), - ], - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.cat, args, kwargs) - - def test_aten__cdist_forward_0(self): - args = ( - torch.randn((5, 7, 10)).to(torch.float32), - torch.randn((5, 8, 10)).to(torch.float32), - 1.0, - None, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten._cdist_forward, args, kwargs) - - def test_aten_ceil_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.ceil, args, kwargs) - - def test_aten_ceil_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.ceil, args, kwargs) - - def test_aten_ceil_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.ceil, args, kwargs) - - def test_aten_clamp_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 0, - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.clamp, args, kwargs) - - def test_aten_clamp_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 0, - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.clamp, args, kwargs) - - def test_aten_clamp_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 0, - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.clamp, args, kwargs) - - def test_aten_clamp_Tensor_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randn((1,)).to(torch.float32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.clamp.Tensor, args, kwargs) - - def test_aten_clamp_Tensor_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - torch.randn((1,)).to(torch.float16), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.clamp.Tensor, args, kwargs) - - def test_aten_clamp_Tensor_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - torch.randint(0, 10, (1,)).to(torch.int32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.clamp.Tensor, args, kwargs) - - def test_aten_clone_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.clone, args, kwargs) - - def test_aten_clone_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.clone, args, kwargs) - - def test_aten_clone_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.clone, args, kwargs) - - def test_aten_constant_pad_nd_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - [ - 0, - 1, - ], - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.constant_pad_nd, args, kwargs) - - def test_aten_constant_pad_nd_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - [ - 0, - 1, - ], - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.constant_pad_nd, args, kwargs) - - def test_aten_constant_pad_nd_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - [ - 0, - 1, - ], - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.constant_pad_nd, args, kwargs) - - def test_aten_convolution_0(self): - args = ( - torch.randn((3, 2, 10)).to(torch.float32), - torch.randn((2, 2, 2)).to(torch.float32), - None, - [ - 2, - ], - [ - 0, - ], - [ - 1, - ], - False, - [ - 0, - ], - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.convolution, args, kwargs) - - def test_aten_convolution_1(self): - args = ( - torch.randn((3, 2, 10)).to(torch.float16), - torch.randn((2, 2, 2)).to(torch.float16), - None, - [ - 2, - ], - [ - 0, - ], - [ - 1, - ], - False, - [ - 0, - ], - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.convolution, args, kwargs) - - def test_aten_convolution_2(self): - args = ( - torch.randint(0, 10, (3, 2, 10)).to(torch.int32), - torch.randint(0, 10, (2, 2, 2)).to(torch.int32), - None, - [ - 2, - ], - [ - 0, - ], - [ - 1, - ], - False, - [ - 0, - ], - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.convolution, args, kwargs) - - def test_aten_copy_0(self): - args = (torch.randn((10, 10)).to(torch.float32), torch.randn( - (10, 10)).to(torch.float32)) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.copy, args, kwargs) - - def test_aten_copy_broadcast(self): - args = (torch.randn( - (10, 10)).to(torch.float32), torch.tensor(1.0, dtype=torch.float32)) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.copy, args, kwargs) - - def test_aten_copy_cast_dtype(self): - args = (torch.randn((10, 10)).to(torch.float32), torch.randn( - (10, 10)).to(torch.int64)) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.copy, args, kwargs) - - def test_aten_cos_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.cos, args, kwargs) - - def test_aten_cos_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.cos, args, kwargs) - - def test_aten_cos_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.cos, args, kwargs) - - def test_aten_cosh_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.cosh, args, kwargs) - - def test_aten_cosh_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.cosh, args, kwargs) - - def test_aten_cosh_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.cosh, args, kwargs) - - def test_aten_cumsum_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.cumsum, args, kwargs) - - def test_aten_cumsum_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 1, - ) - kwargs = dict() - run_export_and_compare( - self, torch.ops.aten.cumsum, args, kwargs, atol=1e-2, rtol=1e-3) - - def test_aten_cumsum_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.cumsum, args, kwargs) - - def test_aten_cumsum_3(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 1, - ) - kwargs = dict(dtype=torch.float32) - run_export_and_compare(self, torch.ops.aten.cumsum, args, kwargs) - - def test_aten_diagonal_0(self): - args = (torch.randn((10, 20)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.diagonal, args, kwargs) - - def test_aten_diagonal_1(self): - args = (torch.randn((10, 20)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.diagonal, args, kwargs) - - def test_aten_diagonal_2(self): - args = (torch.randint(0, 10, (10, 20)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.diagonal, args, kwargs) - - def test_aten_div_Scalar_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 0.5, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.div.Scalar, args, kwargs) - - def test_aten_div_Scalar_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 0.5, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.div.Scalar, args, kwargs) - - def test_aten_div_Scalar_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 0.5, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.div.Scalar, args, kwargs) - - def test_aten_div_Scalar_mode_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 0.123, - ) - kwargs = { - "rounding_mode": "trunc", - } - run_export_and_compare(self, torch.ops.aten.div.Scalar_mode, args, kwargs) - - def test_aten_div_Scalar_mode_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 0.123, - ) - kwargs = { - "rounding_mode": "trunc", - } - run_export_and_compare( - self, torch.ops.aten.div.Scalar_mode, args, kwargs, rtol=0.1) - - def test_aten_div_Scalar_mode_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 0.123, - ) - kwargs = { - "rounding_mode": "trunc", - } - run_export_and_compare(self, torch.ops.aten.div.Scalar_mode, args, kwargs) - - def test_aten_div_Tensor_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.div.Tensor, args, kwargs) - - def test_aten_div_Tensor_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - torch.randn((10, 10)).to(torch.float16), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.div.Tensor, args, kwargs) - - def test_aten_div_Tensor_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - torch.randint(0, 10, (10, 10)).to(torch.int32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.div.Tensor, args, kwargs) - - def test_aten_div_Tensor_mode_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - ) - kwargs = { - "rounding_mode": "trunc", - } - run_export_and_compare(self, torch.ops.aten.div.Tensor_mode, args, kwargs) - - def test_aten_div_Tensor_mode_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - torch.randn((10, 10)).to(torch.float16), - ) - kwargs = { - "rounding_mode": "trunc", - } - run_export_and_compare(self, torch.ops.aten.div.Tensor_mode, args, kwargs) - - def test_aten_embedding_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randint(0, 10, (10,)).to(torch.int64), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.embedding, args, kwargs) - - def test_aten_embedding_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - torch.randint(0, 10, (10,)).to(torch.int64), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.embedding, args, kwargs) - - def test_aten_embedding_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - torch.randint(0, 10, (10,)).to(torch.int64), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.embedding, args, kwargs) - - def test_aten_eq_Scalar_0(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.eq.Scalar, args, kwargs) - - def test_aten_eq_Scalar_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.eq.Scalar, args, kwargs) - - def test_aten_eq_Scalar_2(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.eq.Scalar, args, kwargs) - - def test_aten_eq_Tensor_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.eq.Tensor, args, kwargs) - - def test_aten_eq_Tensor_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - torch.randn((10, 10)).to(torch.float16), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.eq.Tensor, args, kwargs) - - def test_aten_eq_Tensor_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - torch.randint(0, 10, (10, 10)).to(torch.int32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.eq.Tensor, args, kwargs) - - def test_aten_erf_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.erf, args, kwargs) - - def test_aten_erf_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.erf, args, kwargs) - - def test_aten_erf_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.erf, args, kwargs) - - def test_aten_exp_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.exp, args, kwargs) - - def test_aten_exp_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.exp, args, kwargs) - - def test_aten_exp_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.exp, args, kwargs) - - def test_aten_expand_0(self): - args = ( - torch.randn((10, 1)).to(torch.float32), - [ - 10, - 10, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.expand, args, kwargs) - - def test_aten_expand_1(self): - args = ( - torch.randn((10, 1)).to(torch.float16), - [ - 10, - 10, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.expand, args, kwargs) - - def test_aten_expand_2(self): - args = ( - torch.randint(0, 10, (10, 1)).to(torch.int32), - [ - 10, - 10, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.expand, args, kwargs) - - def test_aten_expand_copy_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - [ - 10, - 10, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.expand_copy, args, kwargs) - - def test_aten_expand_copy_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - [ - 10, - 10, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.expand_copy, args, kwargs) - - def test_aten_expand_copy_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - [ - 10, - 10, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.expand_copy, args, kwargs) - - def test_aten_expm1_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.expm1, args, kwargs) - - def test_aten_expm1_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.expm1, args, kwargs, rtol=1e-3) - - def test_aten_expm1_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.expm1, args, kwargs) - - def test_aten_fill_Scalar_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.fill.Scalar, args, kwargs) - - def test_aten_fill_Scalar_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.fill.Scalar, args, kwargs) - - def test_aten_fill_Scalar_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.fill.Scalar, args, kwargs) - - def test_aten_flip_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - [ - 0, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.flip, args, kwargs) - - def test_aten_flip_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - [ - 0, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.flip, args, kwargs) - - def test_aten_flip_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - [ - 0, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.flip, args, kwargs) - - def test_aten_floor_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.floor, args, kwargs) - - def test_aten_floor_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.floor, args, kwargs) - - def test_aten_floor_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.floor, args, kwargs) - - def test_aten_fmod_Scalar_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.fmod.Scalar, args, kwargs) - - def test_aten_fmod_Scalar_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 0.123, - ) - kwargs = dict() - run_export_and_compare( - self, torch.ops.aten.fmod.Scalar, args, kwargs, rtol=0.1, atol=0.2) - - def test_aten_fmod_Scalar_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.fmod.Scalar, args, kwargs) - - def test_aten_fmod_Tensor_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.fmod.Tensor, args, kwargs) - - def test_aten_fmod_Tensor_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - torch.randn((10, 10)).to(torch.float16), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.fmod.Tensor, args, kwargs) - - def test_aten_full_like_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.full_like, args, kwargs) - - def test_aten_full_like_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.full_like, args, kwargs) - - def test_aten_full_like_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.full_like, args, kwargs) - - def test_aten_gather_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 1, - torch.randint(0, 10, (2, 2)).to(torch.int64), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.gather, args, kwargs) - - def test_aten_gather_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 1, - torch.randint(0, 10, (2, 2)).to(torch.int64), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.gather, args, kwargs) - - def test_aten_gather_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 1, - torch.randint(0, 10, (2, 2)).to(torch.int64), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.gather, args, kwargs) - - def test_aten_ge_Scalar_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.ge.Scalar, args, kwargs) - - def test_aten_ge_Scalar_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.ge.Scalar, args, kwargs) - - def test_aten_ge_Scalar_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.ge.Scalar, args, kwargs) - - def test_aten_ge_Tensor_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.ge.Tensor, args, kwargs) - - def test_aten_ge_Tensor_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - torch.randn((10, 10)).to(torch.float16), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.ge.Tensor, args, kwargs) - - def test_aten_ge_Tensor_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - torch.randint(0, 10, (10, 10)).to(torch.int32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.ge.Tensor, args, kwargs) - - def test_aten_gelu_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.gelu, args, kwargs) - - def test_aten_gelu_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare( - self, torch.ops.aten.gelu, args, kwargs, atol=0.01, rtol=0.01) - - def test_aten_glu_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 0, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.glu, args, kwargs) - - def test_aten_glu_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 0, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.glu, args, kwargs) - - def test_aten_grid_sampler_2d_0(self): - args = ( - torch.randn((1, 3, 2, 10)).to(torch.float32), - torch.randn((1, 2, 2, 2)).to(torch.float32), - 0, - 0, - False, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.grid_sampler_2d, args, kwargs) - - def test_aten_gt_Scalar_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.gt.Scalar, args, kwargs) - - def test_aten_gt_Scalar_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.gt.Scalar, args, kwargs) - - def test_aten_gt_Scalar_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.gt.Scalar, args, kwargs) - - def test_aten_gt_Tensor_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.gt.Tensor, args, kwargs) - - def test_aten_gt_Tensor_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - torch.randn((10, 10)).to(torch.float16), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.gt.Tensor, args, kwargs) - - def test_aten_gt_Tensor_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - torch.randint(0, 10, (10, 10)).to(torch.int32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.gt.Tensor, args, kwargs) - - def test_aten_hardtanh_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.hardtanh, args, kwargs) - - def test_aten_hardtanh_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.hardtanh, args, kwargs) - - def test_aten_hardtanh_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.hardtanh, args, kwargs) - - def test_aten_index_put_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - [ - torch.randint(0, 10, (1,)).to(torch.int64), - ], - torch.randn((10,)).to(torch.float32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.index_put, args, kwargs) - - def test_aten_index_put_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - [ - torch.randint(0, 10, (1,)).to(torch.int64), - ], - torch.randn((10,)).to(torch.float16), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.index_put, args, kwargs) - - def test_aten_index_put_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - [ - torch.randint(0, 10, (1,)).to(torch.int64), - ], - torch.randint(0, 10, (10,)).to(torch.int32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.index_put, args, kwargs) - - def test_aten_index_put_3(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - [ - torch.randint(0, 10, (1,)).to(torch.int64), - ], - torch.randint(0, 10, (10,)).to(torch.int32), - True, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.index_put, args, kwargs) - - def test_aten_index_select_0(self): - args = ( - torch.randn((2, 10)).to(torch.float32), - 1, - torch.randint(0, 10, (2,)).to(torch.int64), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.index_select, args, kwargs) - - def test_aten_index_select_1(self): - args = ( - torch.randn((2, 10)).to(torch.float16), - 1, - torch.randint(0, 10, (2,)).to(torch.int64), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.index_select, args, kwargs) - - def test_aten_index_select_2(self): - args = ( - torch.randint(0, 10, (2, 10)).to(torch.int32), - 1, - torch.randint(0, 10, (2,)).to(torch.int64), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.index_select, args, kwargs) - - def test_aten_index_select_int32_index(self): - args = ( - torch.randint(0, 10, (2, 10)).to(torch.int32), - 1, - torch.randint(0, 10, (2,)).to(torch.int32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.index_select, args, kwargs) - - def test_aten_index_Tensor_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - [ - torch.randint(0, 10, (2,)).to(torch.int64), - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.index.Tensor, args, kwargs) - - def test_aten_index_Tensor_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - [ - torch.randint(0, 10, (2,)).to(torch.int64), - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.index.Tensor, args, kwargs) - - def test_aten_index_Tensor_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - [ - torch.randint(0, 10, (2,)).to(torch.int64), - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.index.Tensor, args, kwargs) - - def test_aten_isinf_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.isinf, args, kwargs) - - def test_aten_isinf_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.isinf, args, kwargs) - - def test_aten_isinf_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.isinf, args, kwargs) - - def test_aten_isnan_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.isnan, args, kwargs) - - def test_aten_isnan_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.isnan, args, kwargs) - - def test_aten_isnan_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.isnan, args, kwargs) - - def test_aten_le_Scalar_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.le.Scalar, args, kwargs) - - def test_aten_le_Scalar_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.le.Scalar, args, kwargs) - - def test_aten_le_Scalar_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.le.Scalar, args, kwargs) - - def test_aten_le_Tensor_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.le.Tensor, args, kwargs) - - def test_aten_le_Tensor_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - torch.randn((10, 10)).to(torch.float16), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.le.Tensor, args, kwargs) - - def test_aten_le_Tensor_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - torch.randint(0, 10, (10, 10)).to(torch.int32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.le.Tensor, args, kwargs) - - def test_aten_leaky_relu_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.leaky_relu, args, kwargs) - - def test_aten_leaky_relu_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.leaky_relu, args, kwargs) - - def test_aten_lift_fresh_copy_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.lift_fresh_copy, args, kwargs) - - def test_aten_lift_fresh_copy_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.lift_fresh_copy, args, kwargs) - - def test_aten_lift_fresh_copy_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.lift_fresh_copy, args, kwargs) - - def test_aten_log_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.log, args, kwargs) - - def test_aten_log_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.log, args, kwargs) - - def test_aten_log_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.log, args, kwargs) - - def test_aten_log10_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.log10, args, kwargs) - - def test_aten_log10_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare( - self, torch.ops.aten.log10, args, kwargs, atol=0.001, rtol=0.001) - - def test_aten_log10_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.log10, args, kwargs) - - def test_aten_log1p_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.log1p, args, kwargs) - - def test_aten_log1p_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare( - self, torch.ops.aten.log1p, args, kwargs, atol=0.001, rtol=0.001) - - def test_aten_log1p_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.log1p, args, kwargs) - - def test_aten_log2_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.log2, args, kwargs) - - def test_aten_log2_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare( - self, torch.ops.aten.log2, args, kwargs, atol=0.001, rtol=0.001) - - def test_aten_log2_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.log2, args, kwargs) - - def test_aten__log_softmax_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 1, - False, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten._log_softmax, args, kwargs) - - def test_aten_logical_and_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.logical_and, args, kwargs) - - def test_aten_logical_and_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - torch.randn((10, 10)).to(torch.float16), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.logical_and, args, kwargs) - - def test_aten_logical_and_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - torch.randint(0, 10, (10, 10)).to(torch.int32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.logical_and, args, kwargs) - - def test_aten_logical_not_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.logical_not, args, kwargs) - - def test_aten_logical_not_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.logical_not, args, kwargs) - - def test_aten_logical_not_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.logical_not, args, kwargs) - - def test_aten_logical_or_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.logical_or, args, kwargs) - - def test_aten_logical_or_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - torch.randn((10, 10)).to(torch.float16), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.logical_or, args, kwargs) - - def test_aten_logical_or_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - torch.randint(0, 10, (10, 10)).to(torch.int32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.logical_or, args, kwargs) - - def test_aten_logical_xor_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.logical_xor, args, kwargs) - - def test_aten_logical_xor_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - torch.randn((10, 10)).to(torch.float16), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.logical_xor, args, kwargs) - - def test_aten_logical_xor_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - torch.randint(0, 10, (10, 10)).to(torch.int32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.logical_xor, args, kwargs) - - def test_aten_logit_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.logit, args, kwargs) - - def test_aten_logit_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare( - self, - torch.ops.aten.logit, - args, - kwargs, - atol=0.01, - ) - - def test_aten_logit_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.logit, args, kwargs) - - def test_aten_lt_Scalar_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.lt.Scalar, args, kwargs) - - def test_aten_lt_Scalar_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.lt.Scalar, args, kwargs) - - def test_aten_lt_Scalar_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.lt.Scalar, args, kwargs) - - def test_aten_lt_Tensor_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.lt.Tensor, args, kwargs) - - def test_aten_lt_Tensor_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - torch.randn((10, 10)).to(torch.float16), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.lt.Tensor, args, kwargs) - - def test_aten_lt_Tensor_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - torch.randint(0, 10, (10, 10)).to(torch.int32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.lt.Tensor, args, kwargs) - - def test_aten_masked_fill_Scalar_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.bool), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.masked_fill.Scalar, args, - kwargs) - - def test_aten_masked_fill_Scalar_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - torch.randn((10, 10)).to(torch.bool), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.masked_fill.Scalar, args, - kwargs) - - def test_aten_masked_fill_Scalar_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - torch.randn((10, 10)).to(torch.bool), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.masked_fill.Scalar, args, - kwargs) - - def test_aten_max_dim_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.max.dim, args, kwargs) - - def test_aten_max_dim_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.max.dim, args, kwargs) - - def test_aten_max_dim_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.max.dim, args, kwargs) - - def test_aten_max_pool2d_with_indices_0(self): - args = ( - torch.randn((3, 2, 10)).to(torch.float32), - [ - 2, - 2, - ], - [ - 1, - 1, - ], - [ - 1, - 1, - ], - ) - kwargs = dict() - run_export_and_compare( - self, - torch.ops.aten.max_pool2d_with_indices, - args, - kwargs, - ignore_indices=True) - - def test_aten_max_pool2d_with_indices_1(self): - args = ( - torch.randn((3, 2, 10)).to(torch.float16), - [ - 2, - 2, - ], - [ - 1, - 1, - ], - [ - 1, - 1, - ], - ) - kwargs = dict() - run_export_and_compare( - self, - torch.ops.aten.max_pool2d_with_indices, - args, - kwargs, - ignore_indices=True) - - def test_aten_max_pool2d_with_indices_2(self): - args = ( - torch.arange(0, 60).reshape(3, 2, 10), - [ - 2, - 2, - ], - [ - 1, - 1, - ], - [ - 1, - 1, - ], - ) - kwargs = dict() - run_export_and_compare( - self, - torch.ops.aten.max_pool2d_with_indices, - args, - kwargs, - ignore_indices=True) - - def test_aten_max_pool3d_with_indices_0(self): - args = ( - torch.randn((1, 3, 2, 10)).to(torch.float32), - [ - 2, - 2, - 2, - ], - [ - 1, - 1, - 1, - ], - [ - 1, - 1, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.max_pool3d_with_indices, args, - kwargs) - - def test_aten_max_pool3d_with_indices_1(self): - args = ( - torch.randn((1, 3, 2, 10)).to(torch.float16), - [ - 2, - 2, - 2, - ], - [ - 1, - 1, - 1, - ], - [ - 1, - 1, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.max_pool3d_with_indices, args, - kwargs) - - def test_aten_max_pool3d_with_indices_2(self): - args = ( - torch.arange(0, 60).reshape(1, 3, 2, 10), - [ - 2, - 2, - 2, - ], - [ - 1, - 1, - 1, - ], - [ - 1, - 1, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.max_pool3d_with_indices, args, - kwargs) - - def test_aten_maximum_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.maximum, args, kwargs) - - def test_aten_maximum_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - torch.randn((10, 10)).to(torch.float16), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.maximum, args, kwargs) - - def test_aten_maximum_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - torch.randint(0, 10, (10, 10)).to(torch.int32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.maximum, args, kwargs) - - def test_aten_mean_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.mean, args, kwargs) - - def test_aten_mean_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.mean, args, kwargs) - - def test_aten_mean_dim_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - None, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.mean.dim, args, kwargs) - - def test_aten_mean_dim_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - None, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.mean.dim, args, kwargs) - - def test_aten_min_dim_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.min.dim, args, kwargs) - - def test_aten_min_dim_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.min.dim, args, kwargs) - - def test_aten_min_dim_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.min.dim, args, kwargs) - - def test_aten_minimum_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.minimum, args, kwargs) - - def test_aten_minimum_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - torch.randn((10, 10)).to(torch.float16), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.minimum, args, kwargs) - - def test_aten_minimum_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - torch.randint(0, 10, (10, 10)).to(torch.int32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.minimum, args, kwargs) - - def test_aten_mm_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.mm, args, kwargs) - - def test_aten_mm_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - torch.randn((10, 10)).to(torch.float16), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.mm, args, kwargs) - - def test_aten_mm_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - torch.randint(0, 10, (10, 10)).to(torch.int32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.mm, args, kwargs) - - def test_aten_mul_Scalar_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.mul.Scalar, args, kwargs) - - def test_aten_mul_Scalar_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.mul.Scalar, args, kwargs) - - def test_aten_mul_Scalar_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.mul.Scalar, args, kwargs) - - def test_aten_mul_Tensor_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.mul.Tensor, args, kwargs) - - def test_aten_mul_Tensor_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - torch.randn((10, 10)).to(torch.float16), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.mul.Tensor, args, kwargs) - - def test_aten_mul_Tensor_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - torch.randint(0, 10, (10, 10)).to(torch.int32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.mul.Tensor, args, kwargs) - - def test_aten__native_batch_norm_legit_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - None, - None, - torch.randn((10,)).to(torch.float32), - torch.randn((10,)).to(torch.float32), - False, - 1.0, - 1.0, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit, args, - kwargs) - - def test_aten__native_batch_norm_legit_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - None, - None, - torch.randn((10,)).to(torch.float16), - torch.randn((10,)).to(torch.float16), - False, - 1.0, - 1.0, - ) - kwargs = dict() - run_export_and_compare( - self, - torch.ops.aten._native_batch_norm_legit, - args, - kwargs, - atol=0.01, - rtol=0.01, - ) - - def test_aten__native_batch_norm_legit_no_stats_0(self): - args = ( - torch.randn((1, 3, 2, 10)).to(torch.float32), - torch.randn((1, 3, 1, 1)).to(torch.float32), - torch.randn((1, 3, 1, 1)).to(torch.float32), - True, - 0.0, - 1.0, - ) - kwargs = dict() - run_export_and_compare(self, - torch.ops.aten._native_batch_norm_legit.no_stats, - args, kwargs) - - def test_aten__native_batch_norm_legit_no_stats_1(self): - args = ( - torch.randn((1, 3, 2, 10)).to(torch.float16), - torch.randn((1, 3, 1, 1)).to(torch.float32), - torch.randn((1, 3, 1, 1)).to(torch.float32), - True, - 0.0, - 1.0, - ) - kwargs = dict() - run_export_and_compare(self, - torch.ops.aten._native_batch_norm_legit.no_stats, - args, kwargs) - - def test_aten__native_batch_norm_legit_no_training_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - None, - None, - torch.randn((10,)).to(torch.float32), - torch.randn((10,)).to(torch.float32), - 1.0, - 1.0, - ) - kwargs = dict() - run_export_and_compare(self, - torch.ops.aten._native_batch_norm_legit_no_training, - args, kwargs) - - def test_aten__native_batch_norm_legit_no_training_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - None, - None, - torch.randn((10,)).to(torch.float16), - torch.randn((10,)).to(torch.float16), - 1.0, - 1.0, - ) - kwargs = dict() - run_export_and_compare( - self, - torch.ops.aten._native_batch_norm_legit_no_training, - args, - kwargs, - atol=0.01, - rtol=0.01, - ) - - def test_aten_native_dropout_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 1.0, - True, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.native_dropout, args, kwargs) - - def test_aten_native_dropout_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 1.0, - False, - ) - kwargs = dict() - run_export_and_compare( - self, torch.ops.aten.native_dropout, args, kwargs, atol=0.01, rtol=0.01) - - def test_aten_native_group_norm_0(self): - args = ( - torch.randn((1, 3, 2, 10)).to(torch.float32), - None, - None, - 1, - 3, - 20, - 1, - 0.0, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.native_group_norm, args, kwargs) - - def test_aten_native_group_norm_1(self): - args = ( - torch.randn((1, 3, 2, 10)).to(torch.float16), - None, - None, - 1, - 3, - 20, - 1, - 0.0, - ) - kwargs = dict() - run_export_and_compare( - self, - torch.ops.aten.native_group_norm, - args, - kwargs, - atol=0.01, - rtol=0.01) - - def test_aten_native_layer_norm_0(self): - args = ( - torch.randn((1, 3, 2, 10)).to(torch.float32), - [ - 1, - 3, - 2, - 10, - ], - None, - None, - 0.0, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.native_layer_norm, args, kwargs) - - def test_aten_native_layer_norm_1(self): - args = ( - torch.randn((1, 10, 10, 10)).to(torch.float32), - [10], - torch.randn((10,)).to(torch.float32), - torch.randn((10,)).to(torch.float32), - 0.0, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.native_layer_norm, args, kwargs) - - def test_aten_native_batch_norm_legit(self): - batch = 3 - channel = 2 - args = ( - torch.randn((batch, channel, 2, 2)).to(torch.float32), - torch.ones(channel), - torch.zeros(channel), - torch.zeros(channel), - torch.ones(channel), - False, - 0.5, - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit, args, - kwargs) - - def test_aten_native_batch_norm_legit_none(self): - batch = 3 - channel = 2 - args = ( - torch.randn((batch, channel, 4, 4)).to(torch.float32), - None, - None, - torch.ones(channel), - torch.zeros(channel), - False, - 0.5, - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit, args, - kwargs) - - def test_aten_native_batch_norm_legit_training_none(self): - batch = 3 - channel = 2 - args = ( - torch.randn((batch, channel, 4, 3)).to(torch.float32), - None, - None, - torch.zeros(channel), - torch.ones(channel), - True, - 0.2, - 2e-5, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit, args, - kwargs) - - def test_aten_native_batch_norm_legit_no_training(self): - batch = 3 - channel = 2 - args = ( - torch.randn((batch, channel, 4, 3)).to(torch.float32), - torch.ones(channel), - torch.zeros(channel), - torch.zeros(channel), - torch.ones(channel), - 0.2, - 2e-5, - ) - kwargs = dict() - run_export_and_compare(self, - torch.ops.aten._native_batch_norm_legit_no_training, - args, kwargs) - - def test_aten_native_batch_norm_training(self): - batch = 3 - channel = 2 - args = ( - torch.randn((batch, channel, 4, 3)).to(torch.float32), - torch.ones(channel), - torch.zeros(channel), - torch.zeros(channel), - torch.ones(channel), - True, - 0.1, - 1e-5, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.native_batch_norm, args, kwargs) - - def test_aten_native_batch_norm_training_none(self): - batch = 3 - channel = 2 - args = ( - torch.randn((batch, channel, 4, 3)).to(torch.float32), - None, - None, - torch.zeros(channel), - torch.ones(channel), - True, - 0.1, - 1e-5, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.native_batch_norm, args, kwargs) - - def test_aten_native_batch_norm_eval(self): - batch = 3 - channel = 2 - args = ( - torch.randn((batch, channel, 4, 3)).to(torch.float32), - torch.ones(channel), - torch.zeros(channel), - torch.zeros(channel), - torch.ones(channel), - False, - 0.2, - 2e-5, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.native_batch_norm, args, kwargs) - - def test_aten_ne_Scalar_0(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.ne.Scalar, args, kwargs) - - def test_aten_ne_Scalar_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.ne.Scalar, args, kwargs) - - def test_aten_ne_Scalar_2(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.ne.Scalar, args, kwargs) - - def test_aten_ne_Tensor_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.ne.Tensor, args, kwargs) - - def test_aten_ne_Tensor_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - torch.randn((10, 10)).to(torch.float16), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.ne.Tensor, args, kwargs) - - def test_aten_ne_Tensor_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - torch.randint(0, 10, (10, 10)).to(torch.int32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.ne.Tensor, args, kwargs) - - def test_aten_neg_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.neg, args, kwargs) - - def test_aten_neg_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.neg, args, kwargs) - - def test_aten_neg_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.neg, args, kwargs) - - def test_aten_nonzero_0(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.nonzero, args, kwargs) - - def test_aten_nonzero_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.nonzero, args, kwargs) - - def test_aten_nonzero_2(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.nonzero, args, kwargs) - - def test_aten__pdist_forward_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 1.0, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten._pdist_forward, args, kwargs) - - def test_aten_permute_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - [ - 0, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.permute, args, kwargs) - - def test_aten_permute_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - [ - 0, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.permute, args, kwargs) - - def test_aten_permute_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - [ - 0, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.permute, args, kwargs) - - def test_aten_permute_copy_0(self): - args = ( - torch.randn((2, 2, 2)).to(torch.float32), - [ - 1, - 2, - 0, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.permute_copy, args, kwargs) - - def test_aten_permute_copy_1(self): - args = ( - torch.randn((2, 2, 2)).to(torch.float16), - [ - 1, - 2, - 0, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.permute_copy, args, kwargs) - - def test_aten_permute_copy_2(self): - args = ( - torch.randint(0, 10, (2, 2, 2)).to(torch.int32), - [ - 1, - 2, - 0, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.permute_copy, args, kwargs) - - def test_aten_pixel_shuffle_0(self): - args = ( - torch.randn((1, 3, 10, 10)).to(torch.float32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.pixel_shuffle, args, kwargs) - - def test_aten_pixel_shuffle_1(self): - args = ( - torch.randn((1, 3, 10, 10)).to(torch.float16), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.pixel_shuffle, args, kwargs) - - def test_aten_pixel_shuffle_2(self): - args = ( - torch.randint(0, 10, (1, 3, 10, 10)).to(torch.int32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.pixel_shuffle, args, kwargs) - - def test_aten_pow_Scalar_0(self): - args = ( - 1.123, - torch.randn((10, 10)).to(torch.float32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.pow.Scalar, args, kwargs) - - def test_aten_pow_Tensor_Scalar_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 1.2, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.pow.Tensor_Scalar, args, kwargs) - - def test_aten_pow_Tensor_Scalar_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 1.2, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.pow.Tensor_Scalar, args, kwargs) - - def test_aten_pow_Tensor_Scalar_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 1.2, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.pow.Tensor_Scalar, args, kwargs) - - def test_aten_pow_Scalar_1(self): - args = (10000, torch.randn(16 * 8)) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.pow.Scalar, args, kwargs) - - def test_aten_pow_Tensor_Tensor_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.pow.Tensor_Tensor, args, kwargs) - - def test_aten_pow_Tensor_Tensor_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - torch.randn((10, 10)).to(torch.float16), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.pow.Tensor_Tensor, args, kwargs) - - def test_aten_pow_Tensor_Tensor_2(self): - args = ( - torch.randint(0, 5, (10, 10)).to(torch.int32), - torch.randint(0, 5, (10, 10)).to(torch.int32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.pow.Tensor_Tensor, args, kwargs) - - def test_aten_prod_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.prod, args, kwargs) - - def test_aten_prod_1(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.prod, args, kwargs) - - def test_aten_prod_dim_int_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.prod.dim_int, args, kwargs) - - def test_aten_prod_dim_int_1(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.prod.dim_int, args, kwargs) - - def test_aten_reciprocal_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.reciprocal, args, kwargs) - - def test_aten_reciprocal_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.reciprocal, args, kwargs) - - def test_aten_reciprocal_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.reciprocal, args, kwargs) - - def test_aten_reflection_pad1d_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - [ - 0, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.reflection_pad1d, args, kwargs) - - def test_aten_reflection_pad1d_1(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - [ - 0, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.reflection_pad1d, args, kwargs) - - def test_aten_reflection_pad2d_0(self): - args = ( - torch.randn((3, 2, 10)).to(torch.float32), - [ - 1, - 1, - 1, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.reflection_pad2d, args, kwargs) - - def test_aten_reflection_pad2d_1(self): - args = ( - torch.randint(0, 10, (3, 2, 10)).to(torch.int32), - [ - 1, - 1, - 1, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.reflection_pad2d, args, kwargs) - - def test_aten_reflection_pad3d_0(self): - args = ( - torch.randn((3, 3, 3, 3)).to(torch.float32), - [ - 1, - 2, - 1, - 2, - 1, - 2, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.reflection_pad3d, args, kwargs) - - def test_aten_reflection_pad3d_1(self): - args = ( - torch.randn((3, 3, 3, 3, 3)).to(torch.float16), - [ - 1, - 2, - 1, - 2, - 1, - 2, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.reflection_pad3d, args, kwargs) - - def test_aten_reflection_pad3d_2(self): - args = ( - torch.randint(0, 10, (3, 3, 3, 3)).to(torch.int32), - [ - 1, - 2, - 1, - 2, - 1, - 2, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.reflection_pad3d, args, kwargs) - - def test_aten_relu_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.relu, args, kwargs) - - def test_aten_relu_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.relu, args, kwargs) - - def test_aten_relu_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.relu, args, kwargs) - - def test_aten_remainder_Scalar_0(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 2, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.remainder.Scalar, args, kwargs) - - def test_aten_remainder_Scalar_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 2, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.remainder.Scalar, args, kwargs) - - def test_aten_remainder_Scalar_2(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 2, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.remainder.Scalar, args, kwargs) - - def test_aten_remainder_Tensor_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.remainder.Tensor, args, kwargs) - - def test_aten_remainder_Tensor_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - torch.randn((10, 10)).to(torch.float16), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.remainder.Tensor, args, kwargs) - - def test_aten_replication_pad2d_0(self): - args = ( - torch.randn((3, 2, 10)).to(torch.float32), - [ - 1, - 1, - 1, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.replication_pad2d, args, kwargs) - - def test_aten_replication_pad2d_1(self): - args = ( - torch.randint(0, 10, (3, 2, 10)).to(torch.int32), - [ - 1, - 1, - 1, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.replication_pad2d, args, kwargs) - - def test_aten_replication_pad3d_0(self): - args = ( - torch.randn((1, 3, 2, 10)).to(torch.float32), - [ - 1, - 1, - 1, - 1, - 1, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.replication_pad3d, args, kwargs) - - def test_aten_replication_pad3d_1(self): - args = ( - torch.randint(0, 10, (1, 3, 2, 10)).to(torch.int32), - [ - 1, - 1, - 1, - 1, - 1, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.replication_pad3d, args, kwargs) - - def test_aten_roll_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - [ - 0, - 1, - ], - [ - 0, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.roll, args, kwargs) - - def test_aten_roll_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - [ - 0, - 1, - ], - [ - 0, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.roll, args, kwargs) - - def test_aten_roll_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - [ - 0, - 1, - ], - [ - 0, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.roll, args, kwargs) - - def test_aten_round_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.round, args, kwargs) - - def test_aten_round_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.round, args, kwargs) - - def test_aten_round_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.round, args, kwargs) - - def test_aten_rsqrt_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.rsqrt, args, kwargs) - - def test_aten_rsqrt_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare( - self, torch.ops.aten.rsqrt, args, kwargs, atol=0.01, rtol=0.01) - - def test_aten_rsqrt_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.rsqrt, args, kwargs) - - def test_aten_rsub_Scalar_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.rsub.Scalar, args, kwargs) - - def test_aten_rsub_Scalar_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.rsub.Scalar, args, kwargs) - - def test_aten_rsub_Scalar_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.rsub.Scalar, args, kwargs) - - def test_aten_scatter_add_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 1, - torch.randint(0, 10, (2, 2)).to(torch.int64), - torch.randn((10, 10)).to(torch.float32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.scatter_add, args, kwargs) - - def test_aten_scatter_add_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 1, - torch.randint(0, 10, (2, 2)).to(torch.int64), - torch.randn((10, 10)).to(torch.float16), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.scatter_add, args, kwargs) - - def test_aten_scatter_add_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 1, - torch.randint(0, 10, (2, 2)).to(torch.int64), - torch.randint(0, 10, (10, 10)).to(torch.int32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.scatter_add, args, kwargs) - - def test_aten_scatter_reduce_two_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 1, - torch.randint(0, 10, (10, 10)).to(torch.int64), - torch.randn((10, 10)).to(torch.float32), - "sum", - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.scatter_reduce.two, args, - kwargs) - - def test_aten_scatter_reduce_two_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 1, - torch.randint(0, 10, (10, 10)).to(torch.int64), - torch.randn((10, 10)).to(torch.float16), - "amin", - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.scatter_reduce.two, args, - kwargs) - - def test_aten_scatter_reduce_two_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 1, - torch.randint(0, 10, (10, 10)).to(torch.int64), - torch.randint(0, 10, (10, 10)).to(torch.int32), - "amax", - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.scatter_reduce.two, args, - kwargs) - - def test_aten_scatter_src_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 1, - torch.randint(0, 10, (10, 10)).to(torch.int64), - torch.randn((10, 10)).to(torch.float32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.scatter.src, args, kwargs) - - def test_aten_scatter_src_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 1, - torch.randint(0, 10, (10, 10)).to(torch.int64), - torch.randn((10, 10)).to(torch.float16), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.scatter.src, args, kwargs) - - def test_aten_scatter_src_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 1, - torch.randint(0, 10, (10, 10)).to(torch.int64), - torch.randint(0, 10, (10, 10)).to(torch.int32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.scatter.src, args, kwargs) - - def test_aten_scatter_value_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 1, - torch.randint(0, 10, (10, 10)).to(torch.int64), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.scatter.value, args, kwargs) - - def test_aten_scatter_value_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 1, - torch.randint(0, 10, (10, 10)).to(torch.int64), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.scatter.value, args, kwargs) - - def test_aten_scatter_value_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 1, - torch.randint(0, 10, (10, 10)).to(torch.int64), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.scatter.value, args, kwargs) - - def test_aten_select_copy_int_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 1, - 0, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.select_copy.int, args, kwargs) - - def test_aten_select_copy_int_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 1, - 0, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.select_copy.int, args, kwargs) - - def test_aten_select_copy_int_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 1, - 0, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.select_copy.int, args, kwargs) - - def test_aten_select_int_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 1, - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.select.int, args, kwargs) - - def test_aten_select_int_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 1, - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.select.int, args, kwargs) - - def test_aten_select_int_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 1, - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.select.int, args, kwargs) - - def test_aten_select_scatter_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randint(0, 10, (10,)).to(torch.int64), - 1, - 0, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.select_scatter, args, kwargs) - - def test_aten_select_scatter_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - torch.randint(0, 10, (10,)).to(torch.int64), - 1, - 0, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.select_scatter, args, kwargs) - - def test_aten_select_scatter_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - torch.randint(0, 10, (10,)).to(torch.int64), - 1, - 0, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.select_scatter, args, kwargs) - - def test_aten_select_scatter_3(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randint(0, 10, (10,)).to(torch.int64), - -1, - 0, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.select_scatter, args, kwargs) - - def test_aten_sigmoid_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.sigmoid, args, kwargs) - - def test_aten_sigmoid_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.sigmoid, args, kwargs) - - def test_aten_sigmoid_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.sigmoid, args, kwargs) - - def test_aten_sign_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.sign, args, kwargs) - - def test_aten_sign_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.sign, args, kwargs) - - def test_aten_sign_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.sign, args, kwargs) - - def test_aten_sin_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.sin, args, kwargs) - - def test_aten_sin_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.sin, args, kwargs) - - def test_aten_sin_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.sin, args, kwargs) - - def test_aten_sinh_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.sinh, args, kwargs) - - def test_aten_sinh_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.sinh, args, kwargs) - - def test_aten_sinh_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.sinh, args, kwargs) - - def test_aten_slice_copy_Tensor_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.slice_copy.Tensor, args, kwargs) - - def test_aten_slice_copy_Tensor_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.slice_copy.Tensor, args, kwargs) - - def test_aten_slice_copy_Tensor_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.slice_copy.Tensor, args, kwargs) - - def test_aten_slice_scatter_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.slice_scatter, args, kwargs) - - def test_aten_slice_scatter_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - torch.randn((10, 10)).to(torch.float16), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.slice_scatter, args, kwargs) - - def test_aten_slice_scatter_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - torch.randint(0, 10, (10, 10)).to(torch.int32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.slice_scatter, args, kwargs) - - def test_aten_slice_Tensor_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.slice.Tensor, args, kwargs) - - def test_aten_slice_Tensor_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.slice.Tensor, args, kwargs) - - def test_aten_slice_Tensor_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.slice.Tensor, args, kwargs) - - def test_aten__softmax_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 1, - False, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten._softmax, args, kwargs) - - def test_aten__softmax_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 1, - False, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten._softmax, args, kwargs) - - def test_aten_softmax(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.softmax, args, kwargs) - - def _compare_sorted_result(self, args): - res = torch.ops.aten.sort(*args) - with self.subTest("torchax_eval"): - args2 = self.env.to_xla(args) - with self.env: - res2 = torch.ops.aten.sort(*args2) - - # The second argument is the sorted index. These might not be - # identical from torch vs. jax; but both can be correct - diff_output(self, res[0], res2[0].torch(), 1e-3, 1e-5) - - def test_aten_sort_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 1, - ) - self._compare_sorted_result(args) - - def test_aten_sort_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 1, - ) - self._compare_sorted_result(args) - - def test_aten_sort_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 1, - ) - self._compare_sorted_result(args) - - def test_aten_split_copy_Tensor_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 2, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.split_copy.Tensor, args, kwargs) - - def test_aten_split_copy_Tensor_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 2, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.split_copy.Tensor, args, kwargs) - - def test_aten_split_copy_Tensor_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 2, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.split_copy.Tensor, args, kwargs) - - def test_aten_split_with_sizes_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - [ - 1, - 2, - 3, - 4, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.split_with_sizes, args, kwargs) - - def test_aten_split_with_sizes_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - [ - 1, - 2, - 3, - 4, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.split_with_sizes, args, kwargs) - - def test_aten_split_with_sizes_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - [ - 1, - 2, - 3, - 4, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.split_with_sizes, args, kwargs) - - def test_aten_sqrt_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.sqrt, args, kwargs) - - def test_aten_sqrt_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.sqrt, args, kwargs) - - def test_aten_sqrt_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.sqrt, args, kwargs) - - def test_aten_squeeze_copy_dim_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 0, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.squeeze_copy.dim, args, kwargs) - - def test_aten_squeeze_copy_dim_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 0, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.squeeze_copy.dim, args, kwargs) - - def test_aten_squeeze_copy_dim_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 0, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.squeeze_copy.dim, args, kwargs) - - def test_aten_squeeze_dims_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - [ - 0, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.squeeze.dims, args, kwargs) - - def test_aten_squeeze_dims_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - [ - 0, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.squeeze.dims, args, kwargs) - - def test_aten_squeeze_dims_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - [ - 0, - 1, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.squeeze.dims, args, kwargs) - - def test_aten_stack_0(self): - args = ([ - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - ],) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.stack, args, kwargs) - - def test_aten_stack_1(self): - args = ([ - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - ],) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.stack, args, kwargs) - - def test_aten_stack_2(self): - args = ([ - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - ],) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.stack, args, kwargs) - - def test_aten_sub_Scalar_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.sub.Scalar, args, kwargs) - - def test_aten_sub_Scalar_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.sub.Scalar, args, kwargs) - - def test_aten_sub_Scalar_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 0.123, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.sub.Scalar, args, kwargs) - - def test_aten_sub_Tensor_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.sub.Tensor, args, kwargs) - - def test_aten_sub_Tensor_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - torch.randn((10, 10)).to(torch.float16), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.sub.Tensor, args, kwargs) - - def test_aten_sub_Tensor_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - torch.randint(0, 10, (10, 10)).to(torch.int32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.sub.Tensor, args, kwargs) - - def test_aten_sum_dim_IntList_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - None, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.sum.dim_IntList, args, kwargs) - - def test_aten_sum_dim_IntList_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - None, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.sum.dim_IntList, args, kwargs) - - def test_aten_sum_dim_IntList_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - None, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.sum.dim_IntList, args, kwargs) - - def test_aten_tan_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.tan, args, kwargs) - - def test_aten_tan_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare( - self, - torch.ops.aten.tan, - args, - kwargs, - rtol=0.001, - atol=0.01, - ) - - def test_aten_tan_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.tan, args, kwargs) - - def test_aten_tanh_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.tanh, args, kwargs) - - def test_aten_tanh_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.tanh, args, kwargs) - - def test_aten_tanh_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.tanh, args, kwargs) - - def test_aten_topk_0(self): - args = ( - torch.arange(0, 100).reshape(10, 10).to(torch.float32), - 1, - 1, - False, - False, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.topk, args, kwargs) - - def test_aten_topk_1(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 1, - 1, - True, - False, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.topk, args, kwargs) - - def test_aten_topk_2(self): - args = ( - torch.arange(0, 100).reshape(10, 10).to(torch.int32), - 1, - 1, - False, - False, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.topk, args, kwargs) - - def test_aten_topk_3(self): - args = ( - torch.arange(0, 100).reshape(10, 10).to(torch.int32), - 3, - 0, - False, - True, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.topk, args, kwargs) - - def test_aten_topk_4(self): - args = ( - torch.arange(0, 100).reshape(10, 10).to(torch.int32), - 3, - 0, - True, - True, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.topk, args, kwargs) - - def test_aten_transpose_copy_int_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - 0, - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.transpose_copy.int, args, - kwargs) - - def test_aten_transpose_copy_int_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - 0, - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.transpose_copy.int, args, - kwargs) - - def test_aten_transpose_copy_int_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - 0, - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.transpose_copy.int, args, - kwargs) - - def test_aten_tril_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.tril, args, kwargs) - - def test_aten_tril_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.tril, args, kwargs) - - def test_aten_tril_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.tril, args, kwargs) - - def test_aten_trunc_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.trunc, args, kwargs) - - def test_aten_trunc_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.trunc, args, kwargs) - - def test_aten_trunc_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.trunc, args, kwargs) - - def test_aten_unbind_copy_int_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.unbind_copy.int, args, kwargs) - - def test_aten_unbind_copy_int_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.unbind_copy.int, args, kwargs) - - def test_aten_unbind_copy_int_2(self): - args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.unbind_copy.int, args, kwargs) - - def test_aten_unsqueeze_copy_0(self): - args = ( - torch.randn((2, 0, 2)).to(torch.float32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.unsqueeze_copy, args, kwargs) - - def test_aten_unsqueeze_copy_1(self): - args = ( - torch.randn((2, 0, 2)).to(torch.float16), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.unsqueeze_copy, args, kwargs) - - def test_aten_unsqueeze_copy_2(self): - args = ( - torch.randint(0, 10, (2, 0, 2)).to(torch.int32), - 1, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.unsqueeze_copy, args, kwargs) - - def test_aten_upsample_bilinear2d_0(self): - args = ( - torch.randn((1, 3, 2, 10)).to(torch.float32), - [ - 3, - 20, - ], - False, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.upsample_bilinear2d, args, - kwargs) - - def test_aten_upsample_nearest2d_0(self): - args = ( - torch.randn((1, 3, 2, 10)).to(torch.float32), - [ - 3, - 20, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.upsample_nearest2d, args, - kwargs) - - def test_aten_var_correction_0(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.var.correction, args, kwargs) - - def test_aten_var_correction_1(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.var.correction, args, kwargs) - - def test_aten_var_correction_2(self): - args = (torch.randn((10, 10)).to(torch.float32),) - kwargs = dict(correction=0) - run_export_and_compare(self, torch.ops.aten.var.correction, args, kwargs) - - def test_aten_var_correction_3(self): - args = (torch.randn((10, 10)).to(torch.float16),) - kwargs = dict(correction=0) - run_export_and_compare(self, torch.ops.aten.var.correction, args, kwargs) - - def test_aten_view_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - [ - 1, - 100, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.view, args, kwargs) - - def test_aten_view_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - [ - 1, - 100, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.view, args, kwargs) - - def test_aten_view_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - [ - 1, - 100, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.view, args, kwargs) - - def test_aten_view_copy_0(self): - args = ( - torch.randn((10, 10)).to(torch.float32), - [ - 2, - 5, - 10, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.view_copy, args, kwargs) - - def test_aten_view_copy_1(self): - args = ( - torch.randn((10, 10)).to(torch.float16), - [ - 2, - 5, - 10, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.view_copy, args, kwargs) - - def test_aten_view_copy_2(self): - args = ( - torch.randint(0, 10, (10, 10)).to(torch.int32), - [ - 2, - 5, - 10, - ], - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.view_copy, args, kwargs) - - def test_aten_where_self_0(self): - args = ( - torch.randn((10, 10)).to(torch.bool), - torch.randn((10, 10)).to(torch.float32), - torch.randn((10, 10)).to(torch.float32), - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.where.self, args, kwargs) - - def test_aten_where(self): - args = (torch.randn((10, 10)).to(torch.bool),) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.where, args, kwargs) - - def test_aten_copy_dtype(self): - args = ( - torch.ones((3, 3), dtype=torch.int32), - torch.zeros((3, 3), dtype=torch.float32), - ) - kwargs = dict() - run_export_and_compare( - self, torch.ops.aten.copy_, args, kwargs, check_dtype=True) - - def test_aten_rand_like(self): - args = (torch.ones((3, 3), dtype=torch.bfloat16),) - kwargs = dict() - run_export_and_compare( - self, - torch.ops.aten.rand_like, - args, - kwargs, - atol=math.inf, - check_dtype=True) - - def test_einsum(self): - args = ( - "bshd,bthd->bsht", - torch.randn((1, 2, 4, 8), dtype=torch.float16), - torch.randn((1, 2, 4, 8), dtype=torch.float16), - ) - kwargs = dict() - run_export_and_compare( - self, - torch.einsum, - args, - kwargs, - atol=1e-2, - rtol=1e-2, - check_dtype=True) - - def test_aten_einsum(self): - args = ("bshd,bthd->bsht", ( - torch.randn((1, 2, 4, 8), dtype=torch.float16), - torch.randn((1, 2, 4, 8), dtype=torch.float16), - )) - kwargs = dict() - run_export_and_compare( - self, - torch.ops.aten.einsum, - args, - kwargs, - atol=1e-2, - rtol=1e-2, - check_dtype=True) - - def test_aten_linear(self): - # with bias - args = ( - torch.randn((2, 4), dtype=torch.float16), - torch.randn((2, 4), dtype=torch.float16), - torch.randn((2,), dtype=torch.float16), - ) - kwargs = dict() - run_export_and_compare( - self, - torch.ops.aten.linear, - args, - kwargs, - atol=1e-2, - rtol=1e-2, - check_dtype=True) - - # without bias - args = ( - torch.randn((2, 4), dtype=torch.float16), - torch.randn((2, 4), dtype=torch.float16), - ) - kwargs = dict() - run_export_and_compare( - self, - torch.ops.aten.linear, - args, - kwargs, - atol=1e-2, - rtol=1e-2, - check_dtype=True) - - def test_aten_copy_different_device(self): - cpu_tensor = torch.tensor([1, 2, 3]) - - with self.env: - xla_tensor = torch.tensor([0, 0, 0], device='jax') - xla_tensor.copy_(cpu_tensor) - self.assertEqual(xla_tensor.tolist(), cpu_tensor.tolist()) - self.assertIsInstance(xla_tensor, tensor.Tensor) - self.assertEqual(xla_tensor.device.type, 'jax') - - -if __name__ == "__main__": - base_test_util.main() diff --git a/torchax/test/test_exports.py b/torchax/test/test_exports.py deleted file mode 100644 index 1a056c68e366..000000000000 --- a/torchax/test/test_exports.py +++ /dev/null @@ -1,172 +0,0 @@ -import unittest -import torch -import torch.nn.functional as F -import jax -import jax.export -import torchax -import torchax.export -from torchax import tensor -from torchax.ops import mappings - - -class Interpolate(torch.nn.Module): - - def forward(self, masks: torch.Tensor) -> torch.Tensor: - masks = F.interpolate( - masks, - size=(500, 500), - mode="bilinear", - align_corners=False, - ) - return masks - - -class TensorConstant(torch.nn.Module): - - def __init__(self): - super().__init__() - - def forward(self, a): - return a / torch.tensor(3) - - -class ExportTest(unittest.TestCase): - - def setUp(self): - torch.manual_seed(0) - torchax.enable_accuracy_mode() - - def test_interpolate(self): - - # Check Accuracy - arg = (torch.randn(3, 3, 200, 200),) - model = Interpolate() - ans = model(*arg) - - env = torchax.default_env() - - with torch.no_grad(): - exported = torch.export.export(model, arg) - weights, func = torchax.export.exported_program_to_jax(exported) - argj = env.t2j_copy(arg[0]) - ans2 = jax.jit(func)(weights, (argj,))[0] - ans2 = env.j2t_copy(ans2) - self.assertTrue(torch.allclose(ans, ans2, atol=1e-3)) - - # Convert to StableHLO - weights, stablehlo = torchax.export.exported_program_to_stablehlo(exported) - module_str = str(stablehlo.mlir_module()) - self.assertIn("func.func public @main", module_str) - self.assertIn("func.func private @clip(%arg0: tensor<500xf32>", module_str) - self.assertIn("stablehlo.minimum", module_str) - - def test_constant(self): - - # Check Accuracy - arg = (torch.randn(10, 10),) - model = TensorConstant() - ans = model(*arg) - - with torch.no_grad(): - exported = torch.export.export(model, arg) - env = torchax.default_env() - weights, func = torchax.export.exported_program_to_jax(exported) - argj = env.t2j_copy(arg[0]) - ans2 = jax.jit(func)(weights, (argj,))[0] - ans2 = env.j2t_copy(ans2) - self.assertTrue(torch.allclose(ans, ans2, atol=1e-5)) - - # Convert to StableHLO - weights, stablehlo = torchax.export.exported_program_to_stablehlo(exported) - module_str = str(stablehlo.mlir_module()) - self.assertIn("func.func public @main", module_str) - self.assertIn("stablehlo.divide", module_str) - - def test_interpolate_dynamic(self): - # Export with dynamic dimension constraints on both min and max - arg = (torch.randn(3, 3, 200, 200),) - model = Interpolate() - ans = model(*arg) - dynamic_shapes = ({0: torch.export.Dim("b", min=3, max=10)},) - - with torch.no_grad(): - exported = torch.export.export(model, arg, dynamic_shapes=dynamic_shapes) - weights, stablehlo = torchax.export.exported_program_to_stablehlo(exported) - module_str = str(stablehlo.mlir_module()) - - # Look for dynamic shape artifacts - self.assertIn("func.func public @main(%arg0: tensor", - module_str) - self.assertIn("stablehlo.dynamic_broadcast_in_dim", module_str) - self.assertIn("stablehlo.dynamic_gather", module_str) - - def test_export_dtypes(self): - DTYPE_TO_MLIR_STR = { - # NO_MAPPING : jnp.float0 (signless scalar int) - torch.bool: - "i1", - # NO_MAPPING : "i4" - torch.int8: - "i8", - torch.int16: - "i16", - torch.int32: - "i32", - torch.int64: - "i64", - torch.long: - "i64", - # NO_MAPPING : "ui4" - torch.uint8: - "ui8", - # NOTE(qihqi): torch export for uint16 seems broken at torch 2.4 - # torch.uint16 : "ui16", - torch.uint32: - "ui32", - torch.uint64: - "ui64", - # NO_MAPPING : "f8E4M3B11FNUZ" - torch.float8_e4m3fn: - "f8E4M3FN", - # NO_MAPPING : f8E4M3FNUZ - torch.float8_e5m2: - "f8E5M2", - # NO_MAPPING : f8E5M2FNUZ - torch.bfloat16: - "bf16", - torch.half: - "f16", - torch.float16: - "f16", - torch.float32: - "f32", - torch.float64: - "f64", - torch.double: - "f64", - torch.complex64: - "complex", - torch.complex128: - "complex", - None: - None, - } - - model = TensorConstant() - for torch_dtype in DTYPE_TO_MLIR_STR.keys(): - if torch_dtype == None: - ## TODO: Figure out what the None mapping should be, seems like: - ## torch.tensor(dtype=None) maps to f32 - ## jnp.tensor(dtype=None) maps to f64 - continue - arg = (torch.randn(10).to(torch_dtype),) - with torch.no_grad(): - exported = torch.export.export(model, arg) - weights, stablehlo = torchax.export.exported_program_to_stablehlo( - exported) - module_str = str(stablehlo.mlir_module()) - self.assertIn(DTYPE_TO_MLIR_STR[torch_dtype], module_str) - - -if __name__ == '__main__': - unittest.main() diff --git a/torchax/test/test_flax.py b/torchax/test/test_flax.py deleted file mode 100644 index bc5b7f219786..000000000000 --- a/torchax/test/test_flax.py +++ /dev/null @@ -1,97 +0,0 @@ -import unittest -import torch -import torchax -from flax import linen as nn -from torchax.flax import FlaxNNModule -from torchax.interop import jax_jit -import jax.numpy as jnp -import jax - - -class CNN(nn.Module): - """A simple CNN model.""" - - @nn.compact - def __call__(self, x): - x = nn.Conv(features=32, kernel_size=(3, 3))(x) - x = nn.relu(x) - x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) - x = nn.Conv(features=64, kernel_size=(3, 3))(x) - x = nn.relu(x) - x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) - x = x.reshape((x.shape[0], -1)) # flatten - x = nn.Dense(features=256)(x) - x = nn.relu(x) - x = nn.Dense(features=10)(x) - return x - - -class FlaxTest(unittest.TestCase): - - def test_flax_simple(self): - flax_model = CNN() - - inputs = jnp.ones((1, 28, 28, 1)) - env = torchax.default_env() - state = flax_model.init(env.prng_key, inputs) - expected = flax_model.apply(state, inputs) - - env = torchax.default_env() - nn_module = FlaxNNModule(env, flax_model, (inputs,), {}) - res = nn_module.forward(inputs) - - self.assertTrue(jnp.allclose(res.jax(), expected)) - - def test_flax_functional_call(self): - flax_model = CNN() - - inputs = jnp.ones((1, 28, 28, 1)) - env = torchax.default_env() - state = flax_model.init(env.prng_key, inputs) - expected = flax_model.apply(state, inputs) - - env = torchax.default_env() - nn_module = FlaxNNModule(env, flax_model, (inputs,), {}) - - @jax_jit - def jitted(weights, args): - return torch.func.functional_call(nn_module, weights, args) - - with env: - inputs_torch = torch.ones((1, 28, 28, 1), device='jax') - state_dict = nn_module.state_dict() - res = jitted(state_dict, inputs_torch) - self.assertTrue(jnp.allclose(res.jax(), expected)) - - def test_flax_module_nested(self): - env = torchax.default_env() - - class Parent(torch.nn.Module): - - def __init__(self): - super().__init__() - self.a = torch.nn.Linear(28, 28) - sample_cnn_inputs = torch.ones((1, 28, 28, 1), device='jax') - self.cnn = FlaxNNModule(env, CNN(), (sample_cnn_inputs,), {}) - - def forward(self, x): - y = self.a(x) - y = y.reshape((-1, 28, 28, 1)) - res = self.cnn(y) - return res - - with env: - nn_module = Parent().to('jax') - - @jax_jit - def jitted(weights, args): - return torch.func.functional_call(nn_module, weights, args) - - inputs_torch = torch.ones((1, 28, 28), device='jax') - state_dict = nn_module.state_dict() - res = jitted(state_dict, inputs_torch) - print(res) - - -if __name__ == '__main__': - unittest.main() diff --git a/torchax/test/test_functions.py b/torchax/test/test_functions.py deleted file mode 100644 index 03c3778bb00e..000000000000 --- a/torchax/test/test_functions.py +++ /dev/null @@ -1,104 +0,0 @@ -from typing import Callable -from absl.testing import absltest -from absl.testing import parameterized -import torch -import torchax -import torchax.tensor - - -class SeqModel(torch.nn.Module): - """ Architecture is LLM generated """ - - def __init__(self): - super().__init__() - self.gru = torch.nn.GRU(20, 30, batch_first=True) - self.linear = torch.nn.Linear(30, 1) - - def forward(self, x: torch.Tensor): - output, _ = self.gru(x) #output, hidden state - output = self.linear(output) - return output - - -class TestTorchFunctions(parameterized.TestCase): - - def setUp(self): - torchax.enable_globally() - torchax.enable_accuracy_mode() - self.env = torchax.default_env() - - @parameterized.named_parameters( - ('tensor_2d', [[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]]), - ('tensor_1d', [0, 1]), ('tensor_scalar', 3.14159), ('tensor_empty', []), - ('tensor_dtype', [[0.11111, 0.222222, 0.3333333]], { - 'dtype': torch.float64 - })) - def test_tensor_constructor(self, arg, kwargs=None): - kwargs = kwargs or {} - expected = torch.tensor(arg, **kwargs) - - actual = torch.tensor(arg, device='jax', **kwargs) - self.assertIsInstance(actual, torchax.tensor.Tensor) - - torch.testing.assert_close(actual.to('cpu'), expected) - - def test_dont_capture_conversion(self): - t = torch.tensor([1, 2, 3]) - with self.env: - t2 = self.env.to_xla(t) - # assert no exceptions - - def test_brackets(self): - with self.env: - a = torch.randn((2, 3)) - a[1] = 9 - self.assertEqual(a[1, 0].item(), 9) - - def test_bernoulli_inplace(self): - with self.env: - a = torch.randn((2, 3)) - a.bernoulli_(0.4) - - def test_flatten(self): - with self.env: - a = torch.randn((2, 3, 4)) - a = a.flatten(0, 1) - self.assertEqual(tuple(a.shape), (6, 4)) - - def test_rnn(self): - model = SeqModel() - x = torch.randn((2, 100, 20)) - res = model(x) - with self.env: - model.to('jax') - x = x.to('jax') - res2 = model(x) - print(res.shape, res2.shape) - - self.assertEqual(res.shape, res2.shape) - - def test_rms_norm(self): - model = torch.nn.RMSNorm((100, 20)) - x = torch.randn((2, 100, 20)) - res = model(x) - - with self.env: - model.to('jax') - x = x.to('jax') - res2 = model(x) - self.assertTrue(torch.allclose(res, res2.to('cpu'))) - - @parameterized.named_parameters( - ('ones', torch.ones, ((2, 2),)), ('zeros', torch.zeros, ((2, 2),)), - ('empty', torch.empty, - ((2, 2),)), ('empty_strided', torch.empty_strided, - ((2, 2), (2, 1))), ('tensor', torch.tensor, ([2.0, 2.0],)), - ('eye', torch.eye, (2,)), ('randn', torch.randn, ((2, 2),)), - ('rand', torch.rand, ((2, 2),)), ('full', torch.full, ((2, 2), 0))) - def test_requires_grad(self, func, args): - x = func(*args, requires_grad=True, device='jax') - self.assertEqual(x.requires_grad, True) - - -if __name__ == '__main__': - absltest.main() diff --git a/torchax/test/test_image.py b/torchax/test/test_image.py deleted file mode 100644 index 08e35816d80e..000000000000 --- a/torchax/test/test_image.py +++ /dev/null @@ -1,68 +0,0 @@ -from absl.testing import parameterized -import unittest -from typing import Tuple -import itertools -from functools import partial -import jax -import torch - -import torchax -import torchax.interop - - -@partial(jax.jit, static_argnums=(1, 2, 3, 4)) -def upsample_jit(tensor, output_size: Tuple[int, int], align_corners: bool, - antialias: bool, method: str): - tensor = torchax.interop.torch_view(tensor) - tensor = torch.nn.functional.interpolate( - tensor, - size=output_size, - mode=method, - align_corners=align_corners, - antialias=antialias) - return torchax.interop.jax_view(tensor) - - -class TestResampling(parameterized.TestCase): - - @parameterized.product( - antialias=[ - True, - False, - ], align_corners=[ - False, - True, - ]) - def test_resampling_combinations_bicubic(self, antialias, align_corners): - method = 'bicubic' - input_tensor = torch.rand((1, 1, 256, 512), dtype=torch.float32) - output_size = (128, 64) - - upsampled_tensor = torch.nn.functional.interpolate( - input_tensor, - size=output_size, - mode=method, - align_corners=align_corners, - antialias=antialias) - - env = torchax.default_env() - with env: - input_tensor_xla = env.to_xla(input_tensor) - input_tensor_xla = torchax.interop.jax_view(input_tensor_xla) - upsampled_tensor_xla = upsample_jit( - input_tensor_xla, - output_size, - align_corners, - antialias=antialias, - method=method) - - upsampled_tensor_xla = env.j2t_copy(upsampled_tensor_xla) - abs_err = torch.abs(upsampled_tensor - upsampled_tensor_xla) - - assert torch.allclose( - upsampled_tensor, upsampled_tensor_xla, atol=1e-4, - rtol=1e-5), f"{method} upsampling failed with error {abs_err.max()}" - - -if __name__ == '__main__': - unittest.main() diff --git a/torchax/test/test_interop.py b/torchax/test/test_interop.py deleted file mode 100644 index fe17c95292a7..000000000000 --- a/torchax/test/test_interop.py +++ /dev/null @@ -1,172 +0,0 @@ -import functools -import torch -import unittest -import torchax -from torchax import interop -import torchax -import jax -import jax.numpy as jnp - - -def is_tpu_available(): - """Checks if any TPU devices are available to JAX.""" - try: - # jax.devices('tpu') will return a list of TPU devices if available. - # If no TPUs are found or JAX is not configured for TPU, - # it will raise a RuntimeError. - tpu_devices = jax.devices('tpu') - return len(tpu_devices) > 0 - except RuntimeError: - return False - - -class InteropTest(unittest.TestCase): - - def setUp(self): - torchax.enable_globally() - - def test_mod_attr(self): - - class Child(torch.nn.Module): - - def __init__(self): - super().__init__() - self.x = torch.ones(10, 10) - - class ModuleWithUnregisteredTensor(torch.nn.Module): - - def __init__(self): - super().__init__() - self.a = torch.nn.Linear(100, 100) - self.b = torch.nn.Parameter(torch.ones(10, 10)) - c = torch.ones(10, 10) - self.register_buffer('c', c) - self.register_buffer('c2', c, persistent=False) - self.d = torch.ones(10, 10) - self.m1 = Child() - - m = ModuleWithUnregisteredTensor() - params, buffers = interop.extract_all_buffers(m) - self.assertEqual(set(params.keys()), {'a.weight', 'a.bias', 'b'}) - self.assertEqual(set(buffers.keys()), {'c', 'c2', 'd', 'm1.x'}) - - interop.set_all_buffers(m, {'a.weight': torch.tensor([0.0])}, - {'m1.x': torch.tensor([0.0])}) - self.assertEqual(m.a.weight.item(), 0) - self.assertEqual(m.m1.x.item(), 0) - - def test_j2t_autograd_forward(self): - with torchax.default_env(): - # Setup - def fn(x): - return x + 1 - - j2t_fn = interop.j2t_autograd(fn) - x = torch.ones(2, 2, requires_grad=True, device='jax') - - # Act - actual = j2t_fn(x) - - # Assert - expected = torch.ones(2, 2) + 1 - torch.testing.assert_close(actual, expected, check_device=False) - - def test_j2t_autograd_backward(self): - with torchax.default_env(): - # Setup - def fn(x): - return x * 2 - - j2t_fn = interop.j2t_autograd(fn) - x = torch.ones(2, 2, device='jax').requires_grad_() - - # Act - actual = j2t_fn(x) - actual.sum().backward() - - # Assert - expected = torch.ones(2, 2) * 2 - torch.testing.assert_close(x.grad, expected, check_device=False) - - def test_module_with_shared_weights(self): - - # arrange - class ModuleWithSharedWeights(torch.nn.Module): - - def __init__(self): - super().__init__() - self.a = torch.nn.Linear(10, 10) - self.b = self.a - - def forward(self, x): - return self.a(self.b(x)) - - m = ModuleWithSharedWeights().to('jax') - - m_jitted = interop.JittableModule(m, dedup_parameters=True) - - # a's weights and bias and b's weights and bias - self.assertEqual(len(m.state_dict()), 4) - - # b's weights and bias are deduped - self.assertEqual(len(m_jitted.params), 2) - x = torch.randn(10, 10).to('jax') - expected = m(x) - - # act - actual = m_jitted(x) - - # assert - torch.testing.assert_allclose(actual, expected) - - # arrange - # make sure buffer donation works - functional_forward = interop.jax_jit( - functools.partial(m_jitted.functional_call, 'forward'), - kwargs_for_jax_jit={'donate_argnums': (0,)}) - - # act - actual = functional_forward(m_jitted.params, m_jitted.buffers, x) - # assert - torch.testing.assert_allclose(actual, expected) - - def test_to_jax_device(self): - a = torch.ones(3, 3) - - if is_tpu_available(): - # by default if tpu is available, to jax will be to tpu - e = a.to("jax") - self.assertEqual(e.jax_device.platform, "tpu") - self.assertEqual(e.device.type, "jax") - else: - e = a.to("jax") - self.assertEqual(e.jax_device.platform, "cpu") - self.assertEqual(e.device.type, "jax") - - with jax.default_device(jax.devices("cpu")[0]): - # move torch.tensor to torchax.tensor CPU - b = a.to("jax") - self.assertEqual(b.jax_device.platform, "cpu") - self.assertEqual(b.device.type, "jax") - - if is_tpu_available(): - # move torch.tensor to torchax.tensor TPU - with jax.default_device(jax.local_devices("tpu")[0]): - c = a.to("jax") - self.assertEqual(c.jax_device.platform, "tpu") - self.assertEqual(c.device.type, "jax") - - def test_torch_jax_view_dtype(self): - dtype = torch.float32 - self.assertEqual(interop.jax_view(dtype), jnp.float32.dtype) - self.assertEqual(interop.torch_view(interop.jax_view(dtype)), dtype) - dtype = torch.bfloat16 - self.assertEqual(interop.jax_view(dtype), jnp.bfloat16.dtype) - self.assertEqual(interop.torch_view(interop.jax_view(dtype)), dtype) - dtype = torch.int32 - self.assertEqual(interop.jax_view(dtype), jnp.int32.dtype) - self.assertEqual(interop.torch_view(interop.jax_view(dtype)), dtype) - - -if __name__ == '__main__': - unittest.main() diff --git a/torchax/test/test_jittable_module.py b/torchax/test/test_jittable_module.py deleted file mode 100644 index 3a44450cbe8f..000000000000 --- a/torchax/test/test_jittable_module.py +++ /dev/null @@ -1,58 +0,0 @@ -import unittest -from torchax import interop -import torch - - -class MyAwesomeModel(torch.nn.Module): - pass - - -class EvenMoreAwesomeModel(torch.nn.Module): - pass - - -class JittableModuleTest(unittest.TestCase): - - def test_isinstance_works(self): - - # Export and check for composite operations - model = MyAwesomeModel() - jittable_module = interop.JittableModule(model) - - # jittable_module should remain an instance of MyAwesomeModel logicailly - assert isinstance(jittable_module, MyAwesomeModel) - - def test_isinstance_does_not_mix(self): - - # Export and check for composite operations - JittableAwesomeModel = interop.JittableModule(MyAwesomeModel()) - JittableMoreAwesomeModel = interop.JittableModule(EvenMoreAwesomeModel()) - - # jittable_module should remain an instance of MyAwesomeModel logicailly - assert isinstance(JittableAwesomeModel, MyAwesomeModel) - assert not isinstance(JittableAwesomeModel, EvenMoreAwesomeModel) - assert isinstance(JittableMoreAwesomeModel, EvenMoreAwesomeModel) - assert not isinstance(JittableMoreAwesomeModel, MyAwesomeModel) - - def test_functional_call_callable(self): - - def outer_function(model, x): - return x + 1 - - model = MyAwesomeModel() - jittable_module = interop.JittableModule(model) - - # Check if the jittable module can be called like a function - input_tensor = torch.randn(1, 3, 224, 224) - expected_output = input_tensor + 1 - - output = jittable_module.functional_call(outer_function, - jittable_module.params, - jittable_module.buffers, - input_tensor) - - assert torch.equal(output, expected_output) - - -if __name__ == '__main__': - unittest.main() diff --git a/torchax/test/test_libraries.py b/torchax/test/test_libraries.py deleted file mode 100644 index 69ed3c77e53b..000000000000 --- a/torchax/test/test_libraries.py +++ /dev/null @@ -1,87 +0,0 @@ -import unittest -import torch -import torch.nn.functional as F -from torch.library import Library, impl, impl_abstract -import torchax -import torchax.export -from torchax.ops import jaten -from torchax.ops import jlibrary -# Create a `mylib` library which has a basic SDPA op. -m = Library("mylib", "DEF") -m.define("scaled_dot_product_attention(Tensor q, Tensor k, Tensor v) -> Tensor") - - -@impl(m, "scaled_dot_product_attention", "CompositeExplicitAutograd") -def _mylib_scaled_dot_product_attention(q, k, v): - """Basic scaled dot product attention without all the flags/features.""" - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - y = F.scaled_dot_product_attention( - q, - k, - v, - dropout_p=0, - is_causal=False, - scale=None, - ) - return y.transpose(1, 2) - - -@impl_abstract("mylib::scaled_dot_product_attention") -def _mylib_scaled_dot_product_attention_meta(q, k, v): - return torch.empty_like(q) - - -# Register library op as a composite for export using the `@impl` method -# for a torch decomposition. -jlibrary.register_torch_composite( - "mylib.scaled_dot_product_attention", _mylib_scaled_dot_product_attention, - torch.ops.mylib.scaled_dot_product_attention, - torch.ops.mylib.scaled_dot_product_attention.default) - -# Also register ATen softmax as a composite for export in the `mylib` library -# using the JAX ATen decomposition from `jaten`. -jlibrary.register_jax_composite( - "mylib.softmax", - jaten._aten_softmax, - torch.ops.aten._softmax, - static_argnums=1 # Required by JAX jit -) - - -class LibraryTest(unittest.TestCase): - - def setUp(self): - torch.manual_seed(0) - - def test_basic_sdpa_library(self): - - class CustomOpExample(torch.nn.Module): - - def forward(self, q, k, v): - x = torch.ops.mylib.scaled_dot_product_attention(q, k, v) - x = x + 1 - return x - - # Export and check for composite operations - model = CustomOpExample() - arg = torch.rand(32, 8, 128, 64) - args = ( - arg, - arg, - arg, - ) - - exported = torch.export.export(model, args=args) - weights, stablehlo = torchax.export.exported_program_to_stablehlo(exported) - module_str = str(stablehlo.mlir_module()) - - ## TODO Update this machinery from producing function calls to producing - ## stablehlo.composite ops. - self.assertIn("call @mylib.scaled_dot_product_attention", module_str) - self.assertIn("call @mylib.softmax", module_str) - - -if __name__ == '__main__': - unittest.main() diff --git a/torchax/test/test_misc.py b/torchax/test/test_misc.py deleted file mode 100644 index 9214c5b1eac6..000000000000 --- a/torchax/test/test_misc.py +++ /dev/null @@ -1,59 +0,0 @@ -"""If you don't know which file a test should go, and don't want to make a new file -for a small test. PUt it here -""" -import torch -import unittest -import torchax -import jax -import jax.numpy as jnp - - -class MiscTest(unittest.TestCase): - - def test_extract_jax_kwargs(self): - - class M(torch.nn.Module): - - def forward(self, a, b): - return torch.sin(a) + torch.cos(b) - - weights, func = torchax.extract_jax(M()) - res = func( - weights, - args=(), - kwargs={ - 'a': jnp.array([1, 2, 3]), - 'b': jnp.array([3, 4, 5]) - }) - self.assertTrue( - jnp.allclose( - res, - jnp.sin(jnp.array([1, 2, 3])) + jnp.cos(jnp.array([3, 4, 5])))) - - def test_to_device(self): - env = torchax.default_env() - with env: - step1 = torch.ones( - 100, - 100, - ) - step2 = torch.triu(step1, diagonal=1) - step3 = step2.to(dtype=torch.bool, device='jax') - self.assertEqual(step3.device.type, 'jax') - - def test_to_device_twice(self): - env = torchax.default_env() - env.config.debug_print_each_op = True - with env: - step1 = torch.ones( - 100, - 100, - ) - step2 = torch.triu(step1, diagonal=1) - step3 = step2.to(dtype=torch.bool, device='jax') - step3.to('jax') - self.assertEqual(step3.device.type, 'jax') - - -if __name__ == '__main__': - unittest.main() diff --git a/torchax/test/test_mutations.py b/torchax/test/test_mutations.py deleted file mode 100644 index ccbc359485c8..000000000000 --- a/torchax/test/test_mutations.py +++ /dev/null @@ -1,52 +0,0 @@ -import unittest -import torchax -import torch -from torch.testing._internal.common_utils import TestCase - - -class TestMutations(TestCase): - - def setUp(self): - self.env = torchax.tensor.Environment() - self.env.config.debug_print_each_op = True - - def test_add(self): - with self.env: - x = torch.tensor([1, 2, 3], device='jax', dtype=torch.int32) - y = torch.tensor([4, 5, 6], device='jax', dtype=torch.int32) - x.add_(y) - torch.testing.assert_close(x.cpu(), - torch.tensor([5, 7, 9], dtype=torch.int32)) - - def test_sub(self): - with self.env: - x = torch.tensor([1, 2, 3], device='jax', dtype=torch.int32) - y = torch.tensor([4, 5, 6], device='jax', dtype=torch.int32) - x.sub_(y) - torch.testing.assert_close(x.cpu(), - torch.tensor([-3, -3, -3], dtype=torch.int32)) - - def test_mul(self): - with self.env: - x = torch.tensor([1, 2, 3], device='jax', dtype=torch.int32) - y = torch.tensor([4, 5, 6], device='jax', dtype=torch.int32) - - x.mul_(y) - torch.testing.assert_close(x.cpu(), - torch.tensor([4, 10, 18], dtype=torch.int32)) - - def test_index_copy(self): - with self.env: - x = torch.zeros(5, 3, device='jax') - t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], - device='jax', - dtype=torch.float) - index = torch.tensor([0, 4, 2], device='jax') - x.index_copy_(0, index, t) - expected = torch.tensor([[1., 2., 3.], [0., 0., 0.], [7., 8., 9.], - [0., 0., 0.], [4., 5., 6.]]) - torch.testing.assert_close(x.cpu(), expected) - - -if __name__ == '__main__': - unittest.main() diff --git a/torchax/test/test_ops.py b/torchax/test/test_ops.py deleted file mode 100644 index 54ef1c30b5a3..000000000000 --- a/torchax/test/test_ops.py +++ /dev/null @@ -1,235 +0,0 @@ -import unittest - -import torch -from torch.testing._internal.common_utils import TestCase -from torch.testing._internal.common_methods_invocations import op_db -from torch.testing._internal.common_device_type import ( - instantiate_device_type_tests, ops) -from torch.utils import _pytree as pytree -from torchax import tensor -import torchax - -skiplist = { - "_segment_reduce", - "bincount", # NOTE: dtype for int input torch gives float. This is weird. - "byte", - "cholesky_solve", - "geqrf", - "histogram", # hard op: AssertionError: Tensor-likes are not close! - "histogramdd", # TypeError: histogram requires ndarray or scalar arguments, got at position 1. - "index_reduce", - "linalg.ldl_solve", - "max_pool2d_with_indices_backward", - "nn.functional.adaptive_max_pool1d", - "nn.functional.adaptive_max_pool2d", - "nn.functional.adaptive_max_pool3d", - "nn.functional.alpha_dropout", - "nn.functional.ctc_loss", - "nn.functional.embedding_bag", - "nn.functional.fractional_max_pool2d", - "nn.functional.fractional_max_pool3d", - "nn.functional.multi_head_attention_forward", - "normal", - "ormqr", - "pca_lowrank", - "searchsorted", - "special.airy_ai", - "special.scaled_modified_bessel_k0", - "special.scaled_modified_bessel_k1", - "special.spherical_bessel_j0", - "special.zeta", - "unfold_copy", - "unfold", -} - -not_support_ops_list = { - "chalf", # Skip due to jax not support complex32 with backend: https://github.com/google/jax/issues/14180 - "__rpow__", # NOTE: cannot fix because torch test case has undefined behavior - # such as 0 to negative power. - "to_sparse", # We are not supporting sparse tensors yet. - "nn.functional.rrelu", # pure torch result match torchax test result, only OpInfo mismatch: https://gist.github.com/ManfeiBai/1a449b15f4e946bfcaa3e5ef86da20f4 -} - -# These inputs are themselves views -# We cannot know how are the views created so cannot replicate the behavior. -variant_test_name_to_skip = { - "partial_views", -} - -random_ops = { - 'empty', - 'empty_like', - 'empty_permuted', - 'empty_strided', - 'bernoulli', - 'geometric', - 'new_empty', - 'new_empty_strided', - 'randint_like', - 'randn', - 'randn_like', - 'rand', - 'rand_like', - 'uniform', - 'multinomial', - # Dropout is not deterministic https://pytorch.org/docs/stable/generated/torch.nn.functional.feature_alpha_dropout.html - 'nn.functional.feature_alpha_dropout', - 'cauchy', - 'exponential', - 'log_normal', - 'randint', - 'nn.functional.dropout2d', - 'nn.functional.dropout3d', - 'nn.functional.dropout', -} - -atol_dict = { - "cov": (2e-1, 2e-4), - "linalg.eig": (2e0, 3e0), - "linalg.eigh": (5e1, 3e0), - "linalg.eigvalsh": (5e1, 3e0), - "linalg.pinv": (8e-1, 2e0), - "linalg.svd": (1e0, 1e0), - "svd": (1e0, 1e0), - "svd_lowrank": (1e0, 1e0), - "matrix_exp": (2e-1, 2e-4), - "cdist": (5e1, 3e0) -} - - -def diff_output(testcase, - output1, - output2, - rtol, - atol, - equal_nan=True, - check_output=True): - if isinstance(output1, torch.Tensor): - testcase.assertIsInstance(output2, torch.Tensor) - output2_cpu = output2.detach().cpu() - if output1.layout != torch.strided: - # We only compare dense tensors. We dont currently support sparse tensors - output1 = output1.to_dense() - if check_output: - torch.testing.assert_close( - output2_cpu, output1, rtol=rtol, atol=atol, equal_nan=equal_nan) - else: - testcase.assertEqual((output1.shape, output1.dtype), - (output2.shape, output2.dtype)) - elif isinstance(output1, (tuple, list)): - testcase.assertIsInstance(output2, (tuple, list)) - testcase.assertEqual(len(output1), len(output2)) - for o1, o2 in zip(output1, output2): - diff_output(testcase, o1, o2, rtol, atol) - else: - testcase.assertEqual(output1, output2) - - -def run_export_and_compare(testcase, - func, - sample_input, - check_output=True, - equal_nan=True, - ignore_indices=False): - atol, rtol = (1e-3, 1e-5) - if func.name in atol_dict: - atol, rtol = atol_dict[func.name] - - with testcase.subTest("torch_eval"): - res = func(sample_input.input, *sample_input.args, **sample_input.kwargs) - with testcase.subTest("torchax_eval"): - input2, args2, kwargs2 = testcase.env.to_xla( - (sample_input.input, sample_input.args, sample_input.kwargs)) - if 'device' in kwargs2: - kwargs2['device'] = 'jax' - with testcase.env: - res2 = func(input2, *args2, **kwargs2) - res2 = pytree.tree_map_only(tensor.Tensor, lambda t: t.torch(), res2) - with testcase.subTest("torchax_diff:" + str(atol)): - if ignore_indices and isinstance(res, tuple) and len(res) == 2: - diff_output( - testcase, - res[0], - res2[0], - atol=atol, - rtol=rtol, - equal_nan=equal_nan, - check_output=check_output) - else: - diff_output( - testcase, - res, - res2, - atol=atol, - rtol=rtol, - equal_nan=equal_nan, - check_output=check_output) - - -ops_to_test = [ - test for test in op_db - if (test.name not in (skiplist | not_support_ops_list) and - test.variant_test_name not in variant_test_name_to_skip) -] - -# Sort related ops should ignore index; -# For example: sort( [1, 0, 0]) -> [0, 0, 1] -# the correct index can be [1, 2, 0] or [2, 1, 0] -should_ignore_indexes = {"topk", "mode", "kthvalue"} - - -class TestOpInfo(TestCase): - - @classmethod - def setUpClass(cls): - print('op_db size: ', len(op_db), 'testing: ', len(ops_to_test)) - - def setUp(self): - self.env = torchax.default_env() - torchax.enable_accuracy_mode() - #self.env.config.debug_accuracy_for_each_op = True - self.env.config.debug_print_each_op = False - torch.manual_seed(0) - - # Replaces all values in the input torch_tensor that are less than the given threshold - # with the threshold value itself. - def replace_values_below_threshold(self, torch_tensor, threshold): - return torch.where(torch_tensor < threshold, torch.tensor(threshold), - torch_tensor) - - @ops(ops_to_test, allowed_dtypes=(torch.float32, torch.long)) - def test_reference_eager(self, device, dtype, op): - sample_inputs = op.sample_inputs(device, dtype) - for sample_input in sample_inputs: - t = sample_input.input - if isinstance(t, torch.Tensor) and t.is_sparse: - continue - check_output = op.name not in random_ops - - # print("[DEBUG] sample_input: ", sample_input) - - # TODO: this is a workaround to skip int64 cast for linspace - # reference: https://github.com/pytorch/xla/issues/7505#issuecomment-2400895692 and subsequent comments - # we have opened a bug in pytorch: https://github.com/pytorch/pytorch/issues/137546 - if op.name == "linspace": - if 'dtype' in sample_input.kwargs: - if sample_input.kwargs['dtype'] == torch.int64: - sample_input.kwargs['dtype'] = torch.float - if op.name == "polygamma" or op.name == "special.polygamma": - # The polygamma function is inaccurate for values < 1. - # To avoid errors during testing, replace values below 1 with 1. - sample_input.input = self.replace_values_below_threshold( - sample_input.input, 1) - if op.name == "nn.functional.scaled_dot_product_attention": - check_output = sample_input.kwargs.get('dropout_p') == 0.0 - - ignore_index = op.name in should_ignore_indexes - - run_export_and_compare( - self, op, sample_input, check_output, ignore_indices=ignore_index) - - -instantiate_device_type_tests(TestOpInfo, globals(), only_for={'cpu'}) - -if __name__ == '__main__': - unittest.main() diff --git a/torchax/test/test_symbolic_shapes.py b/torchax/test/test_symbolic_shapes.py deleted file mode 100644 index 94233ca6a07a..000000000000 --- a/torchax/test/test_symbolic_shapes.py +++ /dev/null @@ -1,102 +0,0 @@ -import torch -import torchax -import torchax.export -from . import base_test_util - - -class AddOne(torch.nn.Module): - - def __init__(self): - super().__init__() - - def forward(self, a): - return a + 1 - - -class ConcatAddModel(torch.nn.Module): - - def __init__(self): - super().__init__() - - def forward(self, a, b): - a = torch.concat([a, a], dim=0) - return a + b - - -class SymbolicShapeTest(base_test_util.TestCase): - """Test possible symbolic shape computations that upstream torch export can - emit. Seems to be currently limited to a few binary math operations where one - operand is a symbolic variable/expr and the other is a constant integer. - """ - - def setUp(self): - torch.manual_seed(0) - - def test_constraints_min_max(self): - """Test a model with basic min/max dimension restrictions - """ - - # Arg shapes are a=s0{<=10}, b=s0*2 - model = AddOne() - args = (torch.rand(5),) - sym_a = torch.export.Dim("a", min=3, max=10) - dynamic_shapes = ({0: sym_a},) - - with torch.no_grad(): - exported = torch.export.export( - model, args=args, dynamic_shapes=dynamic_shapes) - weights, stablehlo = torchax.export.exported_program_to_stablehlo(exported) - module_str = str(stablehlo.mlir_module()) - - self.assertRegex(module_str, r"stablehlo.constant.*3") - self.assertRegex(module_str, r"shape_assertion.*s[0-9]+ >= 3") - self.assertRegex(module_str, r"stablehlo.constant.*10") - self.assertRegex(module_str, r"shape_assertion.*s[0-9]+ <= 10") - - def test_constraints_multiply(self): - """Test a model with a slightly more complex constraint, where the input - shapes are determined by an equation of the other, in this case input shapes - are s0{<=10} and s0*2. - """ - # Arg shapes are a=s0{<=10}, b=s0*2 - model = ConcatAddModel() - args = (torch.rand(2), torch.rand(4)) - sym_a = torch.export.Dim("a", max=10) - sym_b = sym_a * 2 - dynamic_shapes = ({0: sym_a}, {0: sym_b}) - - with torch.no_grad(): - exported = torch.export.export( - model, args=args, dynamic_shapes=dynamic_shapes) - weights, stablehlo = torchax.export.exported_program_to_stablehlo(exported) - module_str = str(stablehlo.mlir_module()) - - self.assertRegex(module_str, r"stablehlo.constant.*10") - self.assertRegex(module_str, r"shape_assertion.*s[0-9]+ <= 10") - self.assertRegex(module_str, r"stablehlo.constant.*2") - self.assertRegex(module_str, r"shape_assertion.*2\*s[0-9]+") - - def test_constraint_indirection(self): - """Test a model where none of the shapes are directly symbolic variables - but all are expressions of symints that don't appear directly in the model. - """ - - # Arg shapes are b=s0{<=10}*2 - args = (torch.randn(10, 10),) - model = AddOne() - sym_a = torch.export.Dim("a", max=10) - sym_b = sym_a * 2 - dynamic_shapes = ({0: sym_b},) - - with torch.no_grad(): - exported = torch.export.export( - model, args=args, dynamic_shapes=dynamic_shapes) - weights, stablehlo = torchax.export.exported_program_to_stablehlo(exported) - module_str = str(stablehlo.mlir_module()) - - self.assertRegex(module_str, r"shape_assertion.*s[0-9]+ <= 10") - self.assertRegex(module_str, r"shape_assertion.*2\*s[0-9]+") - - -if __name__ == "__main__": - base_test_util.main() diff --git a/torchax/test/test_train.py b/torchax/test/test_train.py deleted file mode 100644 index ce07b8e3d447..000000000000 --- a/torchax/test/test_train.py +++ /dev/null @@ -1,58 +0,0 @@ -import unittest -import torch -import torchax as tx -import torchax.export -import torchax.train -from torch.testing._internal.common_utils import TestCase - - -class TrainTest(unittest.TestCase): - - def setUp(self): - torch.manual_seed(0) - torchax.enable_accuracy_mode() - - def test_scan_module(self): - x = torch.arange(300).reshape(3, 100).to(torch.float32) - layers = [ - torch.nn.Linear(100, 100), - torch.nn.Linear(100, 100), - torch.nn.Linear(100, 100), - torch.nn.Linear(100, 100), - ] - # repetitively applies the linear - result = x - for layer in layers: - result = layer(result) - - model = tx.train.ScannedModule(layers) - - with torchax.default_env(): - x = x.to('jax') - model.to('jax') - result2 = model(x) - torch.testing.assert_allclose(result, result2.to('cpu')) - - def test_train_step_can_run(self): - import optax - with torchax.default_env(): - model = torch.nn.Linear(100, 100) - model.to('jax') - weights = model.state_dict() - x = torch.randn(2, 100).to('jax') - y = torch.tensor([1, 2]).to('jax') - - def model_fn(weight, buffers, args): - return torch.func.functional_call(model, weight, args) - - loss_fn = torch.nn.CrossEntropyLoss() - - optimizer = optax.adam(0.01) - opt_state = tx.interop.call_jax(optimizer.init, weights) - - step = tx.train.make_train_step(model_fn, loss_fn, optimizer) - print(step(weights, {}, opt_state, x, y)) - - -if __name__ == '__main__': - unittest.main() diff --git a/torchax/test/test_unbounded_dynamism.py b/torchax/test/test_unbounded_dynamism.py deleted file mode 100644 index d15b1750678a..000000000000 --- a/torchax/test/test_unbounded_dynamism.py +++ /dev/null @@ -1,673 +0,0 @@ -import re -import sys -import unittest - -import torch -from torch.export import Dim, export -from torchax.export import exported_program_to_stablehlo as exp2shlo -import torchax - -## This file is copied from `xla/test/stablehlo/test_unbounded_dynamism.py` -## To test that torchax has identical behavior. -## The only differences in this test files are that torchax export preserves -## argument order more often than torch_xla export. -## -## This broke ~5 tests, for example: test_bmm_dynamic_out_dim -## args = ( -## torch.rand((8, 128, 256)), -## torch.rand((8, 256, 3)), -## ) -## dynamic_shapes = ((None, {2: Dim("dim")}),) -## ... -## torch_xla_regex = r'%arg.: tensor<8x256x\?xf32>.*%arg.: tensor<8x128x256xf32>.*->.*tensor<8x128x\?xf32>' -## torchax_regex = r'%arg.: tensor<8x128x256xf32>.*%arg.: tensor<8x256x\?xf32>.*->.*tensor<8x128x\?xf32>' - - -# Shim to run tests -class ExportAdapter(): - - def __init__(self, export): - self.export = export - - def get_stablehlo_text(self): - return self.export.mlir_module() - - -def exported_program_to_stablehlo(exported): - return ExportAdapter(exp2shlo(exported)[1]) - - -def wrap_func_as_nn_module(f): - - class M(torch.nn.Module): - - def __init__(self): - super().__init__() - - def forward(self, *args): - return f(*args) - - return M().eval() - - -class UnboundedDynamismExportTest(unittest.TestCase): - - def setUp(self): - torchax.enable_accuracy_mode() - - def test_add(self): - args = (torch.rand((10, 197, 768)), torch.rand((10, 197, 768))) - dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),) - m = wrap_func_as_nn_module(torch.ops.aten.add.Tensor) - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search( - r'tensor<\?x197x768xf32>.*tensor<\?x197x768xf32>.*->.*tensor<\?x197x768xf32>', - shlo_text) is not None) - - def test_add_scalar(self): - args = (torch.rand((10, 197, 768)), 0.345) - dynamic_shapes = (({0: Dim("dim")}, None),) - m = wrap_func_as_nn_module(torch.ops.aten.add.Tensor) - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search(r'tensor<\?x197x768xf32>.*->.*tensor<\?x197x768xf32>', - shlo_text) is not None) - - def test_addmm(self): - args = (torch.rand((5)), torch.rand((10, 5)), torch.rand((5, 5))) - dynamic_shapes = ((None, {0: Dim("dim")}, None),) - m = wrap_func_as_nn_module(torch.ops.aten.addmm.default) - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search(r'tensor<\?x5xf32>.*->.*tensor<\?x5xf32>', shlo_text) - is not None) - - def test_bmm(self): - args = ( - torch.rand((24, 197, 64)), - torch.rand((24, 64, 197)), - ) - dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),) - m = wrap_func_as_nn_module(torch.ops.aten.bmm.default) - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search( - r'%arg.: tensor<\?x197x64xf32>.*%arg.: tensor<\?x64x197xf32>.*->.*tensor<\?x197x197xf32>', - shlo_text) is not None) - - def test_bmm_dynamic_out_dim(self): - args = ( - torch.rand((8, 128, 256)), - torch.rand((8, 256, 3)), - ) - dynamic_shapes = ((None, {2: Dim("dim")}),) - m = wrap_func_as_nn_module(torch.ops.aten.bmm.default) - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search( - r'%arg.: tensor<8x128x256xf32>.*%arg.: tensor<8x256x\?xf32>.*->.*tensor<8x128x\?xf32>', - shlo_text) is not None) - - def test_bmm_dynamic_reduction_dim(self): - args = ( - torch.rand((8, 128, 3)), - torch.rand((8, 3, 256)), - ) - dynamic_shapes = (({2: Dim("dim")}, {1: Dim("dim")}),) - m = wrap_func_as_nn_module(torch.ops.aten.bmm.default) - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search( - r'%arg.: tensor<8x128x\?xf32>.*%arg.: tensor<8x\?x256xf32>.*->.*tensor<8x128x256xf32>', - shlo_text) is not None) - - def test_cat(self): - args = (torch.rand((10, 1, 768)), torch.rand((10, 196, 768))) - dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),) - m = wrap_func_as_nn_module( - lambda x, y: torch.ops.aten.cat.default([x, y], 1)) - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search( - r'%arg.: tensor<\?x1x768xf32>.*%arg.: tensor<\?x196x768xf32>.*->.*tensor<\?x197x768xf32>', - shlo_text) is not None) - - def test_conv(self): - args = ( - torch.rand((10, 3, 224, 224)), - torch.rand((5, 3, 16, 16)), - torch.rand((5)), - ) - dynamic_shapes = (({0: Dim("dim")}, None, None),) - m = wrap_func_as_nn_module( - lambda x, y, z: torch.ops.aten.convolution.default( - x, - y, - z, - [16, 16], - [0, 0], - [1, 1], - False, - [0, 0], - 1, - )) - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search(r'tensor<\?x3x224x224xf32>.*->.*tensor<\?x5x14x14xf32>', - shlo_text) is not None) - - def test_conv1d(self): - args = ( - torch.rand((3, 1, 800)), - torch.rand((512, 1, 10)), - ) - dynamic_shapes = (({0: Dim("dim")}, None),) - # dynamic_shapes = None - m = wrap_func_as_nn_module(lambda x, y: torch.ops.aten.convolution.default( - x, - y, - None, - [5], - [0], - [1], - False, - [0], - 1, - )) - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search(r'tensor<\?x1x800xf32>.*->.*tensor<\?x512x159xf32>', - shlo_text) is not None) - - def test_cumsum(self): - args = (torch.rand((10, 5)), 1) - dynamic_shapes = (({0: Dim("dim")}, None),) - m = wrap_func_as_nn_module(torch.ops.aten.cumsum.default) - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search(r'tensor<\?x5xf32>.*->.*tensor<\?x5xf32>', shlo_text) - is not None) - - def test_div(self): - args = (torch.rand((10, 12, 197)), torch.rand((10, 12, 197))) - dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),) - m = wrap_func_as_nn_module(torch.ops.aten.div.Tensor) - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search( - r'tensor<\?x12x197xf32>.*tensor<\?x12x197xf32>.*->.*tensor<\?x12x197xf32>', - shlo_text) is not None) - - def test_div_scalar(self): - args = (torch.rand((10, 12, 197)), 8.0) - dynamic_shapes = (({0: Dim("dim")}, None),) - m = wrap_func_as_nn_module(torch.ops.aten.div.Tensor) - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search(r'tensor<\?x12x197xf32>.*->.*tensor<\?x12x197xf32>', - shlo_text) is not None) - - def test_gelu(self): - args = (torch.rand((3, 5)),) - dynamic_shapes = (({0: Dim("dim")},),) - m = wrap_func_as_nn_module(torch.ops.aten.gelu) - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search(r'tensor<\?x5xf32>.*->.*tensor<\?x5xf32>', shlo_text) - is not None) - - def test_embedding(self): - - class M(torch.nn.Module): - - def forward(self, x, y): - res = torch.ops.aten.embedding.default(x, y) - return res - - args = (torch.rand((20, 768)), torch.randint(0, 15, - (3, 10)).to(torch.int64)) - dynamic_shapes = (None, {0: Dim("dim")}) - m = M() - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search(r"%arg.: tensor<\?x10xi64>.*->.*tensor<\?x10x768xf32>", - shlo_text) is not None) - - def test_mean(self): - - class M(torch.nn.Module): - - def forward(self, x): - return torch.mean(x, -1, keepdim=True) - - args = (torch.rand((10, 197, 768)),) - dynamic_shapes = ({0: Dim("dim")},) - m = M() - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search(r"%arg.: tensor<\?x197x768xf32>.*->.*tensor<\?x197x1xf32>", - shlo_text) is not None) - - def test_mul(self): - args = (torch.rand((10, 2, 768)), torch.rand((10, 2, 768))) - dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),) - m = wrap_func_as_nn_module(torch.ops.aten.mul.Tensor) - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search( - r'tensor<\?x2x768xf32>.*tensor<\?x2x768xf32>.*->.*tensor<\?x2x768xf32>', - shlo_text) is not None) - - def test_mul_scalar(self): - args = (torch.rand((10, 2, 768)), 0.125) - dynamic_shapes = (({0: Dim("dim")}, None),) - m = wrap_func_as_nn_module(torch.ops.aten.mul.Tensor) - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search(r'tensor<\?x2x768xf32>.*->.*tensor<\?x2x768xf32>', shlo_text) - is not None) - - def test_ne_scalar(self): - - class M(torch.nn.Module): - - def forward(self, x): - return torch.ops.aten.ne.Scalar(x, 1).to(torch.int32) - - args = (torch.rand((3, 5)).to(torch.int64),) - dynamic_shapes = ({0: Dim("dim")},) - m = M() - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search(r"%arg.: tensor<\?x5xi64>.*->.*tensor<\?x5xi32>", shlo_text) - is not None) - - def test_var(self): - - class M(torch.nn.Module): - - def forward(self, x): - return torch.var(x, -1, keepdim=True, correction=0) - - args = (torch.rand((10, 197, 768)),) - dynamic_shapes = ({0: Dim("dim")},) - m = M() - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search(r"%arg.: tensor<\?x197x768xf32>.*->.*tensor<\?x197x1xf32>", - shlo_text) is not None) - - def test_native_group_norm(self): - - class M2(torch.nn.Module): - - def __init__(self): - super().__init__() - self.layer_norm = torch.nn.GroupNorm( - num_groups=512, num_channels=512, affine=True) - - def forward(self, x): - x = self.layer_norm(x) - return x - - args = (torch.rand((10, 512, 159)),) - dynamic_shapes = ({0: Dim("dim")},) - m = M2() - out1 = m(*args) - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search(r"%arg.: tensor<\?x512x159xf32>.*->.*tensor<\?x512x159xf32>", - shlo_text) is not None) - - def test_native_layer_norm(self): - - class M(torch.nn.Module): - - def forward(self, x, weight, bias): - return torch.ops.aten.native_layer_norm.default(x, [768], weight, bias, - 1e-12)[0] - - args = ( - torch.rand((10, 197, 768)), - torch.rand((768)), - torch.rand((768)), - ) - dynamic_shapes = ({0: Dim("dim")}, None, None) - m = M() - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search(r"%arg.: tensor<\?x197x768xf32>.*->.*tensor<\?x197x768xf32>", - shlo_text) is not None) - - def test_permute(self): - args = (torch.rand((10, 197, 12, 64)),) - dynamic_shapes = (({0: Dim("dim")},),) - m = wrap_func_as_nn_module( - lambda x: torch.ops.aten.permute.default(x, [0, 2, 1, 3])) - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search( - r"%arg.: tensor<\?x197x12x64xf32>.*->.*tensor<\?x12x197x64xf32>", - shlo_text) is not None) - - def test_select(self): - args = (torch.rand((10, 197, 768)), 1, 0) - dynamic_shapes = (({0: Dim("dim")}, None, None),) - m = wrap_func_as_nn_module(torch.ops.aten.select.int) - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search(r"%arg.: tensor<\?x197x768xf32>.*->.*tensor<\?x768xf32>", - shlo_text) is not None) - - def test_slice(self): - args = (torch.rand((10, 3, 224, 224)), 0, 0, 9223372036854775807) - dynamic_shapes = (({0: Dim("dim")}, None, None, None),) - m = wrap_func_as_nn_module(torch.ops.aten.slice.Tensor) - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search( - r"%arg.: tensor<\?x3x224x224xf32>.*->.*tensor<\?x3x224x224xf32>", - shlo_text) is not None) - - def test_slice_2(self): - args = (torch.rand((10, 3, 224, 224)), 1, 0, 2) - dynamic_shapes = (({0: Dim("dim")}, None, None, None),) - m = wrap_func_as_nn_module(torch.ops.aten.slice.Tensor) - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search( - r"%arg.: tensor<\?x3x224x224xf32>.*->.*tensor<\?x2x224x224xf32>", - shlo_text) is not None) - - def test_softmax(self): - args = (torch.rand((10, 12, 197, 197)), -1, False) - dynamic_shapes = (({0: Dim("dim")}, None, None),) - m = wrap_func_as_nn_module(torch.ops.aten._softmax.default) - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search( - r"%arg.: tensor<\?x12x197x197xf32>.*->.*tensor<\?x12x197x197xf32>", - shlo_text) is not None) - - def test_sub(self): - args = (torch.rand((10, 1, 1, 10)), torch.rand((10, 1, 1, 10))) - dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),) - m = wrap_func_as_nn_module(torch.ops.aten.sub.Tensor) - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search( - r'tensor<\?x1x1x10xf32>.*tensor<\?x1x1x10xf32>.*->.*tensor<\?x1x1x10xf32>', - shlo_text) is not None) - - def test_softmax_reduce_on_dynamic_dim(self): - args = (torch.rand((1, 8, 128, 3)), -1, False) - dynamic_shapes = (({3: Dim("dim")}, None, None),) - m = wrap_func_as_nn_module(torch.ops.aten._softmax.default) - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search(r"%arg.: tensor<1x8x128x\?xf32>.*->.*tensor<1x8x128x\?xf32>", - shlo_text) is not None) - - @unittest.skip("Converted StableHLO contains i1 dtype, not expected.") - def test_index(self): - args = (torch.rand((2, 10)), torch.arange(5)) - dynamic_shapes = ((None, {0: Dim("dim")}),) - m = wrap_func_as_nn_module( - lambda x, y: torch.ops.aten.index.Tensor(x, [None, y])) - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search( - r"%arg.: tensor<\?xi64>.*%arg.: tensor<2x10xf32>.*->.*tensor<2x\?xf32>", - shlo_text) is not None) - - def test_sub_scalar(self): - args = (1.0, torch.rand((10, 1, 1, 10))) - dynamic_shapes = ((None, {0: Dim("dim")}),) - m = wrap_func_as_nn_module(torch.ops.aten.sub.Tensor) - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search(r'tensor<\?x1x1x10xf32>.*->.*tensor<\?x1x1x10xf32>', - shlo_text) is not None) - - def test_split_with_sizes(self): - - class M(torch.nn.Module): - - def forward(self, x): - res = torch.ops.aten.split_with_sizes.default(x, [1, 2, 3], -1) - return res[0], res[1], res[2] - - args = (torch.rand((3, 10, 6)),) - dynamic_shapes = ({0: Dim("dim")},) - m = M() - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search( - r"%arg.: tensor<\?x10x6xf32>.*->.*tensor<\?x10x1xf32>.*tensor<\?x10x2xf32>.*tensor<\?x10x3xf32>", - shlo_text) is not None) - - def test_transpose_on_dynamic_dim(self): - args = (torch.rand((1, 8, 3, 256)),) - dynamic_shapes = (({2: Dim("dim")},),) - m = wrap_func_as_nn_module( - lambda x: torch.ops.aten.transpose.int(x, -2, -1)) - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search(r"%arg.: tensor<1x8x\?x256xf32>.*->.*tensor<1x8x256x\?xf32>", - shlo_text) is not None) - - def test_unsqueeze_1(self): - args = (torch.rand((3, 10)),) - dynamic_shapes = (({0: Dim("dim")},),) - m = wrap_func_as_nn_module(lambda x: torch.ops.aten.unsqueeze.default(x, 1)) - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search(r"%arg.: tensor<\?x10xf32>.*->.*tensor<\?x1x10xf32>", - shlo_text) is not None) - - def test_unsqueeze_2(self): - args = (torch.rand((1, 1, 3, 256)),) - dynamic_shapes = (({2: Dim("dim")},),) - m = wrap_func_as_nn_module(lambda x: torch.ops.aten.unsqueeze.default(x, 2)) - ep = export(m, args=args, dynamic_shapes=dynamic_shapes) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search( - r"%arg.: tensor<1x1x\?x256xf32>.*->.*tensor<1x1x1x\?x256xf32>", - shlo_text) is not None) - - def test_dynamic_view(self): - - class M(torch.nn.Module): - - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 5, [16, 16]) - - def forward(self, x): - x = self.conv(x) - return x.view(x.shape[0], x.shape[1], -1) - - m = M().eval() - args = (torch.rand((10, 3, 224, 224)),) - dynamic_shapes = ({0: Dim("bs")},) - ep = export(m, args, dynamic_shapes=dynamic_shapes) - out1 = ep.module()(*args) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search( - r"%arg.: tensor<\?x3x224x224xf32>.*->.*tensor<\?x5x43681xf32>", - shlo_text) is not None) - - @unittest.skip("Cannot generate aten.sym_numel in the exported program.") - def test_dynamic_view_sym_numel(self): - - class M(torch.nn.Module): - - def forward(self, x, range): - num_elem = torch.numel(range) - return x.view(x.shape[0], x.shape[2], num_elem, x.shape[4]) - - m = M().eval() - args = (torch.rand((1, 1, 8, 3, 256)), torch.arange(3)) - dynamic_shapes = ({3: Dim("bs")}, {0: Dim("bs")}) - ep = export(m, args, dynamic_shapes=dynamic_shapes) - out1 = ep.module()(*args) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search( - r"%arg.: tensor<\?x3x224x224xf32>.*->.*tensor<\?x5x43681xf32>", - shlo_text) is not None) - - def test_dynamic_view_non_bs(self): - - class M(torch.nn.Module): - - def forward(self, x): - return x.view(x.shape[0], x.shape[1] * x.shape[2], x.shape[3]) - - m = M().eval() - args = (torch.rand((1, 3, 2, 16)),) - dynamic_shapes = ({1: Dim("bs")},) - ep = export(m, args, dynamic_shapes=dynamic_shapes) - out1 = ep.module()(*args) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search(r"%arg.: tensor<1x\?x2x16xf32>.*->.*tensor<1x\?x16xf32>", - shlo_text) is not None) - - def test_dynamic_view_multiplier(self): - - class M(torch.nn.Module): - - def forward(self, x): - return x.view(x.shape[0] * x.shape[1], -1) - - m = M().eval() - args = (torch.rand((10, 197, 768)),) - dynamic_shapes = ({0: Dim("bs")},) - ep = export(m, args, dynamic_shapes=dynamic_shapes) - out1 = ep.module()(*args) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search(r"%arg.: tensor<\?x197x768xf32>.*->.*tensor<\?x768xf32>", - shlo_text) is not None) - - def test_dynamic_expand(self): - - class M(torch.nn.Module): - - def forward(self, x, image): - return x.expand(image.shape[0], -1, -1) - - m = M().eval() - args = (torch.rand((1, 1, 768)), torch.rand((10, 3, 224, 224))) - dynamic_shapes = ( - None, - { - 0: Dim("bs") - }, - ) - ep = export(m, args, dynamic_shapes=dynamic_shapes) - out1 = ep.module()(*args) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search(r"%arg.: tensor<1x1x768xf32>.*->.*tensor<\?x1x768xf32>", - shlo_text) is not None) - - def test_dynamic_expand_2(self): - - class M(torch.nn.Module): - - def forward(self, x, range): - return x.expand(1, 1, 8, range.shape[0], 256) - - m = M().eval() - args = (torch.rand((1, 1, 1, 3, 256)), torch.arange(3)) - dynamic_shapes = ({3: Dim("bs")}, {0: Dim("bs")}) - ep = export(m, args, dynamic_shapes=dynamic_shapes) - out1 = ep.module()(*args) - shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text() - self.assertTrue( - re.search( - r"%arg.: tensor<1x1x1x\?x256xf32>.*->.*tensor<1x1x8x\?x256xf32>", - shlo_text) is not None) - - -if __name__ == "__main__": - test = unittest.main() - sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torchax/test/test_util.py b/torchax/test/test_util.py deleted file mode 100644 index 9cbb24042f57..000000000000 --- a/torchax/test/test_util.py +++ /dev/null @@ -1,118 +0,0 @@ -import unittest -from torchax.util import partition, merge - - -# Helper predicate functions for testing partition -def is_even(n): - return isinstance(n, int) and n % 2 == 0 - - -def is_positive(n): - return isinstance(n, (int, float)) and n > 0 - - -def is_string(s): - return isinstance(s, str) - - -class TestListUtils(unittest.TestCase): - - # --- Tests for partition --- - - def test_partition_empty_list(self): - """Test partition with an empty list.""" - self.assertEqual(partition([], is_even), ([], [])) - - def test_partition_even_odd(self): - """Test partitioning numbers into even and odd.""" - nums = [1, 2, 3, 4, 5, 6] - expected_truthy = [None, 2, None, 4, None, 6] - expected_falsy = [1, None, 3, None, 5, None] - self.assertEqual( - partition(nums, is_even), (expected_truthy, expected_falsy)) - - def test_partition_all_true(self): - """Test partition when the predicate is always true.""" - evens = [2, 4, 6, 8] - expected_truthy = [2, 4, 6, 8] - expected_falsy = [None, None, None, None] - self.assertEqual( - partition(evens, is_even), (expected_truthy, expected_falsy)) - - def test_partition_all_false(self): - """Test partition when the predicate is always false.""" - odds = [1, 3, 5, 7] - expected_truthy = [None, None, None, None] - expected_falsy = [1, 3, 5, 7] - self.assertEqual( - partition(odds, is_even), (expected_truthy, expected_falsy)) - - def test_partition_mixed_types(self): - """Test partition with a list of mixed types.""" - mixed = [1, "hello", 2.5, "world", 3, None] - # Using is_string as the predicate - expected_truthy = [None, "hello", None, "world", None, None] - expected_falsy = [1, None, 2.5, None, 3, - None] # Note: None itself is not a string - self.assertEqual( - partition(mixed, is_string), (expected_truthy, expected_falsy)) - - def test_partition_with_lambda(self): - """Test partition using a lambda function as the predicate.""" - nums = [-2, -1, 0, 1, 2] - expected_truthy = [None, None, None, 1, 2] - expected_falsy = [-2, -1, 0, None, None] - self.assertEqual( - partition(nums, lambda x: isinstance(x, int) and x > 0), - (expected_truthy, expected_falsy)) - - # --- Tests for merge --- - - def test_merge_empty_lists(self): - """Test merge with empty lists.""" - self.assertEqual(merge([], []), []) - - def test_merge_basic(self): - """Test basic merging with None values in the first list.""" - list1 = [1, None, 3, None, 5] - list2 = [None, 2, None, 4, None] - expected = [1, 2, 3, 4, 5] - self.assertEqual(merge(list1, list2), expected) - - def test_merge_no_none_in_list1(self): - """Test merge when list1 has no None values.""" - list1 = ['a', 'b', 'c'] - list2 = [1, 2, 3] - expected = ['a', 'b', 'c'] # Should be identical to list1 - self.assertEqual(merge(list1, list2), expected) - - def test_merge_all_none_in_list1(self): - """Test merge when list1 contains only None.""" - list1 = [None, None, None] - list2 = ['x', 'y', 'z'] - expected = ['x', 'y', 'z'] # Should be identical to list2 - self.assertEqual(merge(list1, list2), expected) - - def test_merge_mixed_types(self): - """Test merge with mixed data types.""" - list1 = [1, None, "hello", None] - list2 = [None, 2.5, None, True] - expected = [1, 2.5, "hello", True] - self.assertEqual(merge(list1, list2), expected) - - def test_merge_unequal_lengths(self): - """Test that merge raises AssertionError for lists of unequal length.""" - list1 = [1, 2, 3] - list2 = [4, 5] - # Use assertRaises as a context manager - with self.assertRaises(AssertionError) as cm: - merge(list1, list2) - - list3 = [6, 7] - list4 = [8, 9, 10] - with self.assertRaises(AssertionError): - merge(list3, list4) # No need to check message again if already checked - - -if __name__ == '__main__': - unittest.main() # For running from command line diff --git a/torchax/test/test_view.py b/torchax/test/test_view.py deleted file mode 100644 index 3f5caee5f1f7..000000000000 --- a/torchax/test/test_view.py +++ /dev/null @@ -1,385 +0,0 @@ -import torch -import torchax -import re -import sys -import unittest - -from torchax.tensor import Tensor -from torchax.view import View - - -class TrainTest(unittest.TestCase): - - def setUp(self): - torch.manual_seed(0) - torchax.enable_globally() - - def test_copy_(self): - x = torch.zeros((10, 10), device="jax") - y = torch.ones((5, 5), device="jax") - x[0:5, :][:, 0:5].copy_(y[:, :]) - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (10, 10)) - self.assertEqual(x[0:5, 0:5].sum(), 25) - self.assertEqual(x.sum(), 25) - - def test_transivity(self): - x = torch.zeros((10, 10), device="jax") - x_view = x[0:5, :][:, 0:5].add_(1) - y_view = x_view[0:5, :][:, 0:5].add_(1) - self.assertEqual(type(x), Tensor) - self.assertEqual(type(x_view), View) - self.assertEqual(type(y_view), View) - self.assertEqual(x.shape, (10, 10)) - self.assertEqual(x[0:5, 0:5].sum(), 50) - self.assertEqual(x.sum(), 50) - - def test_outofplace_add(self): - x = torch.zeros((10, 10), device="jax") - x2 = x[0:5, :][:, 0:5].add(1) - x3 = x2[0:5, :][:, 0:5].add_(x2) - self.assertEqual(type(x), Tensor) - self.assertEqual(type(x2), Tensor) - self.assertEqual(type(x3), View) - self.assertEqual(x.shape, (10, 10)) - self.assertEqual(x.sum(), 0) - self.assertEqual(x2.sum(), 50) - - def test_multiply_tensor_and_view(self): - x = torch.ones((10, 10), device="jax") * 2 - y = torch.ones((10, 10), device="jax") - x1 = x[:, :] - res = x1.mul(y) - self.assertEqual(type(x), Tensor) - self.assertEqual(type(y), Tensor) - self.assertEqual(type(x1), View) - self.assertEqual(type(res), Tensor) - self.assertEqual(res.sum(), 200) - - def test_multiply_views(self): - x = torch.ones((10, 10), device="jax") * 2 - y = torch.ones((10, 10), device="jax") - x1 = x[0:1, :] - y1 = y[0:1, :] - res = x1.mul(y1) - self.assertEqual(type(x), Tensor) - self.assertEqual(type(y), Tensor) - self.assertEqual(type(x1), View) - self.assertEqual(type(y1), View) - self.assertEqual(type(res), Tensor) - self.assertEqual(res.sum(), 20) - - def test_setitem(self): - a = torch.zeros(10, device="jax") - a[0:5][0:3] = 1 - self.assertEqual(type(a), Tensor) - self.assertEqual(a.shape, (10,)) - self.assertEqual(a.sum(), 3) - - # Test all in-place operations - def test_add_(self): - x = torch.zeros((10, 10), device="jax") - x[0:5, :][:, 0:5].add_(1) - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (10, 10)) - self.assertEqual(x.sum(), 25) - - def test_sub_(self): - x = torch.zeros((10, 10), device="jax") - x[0:5, :][:, 0:5].sub_(1) - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (10, 10)) - self.assertEqual(x.sum(), -25) - - def test_mul_(self): - x = torch.ones((10, 10), device="jax") - x[0:5, :][:, 0:5].mul_(2) - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (10, 10)) - self.assertEqual(x.sum(), 125) - - def test_div_(self): - x = torch.ones((10, 10), device="jax") - x[0:10, :][:, 0:10].div_(2) - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (10, 10)) - self.assertEqual(x.sum(), 50) - - def test_pow_(self): - x = torch.full((10, 10), fill_value=2, device="jax") - x[0:5, :][:, 0:5].pow_(2) - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (10, 10)) - self.assertEqual(x.sum(), 250) - - def test_clamp_(self): - x = torch.arange(100, device="jax", dtype=torch.float).reshape(10, 10) - x[0:5, :][:, 0:5].clamp_(min=50, max=80) - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (10, 10)) - self.assertTrue((x[0:5, 0:5] >= 50).all()) - self.assertTrue((x[0:5, 0:5] <= 80).all()) - - def test_lt_(self): - x = torch.ones((10, 10), device="jax") - y = torch.zeros((10, 10), device="jax") - x[0:5, :][:, 0:5].lt_(0.5) - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (10, 10)) - self.assertEqual(x[0:5, 0:5].sum(), - 0) # All False (0) in the modified region - self.assertEqual(x[5:, 5:].sum(), - 25) # All True (1) in the unmodified region - - def test_le_(self): - x = torch.ones((10, 10), device="jax") - x[0:5, :][:, 0:5].le_(1) - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (10, 10)) - self.assertEqual(x.sum(), 100) # All True (1) - - def test_gt_(self): - x = torch.ones((10, 10), device="jax") - x[0:5, :][:, 0:5].gt_(1) - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (10, 10)) - self.assertEqual(x[0:5, 0:5].sum(), - 0) # All False (0) in the modified region - self.assertEqual(x.sum(), 75) # Only the unmodified region is True (1) - - def test_ge_(self): - x = torch.ones((10, 10), device="jax") - x[0:5, :][:, 0:5].ge_(1) - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (10, 10)) - self.assertEqual(x.sum(), 100) # All True (1) - - def test_eq_(self): - x = torch.ones((10, 10), device="jax") - x[0:5, :][:, 0:5].eq_(1) - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (10, 10)) - self.assertEqual(x.sum(), 100) # All True (1) - - def test_ne_(self): - x = torch.ones((10, 10), device="jax") - x[0:5, :][:, 0:5].ne_(1) - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (10, 10)) - self.assertEqual(x[0:5, 0:5].sum(), - 0) # All False (0) in the modified region - self.assertEqual(x.sum(), 75) # Only the unmodified region is True (1) - - def test_bernoulli_(self): - # Set a fixed seed for deterministic behavior - torch.manual_seed(42) - x = torch.full((10, 10), fill_value=0.5, device="jax") - y = x[0:5, :][:, 0:5] - y.bernoulli_() - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (10, 10)) - # Values will be 0 or 1 in the modified region - self.assertTrue(((x[0:5, 0:5] == 0) | (x[0:5, 0:5] == 1)).all()) - # Unmodified region remains 0.5 - self.assertTrue((x[5:, 5:] == 0.5).all()) - - def test_geometric_(self): - torch.manual_seed(42) - x = torch.full((10, 10), fill_value=0.5, device="jax") - y = x[0:5, :][:, 0:5] - y.geometric_(p=0.5) - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (10, 10)) - # Geometric distribution values are positive integers - self.assertTrue((x[0:5, 0:5] >= 1).all()) - # Unmodified region remains 0.5 - self.assertTrue((x[5:, 5:] == 0.5).all()) - - def test_normal_(self): - torch.manual_seed(42) - x = torch.zeros((10, 10), device="jax") - x[0:5, :][:, 0:5].normal_(mean=0, std=1) - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (10, 10)) - # Unmodified region remains 0 - self.assertEqual(x[5:, 5:].sum(), 0) - - def test_uniform_(self): - torch.manual_seed(42) - x = torch.zeros((10, 10), device="jax") - x[0:5, :][:, 0:5].uniform_(0, 1) - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (10, 10)) - # Values in modified region are between 0 and 1 - self.assertTrue((x[0:5, 0:5] >= 0).all()) - self.assertTrue((x[0:5, 0:5] <= 1).all()) - # Unmodified region remains 0 - self.assertEqual(x[5:, 5:].sum(), 0) - - def test_relu_(self): - x = torch.randn((10, 10), device="jax") - x_copy = x.clone() - x[0:5, :][:, 0:5].relu_() - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (10, 10)) - # Modified region has no negative values - self.assertTrue((x[0:5, 0:5] >= 0).all()) - # Unmodified region remains the same - self.assertTrue(torch.equal(x[5:, 5:], x_copy[5:, 5:])) - - def test_squeeze_(self): - x = torch.randn((10, 1, 10), device="jax") - x_clone = x.clone() - # Squeeze the middle dimension - x.squeeze_(1) - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (10, 10)) - # Content should remain the same - self.assertTrue(torch.allclose(x, x_clone.squeeze())) - - def test_sqrt_(self): - x = torch.randn((10, 10), - device="jax").abs() # Use abs to ensure positive values - x_copy = x.clone() - x[0:5, :][:, 0:5].sqrt_() - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (10, 10)) - # Modified region is square root of original values - self.assertTrue(torch.allclose(x[0:5, 0:5], torch.sqrt(x_copy[0:5, 0:5]))) - # Unmodified region remains the same - self.assertTrue(torch.equal(x[5:, 5:], x_copy[5:, 5:])) - - def test_clamp_min_(self): - x = torch.randn((10, 10), device="jax") - x_copy = x.clone() - x[0:5, :][:, 0:5].clamp_min_(0) - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (10, 10)) - # Modified region has no values below 0 - self.assertTrue((x[0:5, 0:5] >= 0).all()) - # Unmodified region remains the same - self.assertTrue(torch.equal(x[5:, 5:], x_copy[5:, 5:])) - - def test_sigmoid_(self): - x = torch.randn((10, 10), device="jax") - x_copy = x.clone() - x[0:5, :][:, 0:5].sigmoid_() - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (10, 10)) - # Modified region values are between 0 and 1 - self.assertTrue((x[0:5, 0:5] >= 0).all()) - self.assertTrue((x[0:5, 0:5] <= 1).all()) - # Unmodified region remains the same - self.assertTrue(torch.equal(x[5:, 5:], x_copy[5:, 5:])) - - def test_tanh_(self): - x = torch.randn((10, 10), device="jax") - x_copy = x.clone() - x[0:5, :][:, 0:5].tanh_() - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (10, 10)) - # Modified region values are between -1 and 1 - self.assertTrue((x[0:5, 0:5] >= -1).all()) - self.assertTrue((x[0:5, 0:5] <= 1).all()) - # Unmodified region remains the same - self.assertTrue(torch.equal(x[5:, 5:], x_copy[5:, 5:])) - - def test_ceil_(self): - x = torch.randn((10, 10), device="jax") - x_copy = x.clone() - x[0:5, :][:, 0:5].ceil_() - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (10, 10)) - # Check that ceil operation was applied correctly - self.assertTrue(torch.allclose(x[0:5, 0:5], torch.ceil(x_copy[0:5, 0:5]))) - # Unmodified region remains the same - self.assertTrue(torch.equal(x[5:, 5:], x_copy[5:, 5:])) - - def test_logical_not_(self): - x = torch.zeros((10, 10), device="jax") - x[0:5, 0:5] = 1 # Set some values to 1 - x[0:5, :][:, 0:5].logical_not_() - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (10, 10)) - # Modified region has all values flipped - self.assertEqual(x[0:5, 0:5].sum(), 0) # All now 0 - # Unmodified region remains 0 - self.assertEqual(x[5:, 5:].sum(), 0) - - def test_unsqueeze_(self): - x = torch.randn((10, 10), device="jax") - x_copy = x.clone() - # Add dimension at index 1 - x.unsqueeze_(1) - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (10, 1, 10)) - # Content should remain the same - self.assertTrue(torch.equal(x.squeeze(1), x_copy)) - - def test_transpose_(self): - x = torch.randn((10, 5), device="jax") - x_copy = x.clone() - # Transpose dimensions 0 and 1 - x.transpose_(0, 1) - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (5, 10)) - # Check transposition worked correctly - self.assertTrue(torch.equal(x, x_copy.transpose(0, 1))) - - def test_log_normal_(self): - torch.manual_seed(42) - x = torch.zeros((10, 10), device="jax") - x[0:5, :][:, 0:5].log_normal_(mean=0, std=1) - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (10, 10)) - # Log-normal values are positive - self.assertTrue((x[0:5, 0:5] > 0).all()) - # Unmodified region remains 0 - self.assertEqual(x[5:, 5:].sum(), 0) - - def test_scatter_add_(self): - # Initialize test tensors - x = torch.zeros((5, 5), device="jax") - indices = torch.tensor([[0, 1, 2], [0, 1, 2]], device="jax") - values = torch.ones((2, 3), device="jax") - - # Apply scatter_add_ operation - x.scatter_add_(0, indices, values) - - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (5, 5)) - # Check specific values were added - self.assertTrue(torch.all(x[0, 0] == 2.0)) - self.assertEqual(x.sum(), 6.0) # Only the 3 specified positions have values - - def test_scatter_(self): - # Initialize test tensors - x = torch.zeros((5, 5), device="jax") - indices = torch.tensor([[0, 1, 2], [0, 1, 2]], device="jax") - values = torch.ones((2, 3), device="jax") * 2.0 - - # Apply scatter_ operation - x.scatter_(0, indices, values) - - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (5, 5)) - # Check specific values were replaced - self.assertEqual(x[0, 0], 2.0) - self.assertEqual(x[1, 1], 2.0) - self.assertEqual(x[2, 2], 2.0) - self.assertEqual(x.sum(), 6.0) # Only the 3 specified positions have values - - def test_scatter_reduce_(self): - # Initialize test tensors - x = torch.ones((5, 5), device="jax") - indices = torch.tensor([[0, 1, 2], [0, 1, 2]], device="jax") - values = torch.ones((2, 3), device="jax") * 2.0 - - # Apply scatter_reduce_ operation with "sum" reduction - x.scatter_reduce_(0, indices, values, reduce="sum") - - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (5, 5)) - # Check specific values were reduced - self.assertTrue(torch.all(x[0, 0] == 5.0)) - self.assertEqual(x.sum(), 37.0) diff --git a/torchax/test_dist/test_mesh_util.py b/torchax/test_dist/test_mesh_util.py deleted file mode 100644 index 21890fed331e..000000000000 --- a/torchax/test_dist/test_mesh_util.py +++ /dev/null @@ -1,51 +0,0 @@ -import unittest - -import jax -from jax.sharding import PartitionSpec -import torch -import torchax -from torchax.mesh_util import Mesh, SingleAxisSharder - - -class MeshUtilTest(unittest.TestCase): - - def setUp(self): - torchax.enable_globally() - - def test_init_module_sharded(self): - - class TestModule(torch.nn.Module): - - def __init__(self): - super().__init__() - self.a = torch.nn.Linear(8, 8) - - mesh = Mesh.fsdp_mesh() - - model = mesh.initialize_model_sharded(TestModule, ()) - self.assertEqual( - len(model.a.weight.jax().addressable_shards), len(jax.devices())) - self.assertEqual( - len(model.a.bias.jax().addressable_shards), len(jax.devices())) - - def test_sharder_call(self): - """Test the __call__ method produces the correct PartitionSpec.""" - sharder = SingleAxisSharder(axis_name="fsdp", axis_size=4) - # Use a simple named tuple instead of MagicMock - shaped_type = torch.ones((5, 8, 12)) # Middle dim divisible by 4 - - spec = sharder("param_name", shaped_type) - self.assertEqual(spec, PartitionSpec(None, "fsdp", None)) - - def test_sharder_call_no_shardable(self): - """Test __call__ when no dimension is shardable.""" - sharder = SingleAxisSharder(axis_name="fsdp", axis_size=4) - shaped_type = torch.ones((5, 7, 11)) - - with self.assertRaisesRegex(AssertionError, - "Unable to find a dim to shard"): - sharder("param_name", shaped_type) - - -if __name__ == "__main__": - unittest.main() diff --git a/torchax/test_dist/test_to_device.py b/torchax/test_dist/test_to_device.py deleted file mode 100644 index 78794fad704e..000000000000 --- a/torchax/test_dist/test_to_device.py +++ /dev/null @@ -1,27 +0,0 @@ -import jax -import torch -import torchax -import unittest - -from jax.sharding import NamedSharding, PartitionSpec as P - - -class ToDeviceTest(unittest.TestCase): - - def test_to_device_twice(self): - env = torchax.default_env() - mesh = jax.make_mesh((jax.device_count(),), ('axis',)) - with env: - step1 = torch.ones( - 100, - 100, - ) - step2 = torch.triu(step1, diagonal=1) - step3 = step2.to(dtype=torch.bool, device='jax') - step3.apply_jax_(jax.device_put, NamedSharding(mesh, P())) - print(step3.to('jax')) - self.assertEqual(step3.device.type, 'jax') - - -if __name__ == '__main__': - unittest.main() diff --git a/torchax/torchax/CONTRIBUTING.md b/torchax/torchax/CONTRIBUTING.md deleted file mode 100644 index f908cd2e59bb..000000000000 --- a/torchax/torchax/CONTRIBUTING.md +++ /dev/null @@ -1,43 +0,0 @@ -# Contributing to torchax - -We appreciate all contributions. If you are planning to contribute a bug fix for an open issue, please comment on the thread and we're happy to provide any guidance. You are very welcome to pick issues from good first issue and help wanted labels. - - -# Developer setup - -## Mac setup: -@qihqi - -I am able to develop directly on mac (m1) laptop for most of parts. Using steps -in README.md works. The condensed version for easy copy & paste: - -```bash -conda create --name python=3.10 -conda activate -pip install --upgrade "jax[cpu]" torch -pip install -r test_requirements.txt -pip install -e . -pip install pytest-xdist # recommended for running test faster -pytest -n auto test -``` - -## Setup on GPU or TPU - -Same as Mac setup, except, if you run test using pytest, please also -add `JAX_PLATFORMS=cpu`. The reason is because pytest usually runs -test in multiple threads. CPU device can be accessed concurrently where -TPU devices usually only allow one accesor per process; so it could deadlock. - -### VSCode - -I use vscode on my Mac. I loosely followed instruction in -https://code.visualstudio.com/docs/python/python-tutorial -to setup a proper python environment. - -The plugins I installed (a subset of the ones listed above) are: -* VSCode's official Python plugin -* Ruff formatter -* Python Debugger - -I also changed Python interpreter to point at the one in my conda env. -That is all the changes I have. diff --git a/torchax/torchax/__init__.py b/torchax/torchax/__init__.py deleted file mode 100644 index d5e964416dcd..000000000000 --- a/torchax/torchax/__init__.py +++ /dev/null @@ -1,128 +0,0 @@ -import contextlib -from typing import List, Dict, Any, Optional -import dataclasses -import jax -import os -import torch -from torch.utils import _pytree as pytree -from torchax import tensor -from contextlib import contextmanager - -__version__ = "0.0.6" -VERSION = __version__ - -__all__ = [ - 'default_env', - 'extract_jax', - 'enable_globally', - 'save_checkpoint', - 'load_checkpoint', -] - -from .checkpoint import save_checkpoint, load_checkpoint - -os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1') - -# torchax:oss-begin -if getattr(jax.config, 'jax_pjrt_client_create_options', None): - jax.config.update( - 'jax_pjrt_client_create_options', - f'ml_framework_name:PyTorch/XLA2;ml_framework_version:{"v0.0.1"}') -# torchax:oss-end - -env = None - - -def default_env(): - global env - - if env is None: - env = tensor.Environment() - return env - - -def extract_jax(mod: torch.nn.Module, env=None): - """Returns a pytree of jax.ndarray and a jax callable.""" - if env is None: - env = default_env() - states = dict(mod.named_buffers()) - states.update(mod.named_parameters()) - - states = env.t2j_copy(states) - - #@jax.jit - def jax_func(states, args, kwargs=None): - (states, args, kwargs) = env.j2t_iso((states, args, kwargs)) - with env: - res = torch.func.functional_call( - mod, states, args, kwargs, tie_weights=False) - return env.t2j_iso(res) - - return states, jax_func - - -def enable_globally(): - env = default_env().enable_torch_modes() - return env - - -def disable_globally(): - global env - default_env().disable_torch_modes() - - -@contextlib.contextmanager -def disable_temporarily(): - prev = default_env().enabled - if prev: - disable_globally() - yield () - if prev: - enable_globally() - - -torch.utils.rename_privateuse1_backend('jax') -unsupported_dtype = [torch.quint8] - -import jax -import torchax.device_module - -torch._register_device_module('jax', torchax.device_module) - - -def enable_accuracy_mode(): - jax.config.update('jax_enable_x64', True) - jax.config.update('jax_default_matmul_precision', 'highest') - default_env().config.internal_respect_torch_return_dtypes = True - - -def enable_performance_mode(): - jax.config.update('jax_enable_x64', False) - jax.config.update('jax_default_matmul_precision', 'default') - default_env().config.internal_respect_torch_return_dtypes = False - - -@dataclasses.dataclass -class CompileOptions: - # only valid if compiling nn.Module - methods_to_compile: List[str] = dataclasses.field( - default_factory=lambda: ['forward']) - jax_jit_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) - mode: str = 'jax' # or dynamo or export - - -def compile(fn, options: Optional[CompileOptions] = None): - options = options or CompileOptions() - if options.mode == 'jax': - from torchax import interop - if isinstance(fn, torch.nn.Module): - module = interop.JittableModule(fn, extra_jit_args=options.jax_jit_kwargs) - for n in options.methods_to_compile: - module.make_jitted(n) - return module - else: - return interop.jax_jit(fn) - elif options.mode == 'dynamo': - raise RuntimeError('dynamo mode is not supported yet') - elif options.mode == 'export': - raise RuntimeError('export mode is not supported yet') diff --git a/torchax/torchax/amp.py b/torchax/torchax/amp.py deleted file mode 100644 index ccbc63bead63..000000000000 --- a/torchax/torchax/amp.py +++ /dev/null @@ -1,332 +0,0 @@ -import contextlib -import enum -import torch -from torch.utils import _pytree as pytree - - -# enum class CastPolicy : uint8_t { -# lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before -# // running the op. Currently, lower_precision_fp is -# // fp16 for AutocastCUDA, and is defined by user -# // (default bf16) for AutocastCPU or other device. -# fp32, // Cast all inputs to at::kFloat before running the op. -# fp32_set_opt_dtype, // Treats functions (like softmax) that -# // 1. we'd like to run in fp32 and -# // 2. have a std::optional arg that controls -# // the output type. -# // fp32_set_opt_dtype wrappers' policy is: if the output -# // type is already set, don't touch it, otherwise, set -# // it to at::kFloat. -# fp32_append_dtype, // Treats functions (like norm) that -# // 1. we'd like to run in fp32 and -# // 2. have some overloads that accept an output type and -# // other overloads that don't. -# // fp32_append_dtype wrappers wrap the overloads that don't -# // have an output dtype. -# // The wrapper policy is: append at::kFloat to the args, -# // and redispatch to the type-aware overload. -# promote, // Run in the widest dtype among several args. -# }; -class CastPolicy(enum.Enum): - LOWER_PRECISION_FP = 0 - FP32 = 1 - FP32_SET_OPT_DTYPE = 2 - FP32_APPEND_DTYPE = 3 - PROMOTE = 4 - - -def execute_policy(policy, args, kwargs, target_lower_fp): - - def is_float(a): - return isinstance(a, torch.Tensor) and a.is_floating_point() - match policy: - case CastPolicy.LOWER_PRECISION_FP: - return pytree.tree_map_only(is_float, lambda a: a.to(target_lower_fp), - (args, kwargs)) - case CastPolicy.FP32: - return pytree.tree_map_only(is_float, lambda a: a.to(torch.float32), - (args, kwargs)) - case CastPolicy.PROMOTE: - dtypes = set(a.dtype for a in args) - widest = max((dtype.itemsize, dtype) for dtype in dtypes)[1] - return pytree.tree_map_only(is_float, lambda a: a.to(widest), - (args, kwargs)) - case _: - raise AssertionError(f'Policy {policy} not implemented yet.') - - -@contextlib.contextmanager -def autocast(device, dtype=torch.bfloat16, env=None): - del device - if env is None: - import torchax - env = torchax.default_env() - with env.override_property(autocast_dtype=dtype): - yield - - -# https://github.com/pytorch/pytorch/blob/05faba40287cf7d8734da96cb2e904f39710bf29/aten/src/ATen/autocast_mode.cpp#L327 -autocast_policy = { - torch.ops.aten.conv1d.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv1d.padding: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv2d.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv2d.padding: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv3d.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv3d.padding: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.bmm.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.mm.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.linalg_vecdot.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.baddbmm.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.addmm.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten._addmm_activation.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.addbmm.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.linear.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten._convolution.deprecated: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.matmul.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv_tbc.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.mkldnn_rnn_layer.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv_transpose1d.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv_transpose2d.input: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv_transpose3d.input: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.prelu.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.scaled_dot_product_attention.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten._native_multi_head_attention.default: - CastPolicy.LOWER_PRECISION_FP, - - # fp32 cast policy - torch.ops.aten.avg_pool3d.default: - CastPolicy.FP32, - torch.ops.aten.binary_cross_entropy.default: - CastPolicy.FP32, - torch.ops.aten.grid_sampler.default: - CastPolicy.FP32, - torch.ops.aten.polar.default: - CastPolicy.FP32, - torch.ops.aten.prod.default: - CastPolicy.FP32, - torch.ops.aten.prod.dim_int: - CastPolicy.FP32, - torch.ops.aten.prod.dim_Dimname: - CastPolicy.FP32, - torch.ops.aten.quantile.default: - CastPolicy.FP32, - torch.ops.aten.quantile.scalar: - CastPolicy.FP32, - torch.ops.aten.nanquantile.default: - CastPolicy.FP32, - torch.ops.aten.nanquantile.scalar: - CastPolicy.FP32, - torch.ops.aten.stft.default: - CastPolicy.FP32, - torch.ops.aten.stft.center: - CastPolicy.FP32, - torch.ops.aten.cdist.default: - CastPolicy.FP32, - torch.ops.aten.grid_sampler_2d.default: - CastPolicy.FP32, - torch.ops.aten._grid_sampler_2d_cpu_fallback.default: - CastPolicy.FP32, - torch.ops.aten.grid_sampler_3d.default: - CastPolicy.FP32, - torch.ops.aten.trace.default: - CastPolicy.FP32, - torch.ops.aten.view_as_complex.default: - CastPolicy.FP32, - torch.ops.aten.cholesky.default: - CastPolicy.FP32, - torch.ops.aten.cholesky_inverse.default: - CastPolicy.FP32, - torch.ops.aten.cholesky_solve.default: - CastPolicy.FP32, - torch.ops.aten.inverse.default: - CastPolicy.FP32, - torch.ops.aten.lu_solve.default: - CastPolicy.FP32, - torch.ops.aten.orgqr.default: - CastPolicy.FP32, - torch.ops.aten.ormqr.default: - CastPolicy.FP32, - torch.ops.aten.pinverse.default: - CastPolicy.FP32, - torch.ops.aten.max_pool3d.default: - CastPolicy.FP32, - torch.ops.aten.max_unpool2d.default: - CastPolicy.FP32, - torch.ops.aten.max_unpool3d.default: - CastPolicy.FP32, - torch.ops.aten.adaptive_avg_pool3d.default: - CastPolicy.FP32, - torch.ops.aten.reflection_pad1d.default: - CastPolicy.FP32, - torch.ops.aten.reflection_pad2d.default: - CastPolicy.FP32, - torch.ops.aten.replication_pad1d.default: - CastPolicy.FP32, - torch.ops.aten.replication_pad2d.default: - CastPolicy.FP32, - torch.ops.aten.replication_pad3d.default: - CastPolicy.FP32, - torch.ops.aten.mse_loss.default: - CastPolicy.FP32, - torch.ops.aten.cosine_embedding_loss.default: - CastPolicy.FP32, - torch.ops.aten.nll_loss.default: - CastPolicy.FP32, - torch.ops.aten.nll_loss2d.default: - CastPolicy.FP32, - torch.ops.aten.hinge_embedding_loss.default: - CastPolicy.FP32, - torch.ops.aten.poisson_nll_loss.default: - CastPolicy.FP32, - torch.ops.aten.smooth_l1_loss.default: - CastPolicy.FP32, - torch.ops.aten.cross_entropy_loss.default: - CastPolicy.FP32, - torch.ops.aten.l1_loss.default: - CastPolicy.FP32, - torch.ops.aten.huber_loss.default: - CastPolicy.FP32, - torch.ops.aten.margin_ranking_loss.default: - CastPolicy.FP32, - torch.ops.aten.soft_margin_loss.default: - CastPolicy.FP32, - torch.ops.aten.triplet_margin_loss.default: - CastPolicy.FP32, - torch.ops.aten.multi_margin_loss.default: - CastPolicy.FP32, - torch.ops.aten.ctc_loss.IntList: - CastPolicy.FP32, - torch.ops.aten.ctc_loss.Tensor: - CastPolicy.FP32, - torch.ops.aten.kl_div.default: - CastPolicy.FP32, - torch.ops.aten.multilabel_margin_loss.default: - CastPolicy.FP32, - torch.ops.aten.binary_cross_entropy_with_logits.default: - CastPolicy.FP32, - torch.ops.aten.fft_fft.default: - CastPolicy.FP32, - torch.ops.aten.fft_ifft.default: - CastPolicy.FP32, - torch.ops.aten.fft_fft2.default: - CastPolicy.FP32, - torch.ops.aten.fft_ifft2.default: - CastPolicy.FP32, - torch.ops.aten.fft_fftn.default: - CastPolicy.FP32, - torch.ops.aten.fft_ifftn.default: - CastPolicy.FP32, - torch.ops.aten.fft_rfft.default: - CastPolicy.FP32, - torch.ops.aten.fft_irfft.default: - CastPolicy.FP32, - torch.ops.aten.fft_rfft2.default: - CastPolicy.FP32, - torch.ops.aten.fft_irfft2.default: - CastPolicy.FP32, - torch.ops.aten.fft_rfftn.default: - CastPolicy.FP32, - torch.ops.aten.fft_irfftn.default: - CastPolicy.FP32, - torch.ops.aten.fft_hfft.default: - CastPolicy.FP32, - torch.ops.aten.fft_ihfft.default: - CastPolicy.FP32, - torch.ops.aten.linalg_cond.default: - CastPolicy.FP32, - torch.ops.aten.linalg_cond.p_str: - CastPolicy.FP32, - torch.ops.aten.linalg_matrix_rank.default: - CastPolicy.FP32, - torch.ops.aten.linalg_matrix_rank.tol_tensor: - CastPolicy.FP32, - torch.ops.aten.linalg_matrix_rank.atol_rtol_tensor: - CastPolicy.FP32, - torch.ops.aten.linalg_matrix_rank.atol_rtol_float: - CastPolicy.FP32, - torch.ops.aten.linalg_solve.default: - CastPolicy.FP32, - torch.ops.aten.linalg_cholesky.default: - CastPolicy.FP32, - torch.ops.aten.linalg_svdvals.default: - CastPolicy.FP32, - torch.ops.aten.linalg_eigvals.default: - CastPolicy.FP32, - torch.ops.aten.linalg_eigvalsh.default: - CastPolicy.FP32, - torch.ops.aten.linalg_inv.default: - CastPolicy.FP32, - torch.ops.aten.linalg_householder_product.default: - CastPolicy.FP32, - torch.ops.aten.linalg_tensorinv.default: - CastPolicy.FP32, - torch.ops.aten.linalg_tensorsolve.default: - CastPolicy.FP32, - torch.ops.aten.fake_quantize_per_tensor_affine.default: - CastPolicy.FP32, - torch.ops.aten.geqrf.default: - CastPolicy.FP32, - torch.ops.aten._lu_with_info.default: - CastPolicy.FP32, - torch.ops.aten.qr.default: - CastPolicy.FP32, - torch.ops.aten.svd.default: - CastPolicy.FP32, - torch.ops.aten.triangular_solve.default: - CastPolicy.FP32, - torch.ops.aten.fractional_max_pool2d.default: - CastPolicy.FP32, - torch.ops.aten.fractional_max_pool3d.default: - CastPolicy.FP32, - torch.ops.aten.adaptive_max_pool3d.default: - CastPolicy.FP32, - torch.ops.aten.multilabel_margin_loss_forward.default: - CastPolicy.FP32, - torch.ops.aten.linalg_qr.default: - CastPolicy.FP32, - torch.ops.aten.linalg_cholesky_ex.default: - CastPolicy.FP32, - torch.ops.aten.linalg_svd.default: - CastPolicy.FP32, - torch.ops.aten.linalg_eig.default: - CastPolicy.FP32, - torch.ops.aten.linalg_eigh.default: - CastPolicy.FP32, - torch.ops.aten.linalg_lstsq.default: - CastPolicy.FP32, - torch.ops.aten.linalg_inv_ex.default: - CastPolicy.FP32, - - # promote - torch.ops.aten.stack.default: - CastPolicy.PROMOTE, - torch.ops.aten.cat.default: - CastPolicy.PROMOTE, - torch.ops.aten.index_copy.default: - CastPolicy.PROMOTE, - torch.ops.aten.index_copy.dimname: - CastPolicy.PROMOTE, -} diff --git a/torchax/torchax/checkpoint.py b/torchax/torchax/checkpoint.py deleted file mode 100644 index daded1c3afad..000000000000 --- a/torchax/torchax/checkpoint.py +++ /dev/null @@ -1,60 +0,0 @@ -import torch -import os -from typing import Any, Dict -from flax.training import checkpoints -import jax -import jax.numpy as jnp -import numpy as np - - -def _to_jax(pytree): - return jax.tree_util.tree_map( - lambda x: jnp.asarray(x.cpu().numpy()) - if isinstance(x, torch.Tensor) else x, pytree) - - -def _to_torch(pytree): - return jax.tree_util.tree_map( - lambda x: torch.from_numpy(np.asarray(x)) - if isinstance(x, (jnp.ndarray, jax.Array)) else x, pytree) - - -def save_checkpoint(state: Dict[str, Any], path: str, step: int): - """Saves a checkpoint to a file in JAX style. - - Args: - state: A dictionary containing the state to save. torch.Tensors will be - converted to jax.Array. - path: The path to save the checkpoint to. This is a directory. - step: The training step. - """ - state = _to_jax(state) - checkpoints.save_checkpoint(path, state, step=step, overwrite=True) - - -def load_checkpoint(path: str) -> Dict[str, Any]: - """Loads a checkpoint and returns it in JAX format. - - This function can load both PyTorch-style (single file) and JAX-style - (directory) checkpoints. - - If the checkpoint is in PyTorch format, it will be converted to JAX format. - - Args: - path: The path to the checkpoint. - - Returns: - The loaded state in JAX format (pytree with jax.Array leaves). - """ - if os.path.isdir(path): - # JAX-style checkpoint - state = checkpoints.restore_checkpoint(path, target=None) - if state is None: - raise FileNotFoundError(f"No checkpoint found at {path}") - return state - elif os.path.isfile(path): - # PyTorch-style checkpoint - state = torch.load(path, weights_only=False) - return _to_jax(state) - else: - raise FileNotFoundError(f"No such file or directory: {path}") diff --git a/torchax/torchax/config.py b/torchax/torchax/config.py deleted file mode 100644 index f439c656287b..000000000000 --- a/torchax/torchax/config.py +++ /dev/null @@ -1,30 +0,0 @@ -import dataclasses - - -@dataclasses.dataclass -class Configuration: - debug_print_each_op: bool = False - debug_accuracy_for_each_op: bool = False - debug_mixed_tensor: bool = False - debug_print_each_op_operands: bool = False - - use_int32_for_index: bool = False - - # normally, math between CPU torch.Tensor with torchax.Tensor is not - # allowed. However, if that torch.Tensor happens to be scalar, then we - # can use scalar * tensor math to handle it - allow_mixed_math_with_scalar_tensor: bool = True - - # If true, we will convert Views into torchax.Tensors eagerly - force_materialize_views: bool = False - - # Use DLPack for converting jax.Arrays <-> and torch.Tensor - use_dlpack_for_data_conversion: bool = False - - # Flash attention - use_tpu_flash_attention: bool = False - shmap_flash_attention: bool = False - - # device - treat_cuda_as_jax_device: bool = True - internal_respect_torch_return_dtypes: bool = False diff --git a/torchax/torchax/decompositions.py b/torchax/torchax/decompositions.py deleted file mode 100644 index d1c1f463d88a..000000000000 --- a/torchax/torchax/decompositions.py +++ /dev/null @@ -1,776 +0,0 @@ -"""This file contains some decompositons that are not available in torch stable. - -Most likely from Content of -https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py -at main branch HEAD that we find useful here. - -Can also contain decompositions of a torch op in terms of other torch ops. -""" - -import functools -from typing import Any, Callable, List, Tuple - -import torch -from torch import Tensor -import torch._decomp as decomp -from torch._decomp import decompositions_for_rng -from torch._decomp import register_decomposition -import torch._prims_common as utils -from torch._prims_common.wrappers import out_wrapper - -DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] - -# None of these functions are publicly accessible; get at them -# from torch._decomps -__all__: List[str] = [] - -aten = torch._ops.ops.aten - - -def _try_register(op, impl): - try: - register_decomposition(op)(impl) - except: - pass - - -@out_wrapper() -def _reflection_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor: - - def idx(left, middle, right): - dim_idx = torch.arange(-left, middle + right, device=a.device) - return middle - 1 - (middle - 1 - dim_idx.abs()).abs() - - return _reflection_or_replication_pad( - a, - padding, - idx, - ) - - -_try_register(aten.reflection_pad1d, _reflection_pad) -_try_register(aten.reflection_pad2d, _reflection_pad) -_try_register(aten.reflection_pad3d, _reflection_pad) - - -@out_wrapper() -def _replication_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor: - - def idx(left, middle, right): - dim_idx = torch.arange(-left, middle + right, device=a.device) - return torch.clamp(dim_idx, 0, middle - 1) - - return _reflection_or_replication_pad( - a, - padding, - idx, - ) - - -decomp.global_decomposition_table["post_autograd"][ - aten.replication_pad2d.default] = _replication_pad - - -def _reflection_or_replication_pad( - a: Tensor, - padding: Tuple[int, ...], - idx_fn: Callable[[int, int, int], Tensor], -) -> Tensor: - dim = len(padding) // 2 - torch._check( - a.dim() in (dim + 1, dim + 2), - lambda: f"reflection_pad{dim}d requires {dim + 1}D or {dim + 2}D input", - ) - inp_shape = a.shape[-dim:] - nc_dim = a.dim() - dim - - padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)] - padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)] - - result = a - for i in range(dim): - idx: List[Any] = [None] * result.dim() - idx[i + nc_dim] = idx_fn(padding_left[i], inp_shape[i], padding_right[i]) - result = aten._unsafe_index(result, idx) - - # convert output to correct memory format, if necessary - memory_format = utils.suggest_memory_format(result) - result = result.contiguous(memory_format=memory_format) - return result - - -_try_register(aten.replication_pad1d, _replication_pad) -_try_register(aten.replication_pad3d, _replication_pad) - - -def bernoulli(self, *, generator=None): - return (torch.rand_like(self, dtype=torch.float32) < self).to(self.dtype) - - -_try_register(aten.bernoulli.default, bernoulli) - - -def rand_like(self, **kwargs): - dtype = kwargs.get("dtype", self.dtype) - return torch.rand(self.shape, dtype=dtype) - - -def channel_shuffle(self, groups): - batchsize, channels, height, width = self.shape - channels_per_group = channels // groups - self = self.reshape(batchsize, groups, channels_per_group, height, width) - self = self.transpose(1, 2) - self = self.reshape(batchsize, channels, height, width) - return self - - -_try_register(aten.channel_shuffle, channel_shuffle) - -_try_register(aten.bernoulli, bernoulli) -_try_register(aten.rand_like, rand_like) - - -def bernoulli_float(self, p=0.5): - return self.bernoulli_(p) - - -_try_register(aten.bernoulli_.float, bernoulli_float) -_try_register(aten.bernoulli_.Tensor, decompositions_for_rng.bernoulli_) - - -def _sum_tensors(ts) -> Tensor: - return functools.reduce(torch.add, ts) - - -@register_decomposition(aten.grid_sampler_3d) -def _grid_sampler_3d( - a: torch.Tensor, - grid: torch.Tensor, - interpolation_mode: int = 0, - padding_mode: int = 0, - align_corners: bool = False, -) -> Tensor: - """References: https://github.com/pytorch/pytorch/blob/06a7dc21c1005750598c37f3adbc031183c74de6/torch/_decomp/decompositions.py#L4075 - - The above implement the 2d case. - """ - _expand_grid = False - torch._check( - interpolation_mode in (0, 1), - lambda: f"Invalid interpolation mode {interpolation_mode}", - ) - torch._check( - padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}") - - # a is 5D: [B, C, D, H, W] - - def unnormalize(coords: Tensor, size: int) -> Tensor: - # Rescale coordinates from [-1, 1] to: - # [0, size - 1] if align_corners is True - # [-.5, size -.5] if align_corners is False - mul = (size * 0.5 - 0.5) if align_corners else (size * 0.5) - ofs = size * 0.5 - 0.5 - return coords * mul + ofs - - # Reflects coordinates until they fall between low and high (inclusive). - # The bounds are passed as twice their value so that half-integer values - # can be represented as ints. - def reflect_coordinates(coords: Tensor, twice_low: int, - twice_high: int) -> Tensor: - if twice_low == twice_high: - return torch.zeros_like(coords) - coords_min = twice_low / 2 - coords_span = (twice_high - twice_low) / 2 - coords2 = (coords - coords_min).abs() - extra = torch.fmod(coords2, coords_span) - flips = (coords2 / coords_span).floor().to(dtype=torch.int8) - return torch.where(flips & 1 == 0, extra + coords_min, - coords_span + coords_min - extra) - - def compute_coordinates(coords: Tensor, size: int) -> Tensor: - if padding_mode == 0: # Zero - return coords - elif padding_mode == 1: # Borders - return torch.clamp(coords, 0, size - 1) - else: # padding_mode == 2, Reflection - if align_corners: - coords_reflected = reflect_coordinates(coords, 0, 2 * (size - 1)) - else: - coords_reflected = reflect_coordinates(coords, -1, 2 * size - 1) - return torch.clamp(coords_reflected, 0, size - 1) - - def compute_source_index(coords: Tensor, size: int) -> Tensor: - coords_un = unnormalize(coords, size) - return compute_coordinates(coords_un, size) - - N, C, iD, iH, iW = a.shape - _, oD, oH, oW, three = grid.shape - assert three == 3, "Last dim of grid must be 3. got {}".format(three) - - def in_bounds_cond(xs: Tensor, ys: Tensor, zs) -> Tensor: - xcheck = torch.logical_and(0 <= xs, xs < iW) - ycheck = torch.logical_and(0 <= ys, ys < iH) - zcheck = torch.logical_and(0 <= zs, zs < iD) - return torch.logical_and(xcheck, torch.logical_and(ycheck, zcheck)) - - N_idx = torch.arange(N, device=a.device).view(N, 1, 1, 1, 1) - C_idx = torch.arange(C, device=a.device).view(1, C, 1, 1, 1) - - def clip(xs: torch.Tensor, ys: torch.Tensor, zs, ws: torch.Tensor): - cond = in_bounds_cond(xs, ys, zs) - # To clip to inside valid coordinates, we map the coordinates - # to (x, y) = (0, 0) and also set the weight to 0 - # We also change the shape of the tensor to the appropriate one for - # broadcasting with N_idx, C_idx for the purposes of advanced indexing - c = C if _expand_grid else 1 - return tuple( - torch.where(cond, t, 0).view(N, c, oD, oH, oW) for t in ( - xs.to(dtype=torch.int64), - ys.to(dtype=torch.int64), - zs.to(dtype=torch.int64), - ws, - )) - - def get_summand(ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor, - w) -> Tensor: - # Perform clipping, index into input tensor and multiply by weight - idx_x, idx_y, idx_z, w_ = clip(ix, iy, iz, w) - return a[N_idx, C_idx, idx_z, idx_y, idx_x] * w_ - - x = grid[..., 0] - y = grid[..., 1] - d = grid[..., 2] - - if interpolation_mode == 0: # Bilinear - ix = compute_source_index(x, iW) - iy = compute_source_index(y, iH) - id_ = compute_source_index(d, iD) - - ix_nwf, iy_nwf, id_nwf = ix.floor(), iy.floor(), id_.floor() - ix_nef, iy_nef, id_nef = ix_nwf + 1, iy_nwf, id_nwf - ix_swf, iy_swf, id_swf = ix_nwf, iy_nwf + 1, id_nwf - ix_sef, iy_sef, id_sef = ix_nef, iy_swf, id_nwf - ix_nwb, iy_nwb, id_nwb = ix_nwf, iy_nwf, id_nwf + 1 - ix_neb, iy_neb, id_neb = ix_nef, iy_nef, id_nwf + 1 - ix_swb, iy_swb, id_swb = ix_swf, iy_swf, id_nwf + 1 - ix_seb, iy_seb, id_seb = ix_sef, iy_sef, id_nwf + 1 - - w_nwf = (ix_seb - ix) * (iy_seb - iy) * (id_seb - id_) - w_nef = (ix - ix_swb) * (iy_swb - iy) * (id_swb - id_) - w_swf = (ix_neb - ix) * (iy - iy_neb) * (id_neb - id_) - w_sef = (ix - ix_nwb) * (iy - iy_nwb) * (id_nwb - id_) - w_nwb = (ix_sef - ix) * (iy_sef - iy) * (id_ - id_sef) - w_neb = (ix - ix_swf) * (iy_swf - iy) * (id_ - id_swf) - w_swb = (ix_nef - ix) * (iy - iy_nef) * (id_ - id_nef) - w_seb = (ix - ix_nwf) * (iy - iy_nwf) * (id_ - id_nwf) - - return _sum_tensors( - get_summand(ix, iy, id_, w) for (ix, iy, id_, w) in ( - (ix_nwf, iy_nwf, id_nwf, w_nwf), - (ix_nef, iy_nef, id_nef, w_nef), - (ix_swf, iy_swf, id_swf, w_swf), - (ix_sef, iy_sef, id_sef, w_sef), - (ix_nwb, iy_nwb, id_nwb, w_nwb), - (ix_neb, iy_neb, id_neb, w_neb), - (ix_swb, iy_swb, id_swb, w_swb), - (ix_seb, iy_seb, id_seb, w_seb), - )) - else: # interpolation_mode == 1: # Nearest - ix = compute_source_index(x, iW) - iy = compute_source_index(y, iH) - iz = compute_source_index(d, iD) - - ix_nearest = ix.round() - iy_nearest = iy.round() - iz_nearest = iz.round() - - return get_summand(ix_nearest, iy_nearest, iz_nearest, 1) - - -DECOMPOSITIONS = decomp.get_decompositions([ - torch.ops.aten.upsample_bicubic2d, - torch.ops.aten.upsample_nearest1d, - torch.ops.aten.upsample_nearest2d, - torch.ops.aten.upsample_nearest3d, - torch.ops.aten._upsample_nearest_exact1d, - torch.ops.aten._upsample_nearest_exact2d, - torch.ops.aten._upsample_nearest_exact3d, - torch.ops.aten._native_batch_norm_legit.no_stats, - torch.ops.aten._native_batch_norm_legit_functional.default, - torch.ops.aten._adaptive_avg_pool2d, - torch.ops.aten._adaptive_avg_pool3d, - torch.ops.aten.grid_sampler_2d, - torch.ops.aten.grid_sampler_3d, - torch.ops.aten.native_dropout, - torch.ops.aten.reflection_pad1d, - torch.ops.aten.reflection_pad2d, - torch.ops.aten.reflection_pad3d, - torch.ops.aten.replication_pad1d, - torch.ops.aten.replication_pad2d, - torch.ops.aten.replication_pad3d, - torch.ops.aten.bernoulli, - torch.ops.aten.rand_like, - torch.ops.aten._batch_norm_with_update, - torch.ops.aten.channel_shuffle, - torch.ops.aten.nll_loss2d_forward, - torch.ops.aten.nll_loss2d_backward, - torch.ops.aten.bernoulli_.Tensor, - torch.ops.aten.bernoulli_.float, - torch.ops.aten.log_normal, - torch.ops.aten.addcdiv.default, - torch.ops.aten.addcdiv.out, - torch.ops.aten.addcdiv_.default, - torch.ops.aten.addcmul.default, - torch.ops.aten.addcmul.out, - torch.ops.aten.addcmul_.default, - torch.ops.aten.addr.default, - torch.ops.aten.addr.out, - torch.ops.aten.affine_grid_generator.default, - torch.ops.aten.affine_grid_generator.out, - torch.ops.aten.alias_copy.default, - torch.ops.aten.alias_copy.out, - torch.ops.aten.all.default, - torch.ops.aten.all.dim, - torch.ops.aten.all.dims, - torch.ops.aten.all.out, - torch.ops.aten.all.dims_out, - torch.ops.aten.all.all_out, - torch.ops.aten.all.dimname, - torch.ops.aten.all.dimname_out, - torch.ops.aten.aminmax.default, - torch.ops.aten.aminmax.out, - torch.ops.aten.arange.default, - torch.ops.aten.arange.start, - torch.ops.aten.baddbmm.default, - torch.ops.aten.baddbmm.out, - torch.ops.aten.binary_cross_entropy.default, - torch.ops.aten.binary_cross_entropy.out, - torch.ops.aten.binary_cross_entropy_backward.default, - torch.ops.aten.binary_cross_entropy_backward.grad_input, - torch.ops.aten.binary_cross_entropy_with_logits.default, - torch.ops.aten.binary_cross_entropy_with_logits.out, - torch.ops.aten.block_diag.default, - torch.ops.aten.block_diag.out, - torch.ops.aten.celu.default, - torch.ops.aten.celu.out, - torch.ops.aten.celu_.default, - torch.ops.aten.channel_shuffle.default, - torch.ops.aten.channel_shuffle.out, - torch.ops.aten.clamp_max.default, - torch.ops.aten.clamp_max.Tensor, - torch.ops.aten.clamp_max.out, - torch.ops.aten.clamp_max.Tensor_out, - torch.ops.aten.clamp_min.default, - torch.ops.aten.clamp_min.Tensor, - torch.ops.aten.clamp_min.out, - torch.ops.aten.clamp_min.Tensor_out, - torch.ops.aten.col2im.default, - torch.ops.aten.col2im.out, - torch.ops.aten.count_nonzero.dim_IntList, - torch.ops.aten.count_nonzero.dim_IntList_out, - torch.ops.aten.count_nonzero.default, - torch.ops.aten.count_nonzero.out, - torch.ops.aten.linalg_cross.default, - torch.ops.aten.linalg_cross.out, - torch.ops.aten.cudnn_batch_norm.default, - torch.ops.aten.cudnn_batch_norm.out, - torch.ops.aten.cudnn_batch_norm_backward.default, - torch.ops.aten.cudnn_batch_norm_backward.out, - torch.ops.aten.miopen_batch_norm_backward.default, - torch.ops.aten.miopen_batch_norm_backward.out, - torch.ops.aten.deg2rad.default, - torch.ops.aten.deg2rad.out, - torch.ops.aten.deg2rad_.default, - torch.ops.aten.detach.default, - torch.ops.aten.diag_embed.default, - torch.ops.aten.diag_embed.out, - torch.ops.aten.diagonal_backward.default, - torch.ops.aten.diagonal_backward.out, - torch.ops.aten.dot.default, - torch.ops.aten.dot.out, - torch.ops.aten.vdot.default, - torch.ops.aten.vdot.out, - torch.ops.aten.elu.default, - torch.ops.aten.elu.out, - torch.ops.aten.elu_.default, - torch.ops.aten.elu_backward.default, - torch.ops.aten.elu_backward.grad_input, - torch.ops.aten.embedding_dense_backward.default, - torch.ops.aten.embedding_dense_backward.out, - torch.ops.aten.empty_like.default, - torch.ops.aten.empty_like.out, - torch.ops.aten._euclidean_dist.default, - torch.ops.aten.expand_copy.default, - torch.ops.aten.expand_copy.out, - torch.ops.aten.eye.default, - torch.ops.aten.eye.m, - torch.ops.aten.eye.out, - torch.ops.aten.eye.m_out, - torch.ops.aten.fill.Scalar, - torch.ops.aten.fill.Tensor, - torch.ops.aten.fill_.Scalar, - torch.ops.aten.fill_.Tensor, - torch.ops.aten.floor_divide.default, - torch.ops.aten.floor_divide.Scalar, - torch.ops.aten.floor_divide.out, - torch.ops.aten.floor_divide.Scalar_out, - torch.ops.aten.frac.default, - torch.ops.aten.frac.out, - torch.ops.aten.frac_.default, - torch.ops.aten.gelu_.default, - torch.ops.aten.gelu_backward.default, - torch.ops.aten.gelu_backward.grad_input, - torch.ops.aten.glu.default, - torch.ops.aten.glu.out, - torch.ops.aten.glu_backward.default, - torch.ops.aten.glu_backward.grad_input, - torch.ops.aten.hardshrink.default, - torch.ops.aten.hardshrink.out, - torch.ops.aten.hardsigmoid.default, - torch.ops.aten.hardsigmoid.out, - torch.ops.aten.hardsigmoid_.default, - torch.ops.aten.hardsigmoid_backward.default, - torch.ops.aten.hardsigmoid_backward.grad_input, - torch.ops.aten.hardswish.default, - torch.ops.aten.hardswish.out, - torch.ops.aten.hardswish_.default, - torch.ops.aten.hardswish_backward.default, - torch.ops.aten.hardswish_backward.out, - torch.ops.aten.hardtanh_.default, - torch.ops.aten.hardtanh_backward.default, - torch.ops.aten.hardtanh_backward.grad_input, - torch.ops.aten.heaviside.default, - torch.ops.aten.heaviside.out, - torch.ops.aten.heaviside_.default, - torch.ops.aten.huber_loss.default, - torch.ops.aten.huber_loss.out, - torch.ops.aten.huber_loss_backward.default, - torch.ops.aten.huber_loss_backward.out, - torch.ops.aten.im2col.default, - torch.ops.aten.im2col.out, - torch.ops.aten.index_add.default, - torch.ops.aten.index_add.out, - torch.ops.aten.index_add.dimname, - torch.ops.aten.index_add_.default, - torch.ops.aten.index_copy.default, - torch.ops.aten.index_copy.dimname, - torch.ops.aten.index_copy.out, - torch.ops.aten.index_copy_.default, - torch.ops.aten.index_copy_.dimname, - torch.ops.aten.index_fill.int_Tensor, - torch.ops.aten.index_fill.int_Scalar, - torch.ops.aten.index_fill.Dimname_Scalar, - torch.ops.aten.index_fill.Dimname_Tensor, - torch.ops.aten.index_fill.int_Scalar_out, - torch.ops.aten.index_fill.int_Tensor_out, - torch.ops.aten.index_fill_.int_Tensor, - torch.ops.aten.index_fill_.int_Scalar, - torch.ops.aten.index_fill_.Dimname_Scalar, - torch.ops.aten.index_fill_.Dimname_Tensor, - torch.ops.aten.isin.Tensor_Tensor, - torch.ops.aten.isin.Tensor_Tensor_out, - torch.ops.aten.isin.Tensor_Scalar, - torch.ops.aten.isin.Tensor_Scalar_out, - torch.ops.aten.isin.Scalar_Tensor, - torch.ops.aten.isin.Scalar_Tensor_out, - torch.ops.aten.isneginf.default, - torch.ops.aten.isneginf.out, - torch.ops.aten.isposinf.default, - torch.ops.aten.isposinf.out, - torch.ops.aten.leaky_relu_.default, - torch.ops.aten.leaky_relu_backward.default, - torch.ops.aten.leaky_relu_backward.grad_input, - torch.ops.aten.lerp.Scalar, - torch.ops.aten.lerp.Tensor, - torch.ops.aten.lerp.Scalar_out, - torch.ops.aten.lerp.Tensor_out, - torch.ops.aten.lerp_.Scalar, - torch.ops.aten.lerp_.Tensor, - torch.ops.aten.linspace.Tensor_Tensor, - torch.ops.aten.linspace.Tensor_Scalar, - torch.ops.aten.linspace.Scalar_Tensor, - torch.ops.aten.linspace.default, - torch.ops.aten.linspace.out, - torch.ops.aten.linspace.Tensor_Tensor_out, - torch.ops.aten.linspace.Tensor_Scalar_out, - torch.ops.aten.linspace.Scalar_Tensor_out, - torch.ops.aten.logaddexp.default, - torch.ops.aten.logaddexp.out, - torch.ops.aten.logaddexp2.default, - torch.ops.aten.logaddexp2.out, - torch.ops.aten.logit.default, - torch.ops.aten.logit.out, - torch.ops.aten.logit_.default, - torch.ops.aten.logit_backward.default, - torch.ops.aten.log_sigmoid_backward.default, - torch.ops.aten.log_sigmoid_backward.grad_input, - torch.ops.aten.log_sigmoid_forward.default, - torch.ops.aten.log_sigmoid_forward.output, - torch.ops.aten._log_softmax_backward_data.default, - torch.ops.aten._log_softmax_backward_data.out, - torch.ops.aten.logspace.Tensor_Tensor, - torch.ops.aten.logspace.Tensor_Scalar, - torch.ops.aten.logspace.Scalar_Tensor, - torch.ops.aten.logspace.default, - torch.ops.aten.logspace.out, - torch.ops.aten.logspace.Tensor_Tensor_out, - torch.ops.aten.logspace.Tensor_Scalar_out, - torch.ops.aten.logspace.Scalar_Tensor_out, - torch.ops.aten.logsumexp.default, - torch.ops.aten.masked_fill.Scalar, - torch.ops.aten.masked_fill.Tensor, - torch.ops.aten.masked_fill.Scalar_out, - torch.ops.aten.masked_fill.Tensor_out, - torch.ops.aten.masked_fill_.Scalar, - torch.ops.aten.masked_fill_.Tensor, - torch.ops.aten.mish.default, - torch.ops.aten.mish.out, - torch.ops.aten.mish_.default, - torch.ops.aten.mse_loss.default, - torch.ops.aten.mse_loss.out, - torch.ops.aten.mse_loss_backward.default, - torch.ops.aten.mse_loss_backward.grad_input, - torch.ops.aten.multi_margin_loss.default, - torch.ops.aten.multi_margin_loss.out, - torch.ops.aten.multilabel_margin_loss_forward.default, - torch.ops.aten.multilabel_margin_loss_forward.output, - torch.ops.aten.mv.default, - torch.ops.aten.mv.out, - torch.ops.aten.mvlgamma.default, - torch.ops.aten.mvlgamma.out, - torch.ops.aten.mvlgamma_.default, - torch.ops.aten.nansum.default, - torch.ops.aten.nansum.out, - torch.ops.aten.nan_to_num.default, - torch.ops.aten.nan_to_num.out, - torch.ops.aten.nan_to_num_.default, - torch.ops.aten.native_batch_norm_backward.default, - torch.ops.aten.native_batch_norm_backward.out, - torch.ops.aten.native_dropout_backward.default, - torch.ops.aten.native_dropout_backward.out, - torch.ops.aten.native_group_norm_backward.default, - torch.ops.aten.native_group_norm_backward.out, - torch.ops.aten.native_layer_norm_backward.default, - torch.ops.aten.native_layer_norm_backward.out, - torch.ops.aten.new_empty.default, - torch.ops.aten.new_empty.out, - torch.ops.aten.new_full.default, - torch.ops.aten.new_full.out, - torch.ops.aten.new_ones.default, - torch.ops.aten.new_ones.out, - torch.ops.aten.new_zeros.default, - torch.ops.aten.new_zeros.out, - torch.ops.aten.nll_loss2d_forward.default, - torch.ops.aten.nll_loss2d_forward.output, - torch.ops.aten.nll_loss2d_backward.default, - torch.ops.aten.nll_loss2d_backward.grad_input, - torch.ops.aten.nll_loss_backward.default, - torch.ops.aten.nll_loss_backward.grad_input, - torch.ops.aten.nll_loss_forward.default, - torch.ops.aten.nll_loss_forward.output, - torch.ops.aten.norm.Scalar, - torch.ops.aten.norm.ScalarOpt_dim, - torch.ops.aten.norm.names_ScalarOpt_dim, - torch.ops.aten.norm.ScalarOpt_dim_dtype, - torch.ops.aten.norm.dtype_out, - torch.ops.aten.norm.out, - torch.ops.aten.norm.ScalarOpt_dtype, - torch.ops.aten.norm.ScalarOpt_dtype_out, - torch.ops.aten.norm.Scalar_out, - torch.ops.aten.norm.names_ScalarOpt_dim_dtype, - torch.ops.aten.norm.names_dtype_out, - torch.ops.aten.norm.names_out, - torch.ops.aten.ones.default, - torch.ops.aten.ones_like.default, - torch.ops.aten.ones_like.out, - torch.ops.aten.pixel_shuffle.default, - torch.ops.aten.pixel_shuffle.out, - torch.ops.aten.pixel_unshuffle.default, - torch.ops.aten.pixel_unshuffle.out, - torch.ops.aten._prelu_kernel.default, - torch.ops.aten._prelu_kernel_backward.default, - torch.ops.aten._reshape_alias.default, - torch.ops.aten.rad2deg.default, - torch.ops.aten.rad2deg.out, - torch.ops.aten.rad2deg_.default, - torch.ops.aten.reflection_pad1d.default, - torch.ops.aten.reflection_pad1d.out, - torch.ops.aten.reflection_pad1d_backward.default, - torch.ops.aten.reflection_pad1d_backward.grad_input, - torch.ops.aten.reflection_pad2d.default, - torch.ops.aten.reflection_pad2d.out, - torch.ops.aten.reflection_pad2d_backward.default, - torch.ops.aten.reflection_pad2d_backward.grad_input, - torch.ops.aten.reflection_pad3d.default, - torch.ops.aten.reflection_pad3d.out, - torch.ops.aten.reflection_pad3d_backward.default, - torch.ops.aten.reflection_pad3d_backward.grad_input, - torch.ops.aten.replication_pad1d.default, - torch.ops.aten.replication_pad1d.out, - torch.ops.aten.replication_pad2d.default, - torch.ops.aten.replication_pad2d.out, - torch.ops.aten.replication_pad3d.default, - torch.ops.aten.replication_pad3d.out, - torch.ops.aten.renorm.default, - torch.ops.aten.renorm.out, - torch.ops.aten.renorm_.default, - torch.ops.aten.resize_as.default, - torch.ops.aten.resize_as.out, - torch.ops.aten.roll.default, - torch.ops.aten.roll.out, - torch.ops.aten.rot90.default, - torch.ops.aten.rot90.out, - torch.ops.aten.rrelu_with_noise.default, - torch.ops.aten.rrelu_with_noise.out, - torch.ops.aten.rrelu_with_noise_.default, - torch.ops.aten.rsub.Tensor, - torch.ops.aten.rsub.Scalar, - torch.ops.aten.rsub.Tensor_out, - torch.ops.aten.rsub.Scalar_out, - torch.ops.aten._safe_softmax.default, - torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default, - torch.ops.aten.select_backward.default, - torch.ops.aten.select_backward.out, - torch.ops.aten.select_scatter.default, - torch.ops.aten.select_scatter.out, - torch.ops.aten.sgn.default, - torch.ops.aten.sgn.out, - torch.ops.aten.sgn_.default, - torch.ops.aten.sigmoid_backward.default, - torch.ops.aten.sigmoid_backward.grad_input, - torch.ops.aten.silu.default, - torch.ops.aten.silu.out, - torch.ops.aten.silu_.default, - torch.ops.aten.silu_backward.default, - torch.ops.aten.silu_backward.grad_input, - torch.ops.aten.sinc.default, - torch.ops.aten.sinc.out, - torch.ops.aten.sinc_.default, - torch.ops.aten.slice_backward.default, - torch.ops.aten.slice_backward.out, - torch.ops.aten.smooth_l1_loss.default, - torch.ops.aten.smooth_l1_loss.out, - torch.ops.aten.smooth_l1_loss_backward.default, - torch.ops.aten.smooth_l1_loss_backward.grad_input, - torch.ops.aten.soft_margin_loss.default, - torch.ops.aten.soft_margin_loss.out, - torch.ops.aten.soft_margin_loss_backward.default, - torch.ops.aten.soft_margin_loss_backward.grad_input, - torch.ops.aten._softmax_backward_data.default, - torch.ops.aten._softmax_backward_data.out, - torch.ops.aten.softplus.default, - torch.ops.aten.softplus.out, - torch.ops.aten.softplus_backward.default, - torch.ops.aten.softplus_backward.grad_input, - torch.ops.aten.softshrink.default, - torch.ops.aten.softshrink.out, - torch.ops.aten.special_entr.default, - torch.ops.aten.special_entr.out, - torch.ops.aten.special_log_ndtr.default, - torch.ops.aten.special_log_ndtr.out, - torch.ops.aten.special_xlog1py.default, - torch.ops.aten.special_xlog1py.other_scalar, - torch.ops.aten.special_xlog1py.self_scalar, - torch.ops.aten.special_xlog1py.out, - torch.ops.aten.special_xlog1py.self_scalar_out, - torch.ops.aten.special_xlog1py.other_scalar_out, - torch.ops.aten.split.Tensor, - torch.ops.aten.split_with_sizes_copy.default, - torch.ops.aten.split_with_sizes_copy.out, - torch.ops.aten.squeeze.default, - torch.ops.aten.squeeze.dim, - torch.ops.aten.std.default, - torch.ops.aten.std.dim, - torch.ops.aten.std.correction, - torch.ops.aten.std.names_dim, - torch.ops.aten.std.names_out, - torch.ops.aten.std.out, - torch.ops.aten.std.correction_out, - torch.ops.aten.std.correction_names, - torch.ops.aten.std.correction_names_out, - torch.ops.aten.std_mean.default, - torch.ops.aten.std_mean.dim, - torch.ops.aten.std_mean.correction, - torch.ops.aten.std_mean.names_dim, - torch.ops.aten.std_mean.correction_names, - torch.ops.aten.std_mean.correction_out, - torch.ops.aten.stack.default, - torch.ops.aten.stack.out, - torch.ops.aten.sum.default, - torch.ops.aten.sum.out, - torch.ops.aten.t.default, - torch.ops.aten.t_copy.out, - torch.ops.aten.t_copy.default, - torch.ops.aten.take.default, - torch.ops.aten.take.out, - torch.ops.aten.tanh_backward.default, - torch.ops.aten.tanh_backward.grad_input, - torch.ops.aten.threshold.default, - torch.ops.aten.threshold.out, - torch.ops.aten.threshold_.default, - torch.ops.aten.threshold_backward.default, - torch.ops.aten.threshold_backward.grad_input, - torch.ops.aten.trace.default, - torch.ops.aten.trace.out, - torch.ops.aten.transpose.int, - torch.ops.aten.tril.default, - torch.ops.aten.tril.out, - torch.ops.aten.tril_.default, - torch.ops.aten.triu.default, - torch.ops.aten.triu.out, - torch.ops.aten.triu_.default, - torch.ops.aten.unbind.int, - torch.ops.aten.unbind.Dimname, - torch.ops.aten.unfold_backward.default, - torch.ops.aten.unfold_backward.out, - torch.ops.aten.unfold_copy.default, - torch.ops.aten.unfold_copy.out, - torch.ops.aten._unsafe_index.Tensor, - torch.ops.aten._unsafe_index_put.default, - torch.ops.aten._unsafe_masked_index.default, - torch.ops.aten._unsafe_masked_index_put_accumulate.default, - torch.ops.aten.unsafe_split.Tensor, - torch.ops.aten.unsafe_split_with_sizes.default, - torch.ops.aten.unsqueeze_copy.out, - torch.ops.aten.unsqueeze_copy.default, - torch.ops.aten._unsafe_view.default, - torch.ops.aten._unsafe_view.out, - torch.ops.aten.upsample_linear1d.default, - torch.ops.aten.upsample_linear1d.out, - torch.ops.aten.upsample_bilinear2d.vec, - torch.ops.aten.upsample_bilinear2d.default, - torch.ops.aten.upsample_bilinear2d.out, - torch.ops.aten.upsample_trilinear3d.vec, - torch.ops.aten.upsample_trilinear3d.default, - torch.ops.aten.upsample_trilinear3d.out, - torch.ops.aten.xlogy.Tensor, - torch.ops.aten.xlogy.Scalar_Other, - torch.ops.aten.xlogy.Scalar_Self, - torch.ops.aten.xlogy.OutTensor, - torch.ops.aten.xlogy.OutScalar_Self, - torch.ops.aten.xlogy.OutScalar_Other, - torch.ops.aten.xlogy_.Tensor, - torch.ops.aten.xlogy_.Scalar_Other, - torch.ops.aten.zero.default, - torch.ops.aten.zero.out, - torch.ops.aten.zero_.default, - torch.ops.aten.zeros.default, - torch.ops.aten.zeros_like.default, - torch.ops.aten.zeros_like.out, - torch.ops.aten._chunk_cat.default, - torch.ops.aten._chunk_cat.out, - torch.ops.aten._weight_norm_interface.default, - torch.ops.aten._weight_norm_interface.out, - torch.ops.aten.__iand__.Tensor, - torch.ops.aten.__ixor__.Tensor, - torch.ops.aten.__ilshift__.Tensor, - torch.ops.aten.__ilshift__.Scalar, - torch.ops.aten.__irshift__.Tensor, - torch.ops.aten.__irshift__.Scalar, - torch.ops.aten.__ior__.Tensor, -]) - -MUTABLE_DECOMPOSITION = [ - torch.ops.aten.bernoulli_.Tensor, - torch.ops.aten.bernoulli_.float, -] diff --git a/torchax/torchax/device_module.py b/torchax/torchax/device_module.py deleted file mode 100644 index be028cfcc21d..000000000000 --- a/torchax/torchax/device_module.py +++ /dev/null @@ -1,33 +0,0 @@ -import torch - - -def _is_in_bad_fork(): - return False - - -def manual_seed_all(seed): - pass - - -def device_count(): - return 1 - - -def get_rng_state(): - return [] - - -def set_rng_state(new_state, device): - pass - - -def is_available(): - return True - - -def current_device(): - return 0 - - -def get_amp_supported_dtype(): - return [torch.float16, torch.bfloat16] diff --git a/torchax/torchax/export.py b/torchax/torchax/export.py deleted file mode 100644 index 987fb92ba6ee..000000000000 --- a/torchax/torchax/export.py +++ /dev/null @@ -1,245 +0,0 @@ -# pylint: disable -"""Utilities for exporting a torch program to jax/stablehlo.""" -import copy -from typing import Any, Dict, Tuple -import torch -from torch.utils import _pytree as pytree -import torchax -from torchax import tensor -from torchax.ops import ops_registry, mappings -from torchax import decompositions -import jax -import jax.export -import sympy - -DEBUG = False - - -class JaxInterpreter(torch.fx.Interpreter): - """Experimental.""" - - def __init__(self, graph_module): - super().__init__(graph_module) - import torchax.ops.jaten - import torchax.ops.jtorch - - def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: - if not isinstance(target, - (torch._ops.OpOverloadPacket, torch._ops.OpOverload)): - return super().call_function(target, args, kwargs) - - if DEBUG: - print('Running ', target.name(), '--------') - - op = ops_registry.all_aten_ops.get(target) - if op is None: - op = ops_registry.all_aten_ops.get(target.overloadpacket) - assert op is not None, target - assert op.is_jax_function, op - if op is None: - op = ops_registry.all_aten_ops.get(target.overloadpacket) - if op is None: - print(target.name(), target.tags) - raise RuntimeError('No lowering found for', target.name()) - return op.func(*args, **kwargs) - - def run_node(self, n) -> Any: - res = super().run_node(n) - if DEBUG: - if n.op == 'call_function': - if hasattr(res, 'shape'): - print('Meta:', n.meta.get('val').shape, 'REAL: ', res.shape) - return res - - -from torch._decomp import get_decompositions -import torch._refs - -_extra_decomp = get_decompositions([torch.ops.aten.unfold]) - - -def _extract_states_from_exported_program(exported_model): - # NOTE call convention: (parameters, buffers, user_inputs) - param_and_buffer_keys = exported_model.graph_signature.parameters + exported_model.graph_signature.buffers - state_dict = copy.copy(exported_model.state_dict) - if (constants := getattr(exported_model, 'constants', None)) is not None: - state_dict.update(constants) - param_buffer_values = list(state_dict[key] for key in param_and_buffer_keys) - - if hasattr(exported_model.graph_signature, "lifted_tensor_constants"): - for name in exported_model.graph_signature.lifted_tensor_constants: - param_buffer_values.append(exported_model.tensor_constants[name]) - - return param_and_buffer_keys, param_buffer_values - - -def exported_program_to_jax(exported_program, export_raw: bool = False): - """returns a pytree of jax arrays(state), and - - a callable(func) that is jax function. - - func(state, input) would be how you call it. - """ - if torch.__version__ >= '2.2': - # torch version 2.1 didn't expose this yet - exported_program = exported_program.run_decompositions() - exported_program = exported_program.run_decompositions( - decompositions.DECOMPOSITIONS) - if DEBUG: - print(exported_program.graph_module.code) - - names, states = _extract_states_from_exported_program(exported_program) - - def _extract_args(args, kwargs): - flat_args, received_spec = pytree.tree_flatten( - (args, kwargs)) # type: ignore[possibly-undefined] - return flat_args - - num_mutations = len(exported_program.graph_signature.buffers_to_mutate) - - def func(states, inputs): - args = _extract_args(inputs, {}) - res = JaxInterpreter(exported_program.graph_module).run( - *states, - *args, - enable_io_processing=False, - ) - res = res[num_mutations:] - return res - - if export_raw: - return names, states, func - env = torchax.default_env() - states = env.t2j_copy(states) - return states, func - - -def extract_avals(exported): - """Return JAX Abstract Value shapes for all input parameters of the exported - program. This supports dynamic batch dimensions, including with constraints. - """ - - def _to_aval(arg_meta, symbolic_shapes): - """Convet from torch type to jax abstract value for export tracing - """ - - def _get_dim(d): - if isinstance(d, torch.SymInt): - return symbolic_shapes[str(d)] - return d - - val = arg_meta['val'] - is_scalar = isinstance(val, float) or isinstance(val, int) or isinstance( - val, bool) - if is_scalar: - return jax.ShapeDtypeStruct([], type(arg_meta['val'])) - - tensor_meta = arg_meta['tensor_meta'] - shape = [_get_dim(d) for d in tensor_meta.shape] - return jax.ShapeDtypeStruct(shape, mappings.t2j_dtype(tensor_meta.dtype)) - - def _get_inputs(exported): - """Return placeholders with input metadata""" - placeholders = [p for p in exported.graph.nodes if p.op == "placeholder"] - input_placeholders = [ - p for p, s in zip(placeholders, exported.graph_signature.input_specs) - if s.kind == torch.export.graph_signature.InputKind.USER_INPUT - ] - return input_placeholders - - def _build_symbolic_shapes(range_constraints): - """Convert torch SymInt to JAX symbolic_shape and stores in a map using the - string name of the torch symbolic int. - - TODO: There is probably a better way of storing a key for a symbolic int. - This value needs to be looked up again in `_to_aval` to figure out which - JAX symbolic to map to for a given torch tensor. - """ - if len(range_constraints) == 0: - return None - - def _build_symbolic_constraints(symbol_name, torch_constraint): - """Convert torch SymInt constraints to string for JAX symbolic_shape - Using sympy may be overkill here, currently PyTorch only uses ValueRanges - which allow specifying the min and the max of a value, for example: - torch.export.Dim("a", min=5, max=10) - ==> ("a >= 5", "a <= 10",) - """ - if not isinstance(torch_constraint, torch.utils._sympy.value_ranges. - ValueRanges) or torch_constraint.is_bool: - raise TypeError( - f"No symbolic constraint handler for: {torch_constraint}") - - constraints = [] - symbol = sympy.Symbol(symbol_name) - if torch_constraint.lower != 2: - constraints.append(symbol >= torch_constraint.lower) - from sympy.core.singleton import S - if not torch_constraint.upper.is_infinite and torch_constraint.upper is not S.IntInfinity: - constraints.append(symbol <= torch_constraint.upper) - - return tuple(sympy.pretty(c, use_unicode=False) for c in constraints) - - def _build_symbolic_shape(sym, constraint, free_symbols): - """Returns a JAX symbolic shape for a given symbol and constraint - - There are two possible sympy `sym` inputs: - 1. Symbol - (s0) These can have custom constraints. - 2. Expr - (s0*2) These apply the expr to s0's constraints, cannot override. - - Currently support is limited to operations with a symbol and and int, - in `torch/export/dynamic_shapes.py`: - "Only increasing linear operations with integer coefficients are supported." - """ - symbol_name = str(sym) - constraints = _build_symbolic_constraints(symbol_name, constraint) - if sym.is_symbol: - symbolic_shape = jax.export.symbolic_shape( - symbol_name, constraints=constraints) - else: - assert len(sym.free_symbols) > 0 - scope = free_symbols[str(list(sym.free_symbols)[0])].scope - symbolic_shape = jax.export.symbolic_shape(symbol_name, scope=scope) - assert len(symbolic_shape) == 1 - return symbolic_shape[0] - - # Populate symbol variables before expressions, exprs need to use the same - # Symbolic scope as the variable they operate on. Expressions can only be - # integer compuations on symbol variables, so each symbol variable is OK to - # have its own scope. - symbolic_shapes = {} - symbol_variables = [ - (s, v) for s, v in range_constraints.items() if s.is_symbol - ] - symbol_exprs = [ - (s, v) for s, v in range_constraints.items() if not s.is_symbol - ] - for sym, constraint in symbol_variables + symbol_exprs: - symbolic_shape = _build_symbolic_shape(sym, constraint, symbolic_shapes) - symbolic_shapes[str(sym)] = symbolic_shape - return symbolic_shapes - - symbolic_shapes = _build_symbolic_shapes(exported.range_constraints) - args = _get_inputs(exported) - - if DEBUG: - print('Inputs to aval:', args, '--------') - print('Symbolic shapes:', symbolic_shapes) - for arg in args: - print('Meta2Aval', arg.meta, '--> ', _to_aval(arg.meta, symbolic_shapes)) - - return [_to_aval(arg.meta, symbolic_shapes) for arg in args] - - -def exported_program_to_stablehlo(exported_program): - """Replacement for torch_xla.stablehlo.exported_program_to_stablehlo - - Convert a program exported via torch.export to StableHLO. - - This supports dynamic dimension sizes and generates explicit checks for - dynamo guards in the IR using shape_assertion custom_call ops. - """ - weights, func = exported_program_to_jax(exported_program) - jax_avals = extract_avals(exported_program) - jax_export = jax.export.export(jax.jit(func))(weights, (jax_avals,)) - return weights, jax_export diff --git a/torchax/torchax/flax.py b/torchax/torchax/flax.py deleted file mode 100644 index 28542d79c90e..000000000000 --- a/torchax/torchax/flax.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Flax interop.""" - -import torch -import torchax as tx -import torchax.interop - - -class FlaxNNModule(torch.nn.Module): - - def __init__(self, env, flax_module, sample_args, sample_kwargs=None): - super().__init__() - prng = env.prng_key - sample_kwargs = sample_kwargs or {} - parameter_dict = tx.interop.call_jax(flax_module.init, prng, *sample_args, - **sample_kwargs) - - self._params = self._encode_nested_dict(parameter_dict) - - self._flax_module = flax_module - - def _encode_nested_dict(self, nested_dict): - child_module = torch.nn.Module() - for k, v in nested_dict.items(): - if isinstance(v, dict): - child_module.add_module(k, self._encode_nested_dict(v)) - else: - child_module.register_parameter(k, torch.nn.Parameter(v)) - return child_module - - def _decode_nested_dict(self, child_module): - result = dict(child_module.named_parameters(recurse=False)) - for k, v in child_module.named_children(): - result[k] = self._decode_nested_dict(v) - return result - - def forward(self, *args, **kwargs): - nested_dict_params = self._decode_nested_dict(self._params) - return tx.interop.call_jax(self._flax_module.apply, nested_dict_params, - *args, **kwargs) diff --git a/torchax/torchax/interop.py b/torchax/torchax/interop.py deleted file mode 100644 index 34ab79b10838..000000000000 --- a/torchax/torchax/interop.py +++ /dev/null @@ -1,356 +0,0 @@ -import collections -import copy -import functools -import torch -from inspect import signature -from functools import wraps -from torch.nn.utils import stateless as torch_stateless -import jax -import jax.numpy as jnp -from jax import tree_util as pytree -from jax.experimental.shard_map import shard_map -from torchax import tensor -from torchax import util -from torchax.ops import mappings -import torchax - -from torchax.types import JaxValue, TorchValue, JaxCallable, TorchCallable - - -def extract_all_buffers(m: torch.nn.Module): - buffers = {} - params = {} - - def extract_one(module, prefix): - for k in dir(module): - try: - v = getattr(module, k) - except: - continue - qual_name = prefix + k - if isinstance(v, torch.nn.parameter.Parameter) and v.requires_grad: - params[qual_name] = v - elif isinstance(v, torch.Tensor): - buffers[qual_name] = v - for name, child in module.named_children(): - extract_one(child, prefix + name + '.') - - extract_one(m, '') - return params, buffers - - -def set_all_buffers(m, params, buffers): - - def set_one(module, prefix): - for k in dir(module): - qual_name = prefix + k - if (potential_v := buffers.get(qual_name)) is not None: - setattr(module, k, potential_v) - elif (potential_v := params.get(qual_name)) is not None: - print(k, potential_v) - setattr(module, k, torch.nn.Parameter(potential_v)) - for name, child in module.named_children(): - set_one(child, prefix + name + '.') - - set_one(m, '') - - -class JittableModule(torch.nn.Module): - - def __init__(self, - m: torch.nn.Module, - extra_jit_args={}, - dedup_parameters=True): - super().__init__() - self.params, self.buffers = extract_all_buffers(m) - self._model = m - self._jitted = {} - - self._extra_jit_args = extra_jit_args - - self._extra_dumped_weights = {} - - if dedup_parameters: - temp = collections.defaultdict(list) - for k, v in self.params.items(): - temp[id(v)].append(k) - - for v in temp.values(): - if len(v) > 1: - # duplicated weights with different name - self._extra_dumped_weights[v[0]] = v[1:] - for extra_keys in v[1:]: - del self.params[extra_keys] - - @property - def __class__(self): - # Lie about the class type so that - # isinstance(jittable_module, self._model.__class__) works - return self._model.__class__ - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - - def functional_call(self, method_or_name, params, buffers, *args, **kwargs): - kwargs = kwargs or {} - params_copy = copy.copy(params) - params_copy.update(buffers) - # reinflate the state dict so there are not any missing keys - for k, v in self._extra_dumped_weights.items(): - for new_key in v: - params_copy[new_key] = params_copy[k] - - if isinstance(method_or_name, str): - method = getattr(self._model, method_or_name) - else: - if not callable(method_or_name): - raise TypeError( - f"method_or_name should be a callable or a string, got {type(method_or_name)}" - ) - method = method_or_name - args = (self._model,) + args - with torch_stateless._reparametrize_module(self._model, params_copy): - res = method(*args, **kwargs) - return res - - def jittable_call(self, method_name: str, *args, **kwargs): - if method_name not in self._jitted: - jitted = jax_jit( - functools.partial(self.functional_call, method_name), - kwargs_for_jax_jit=self._extra_jit_args, - ) - - def jitted_forward(*args, **kwargs): - return jitted(self.params, self.buffers, *args, **kwargs) - - self._jitted[method_name] = jitted_forward - return self._jitted[method_name](*args, **kwargs) - - def forward(self, *args, **kwargs): - return self.jittable_call('forward', *args, **kwargs) - - def __getattr__(self, key): - if key == '_model': - return super().__getattr__(key) - if key in self._jitted: - return self._jitted[key] - return getattr(self._model, key) - - def make_jitted(self, key): - jitted = jax_jit( - functools.partial(self.functional_call, key), - kwargs_for_jax_jit=self._extra_jit_args) - - def call(*args, **kwargs): - return jitted(self.params, self.buffers, *args, **kwargs) - - self._jitted[key] = call - - -class CompileMixin: - - def functional_call(self, method, params, buffers, *args, **kwargs): - kwargs = kwargs or {} - params_copy = copy.copy(params) - params_copy.update(buffers) - with torch_stateless._reparametrize_module(self, params_copy): - res = method(*args, **kwargs) - return res - - def jit(self, method): - jitted = jax_jit(functools.partial(self.functional_call, method_name)) - - def call(*args, **kwargs): - return jitted(self.named_paramters(), self.named_buffers(), *args, - **kwargs) - - return call - - -def compile_nn_module(m: torch.nn.Module, methods=None): - if methods is None: - methods = ['forward'] - - new_parent = type( - m.__class__.__name__ + '_with_CompileMixin', - (CompileMixin, m.__class__), - ) - m.__class__ = NewParent - - -def _torch_view(t: JaxValue) -> TorchValue: - # t is an object from jax land - # view it as-if it's a torch land object - if isinstance(t, jax.Array): - # TODO - return tensor.Tensor(t, torchax.default_env()) - if isinstance(t, jnp.dtype): - return mappings.j2t_dtype(t) - if callable(t): # t is a JaxCallable - return functools.partial(call_jax, t) - # regular types are not changed - return t - - -torch_view = functools.partial(pytree.tree_map, _torch_view) - - -def _jax_view(t: TorchValue) -> JaxValue: - # t is an object from torch land - # view it as-if it's a jax land object - if isinstance(t, torch.Tensor): - assert isinstance(t, tensor.Tensor) or isinstance(t, tensor.View), type(t) - return t.jax() - if isinstance(t, type(torch.int32)): - return mappings.t2j_dtype(t) - - # torch.nn.Module needs special handling - if not isinstance(t, torch.nn.Module) and callable(t): # t is a TorchCallable - return functools.partial(call_torch, t) - # regular types are not changed - return t - - -jax_view = functools.partial(pytree.tree_map, _jax_view) - - -def call_jax(jax_func: JaxCallable, *args: TorchValue, - **kwargs: TorchValue) -> TorchValue: - args, kwargs = jax_view((args, kwargs)) - res: JaxValue = jax_func(*args, **kwargs) - return torch_view(res) - - -def call_torch(torch_func: TorchCallable, *args: JaxValue, - **kwargs: JaxValue) -> JaxValue: - args, kwargs = torch_view((args, kwargs)) - with torchax.default_env(): - res: TorchValue = torch_func(*args, **kwargs) - return jax_view(res) - - -def j2t_autograd(fn, call_jax=call_jax): - """Given a JAX function, returns a PyTorch autograd function implemented with `jax.vjp(fn)`. - - It wraps `fn` with `jax.vjp` to compute both the output and residuals (intermediate - activations). The wrapped function is then run via `call_jax` and integrated into - the PyTorch autograd framework by saving the residuals into the context object. - """ - - # NOTE(qihqi): This function cannot be inlined from the callsite - # Becuase if it does, then it won't hit the compilation cache for - # call_jax. Call jax uses functions' id as key. - # It is nested inside j2t_autograd to ensure it gets a unique ID for each - # wrapped pure function, preventing cache collisions between different pure modules. - def _jax_forward(fn, other, tree_def, tensors): - """JAX function to compute output and vjp function. - - primals should be a tuple (args, kwargs). - """ - import jax - from jax.tree_util import tree_flatten, tree_unflatten - - def fn_wrapper(*tensors): - # Reconstruct the original args and kwargs - flat_inputs = util.merge(tensors, other) - args, kwargs = tree_unflatten(tree_def, flat_inputs) - return fn(*args, **kwargs) - - return jax.vjp(fn_wrapper, *tensors) - - def _jax_backward(vjp_spec, saved_tensors, grad_out): - """JAX function to compute input gradients. - - Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function. - """ - from jax.tree_util import tree_unflatten - fun_vjp = tree_unflatten(vjp_spec, saved_tensors) - return fun_vjp(grad_out) - - @wraps(fn) - def inner(*args, **kwargs): - from jax.tree_util import tree_flatten - - class JaxFun(torch.autograd.Function): - - @staticmethod - def forward(ctx, tree_def, *flat_args_kwargs): - - tensors, other = util.partition(flat_args_kwargs, - lambda x: isinstance(x, torch.Tensor)) - # We want the arguments that don't require grads to be closured? - - y, fun_vjp = call_jax(_jax_forward, fn, other, tree_def, tensors) - - # Save necessary information for backward - # Flatten the vjp function. `vjp_spec` contains a jaxpr for the backward pass. - # `residuals` contains the tensors needed for the backward pass.` - residuals, vjp_spec = tree_flatten(fun_vjp) - ctx.vjp_spec = vjp_spec - ctx.save_for_backward(*residuals) - return y - - @staticmethod - def backward(ctx, *grad_out): - assert len(grad_out) > 0 - grad_out = grad_out if len(grad_out) > 1 else grad_out[0] - - input_grads_structured = call_jax(_jax_backward, ctx.vjp_spec, - ctx.saved_tensors, grad_out) - - # Construct the gradient tuple to be returned. - # It needs to match the inputs to forward: (tree_def, *flat_inputs) - # The first gradient (for tree_def) is None. - # The subsequent gradients correspond to flat_inputs. - # We need to put a None for inputs that did not require gradients. - final_grads = [None] - for needs_grad, grad in zip( - ctx.needs_input_grad[1:], input_grads_structured, strict=True): - final_grads.append(grad if needs_grad else None) - - return tuple(final_grads) - - sig = signature(fn) - bound = sig.bind(*args, **kwargs) - bound.apply_defaults() - flat_args_kwargs, tree_def = tree_flatten((bound.args, bound.kwargs)) - y = JaxFun.apply(tree_def, *flat_args_kwargs) - return y - - return inner - - -fori_loop = torch_view(jax.lax.fori_loop) - - -def wrap_jax_jit(torch_function, jax_jit_func=jax.jit, kwargs_for_jax=None): - kwargs_for_jax = kwargs_for_jax or {} - jax_func = jax_view(torch_function) - jitted = jax_jit_func(jax_func, **kwargs_for_jax) - return torch_view(jitted) - - -def jax_jit(torch_function, - kwargs_for_jax_jit=None, - fix_for_buffer_donation=False): - return wrap_jax_jit( - torch_function, jax_jit_func=jax.jit, kwargs_for_jax=kwargs_for_jax_jit) - - -def jax_shard_map(torch_function, kwargs_for_jax_shard_map=None): - return wrap_jax_jit( - torch_function, - jax_jit_func=shard_map, - kwargs_for_jax=kwargs_for_jax_shard_map) - - -def jax_value_and_grad(torch_function, kwargs_for_value_and_grad=None): - return wrap_jax_jit( - torch_function, - jax_jit_func=jax.value_and_grad, - kwargs_for_jax=kwargs_for_value_and_grad) - - -def gradient_checkpoint(torch_function, kwargs=None): - return wrap_jax_jit( - torch_function, jax_jit_func=jax.checkpoint, kwargs_for_jax=kwargs) diff --git a/torchax/torchax/mesh_util.py b/torchax/torchax/mesh_util.py deleted file mode 100644 index 208d86a1bac6..000000000000 --- a/torchax/torchax/mesh_util.py +++ /dev/null @@ -1,220 +0,0 @@ -import jax -from jax.sharding import PartitionSpec, NamedSharding -import torch -import torchax -from torchax import interop - - -def _shard_first_multiple_of(axis_name, shape, multiple_of): - """Creates a PartitionSpec to shard the first dimension divisible by a number. - - Iterates through the dimensions specified by `shape`. Finds the first dimension - whose size is a multiple of `multiple_of` and returns a PartitionSpec that - shards that dimension along the given `axis_name`. All preceding dimensions - are not sharded (marked as None in the PartitionSpec). All subsequent dimensions - skipped, which would be implicitly treated as replicated. - - Args: - axis_name: The name of the mesh axis to shard along (e.g., "data", "mdl"). - shape: A tuple or list representing the shape of the tensor to be sharded. - multiple_of: The integer value that a dimension size must be divisible by - in order to be sharded. Typically the size of the mesh axis. - - Returns: - A jax.sharding.PartitionSpec object specifying how to shard the tensor. - For example, if shape=(10, 20, 30), axis_name='x', multiple_of=4, - it would return PartitionSpec(None, 'x', None). - If none divides then it should return a replicated PartitionSpec - """ - sharding = [] - found = False - for size in shape: - if not found and size % multiple_of == 0: - found = True - sharding.append(axis_name) - else: - sharding.append(None) - return PartitionSpec(*sharding) - - -class SingleAxisSharder: - """A callable object that generates PartitionSpecs for single-axis sharding. - - This sharder strategy attempts to shard the *first* dimension of a tensor - that is divisible by the specified `axis_size` along the given `axis_name`. - It's useful for simple 1D mesh sharding scenarios like FSDP where parameters - are typically sharded along one dimension. - - Attributes: - axis_name: The name of the mesh axis to shard along. - axis_size: The size of the mesh axis (number of devices along that axis). - """ - - def __init__(self, axis_name, axis_size, replicate_unshardable=False): - """Initializes the SingleAxisSharder. - - Args: - axis_name: The name of the mesh axis (e.g., "fsdp", "data"). - axis_size: The number of devices along the specified mesh axis. - replicate_unshardable: indicate whether it should return replicated sharding - (P()) when none of the axis is divisible by the axis size. - """ - self.axis_name = axis_name - self.axis_size = axis_size - self.replicate_unshardable = replicate_unshardable - - def __call__(self, name, shapedtype): - """Generates a PartitionSpec for a given tensor name and shaped type. - - Args: - name: The name of the tensor (e.g., parameter name). This argument is - provided for compatibility with more complex sharders but is not used - by this simple sharder. - shapedtype: An object with a `.shape` attribute describing the tensor's shape, - and `.dtype` describing it's dtype. Example: jax.Array, jax.ShapeDtypeStruct - or a torch.Tensor) - - Returns: - A jax.sharding.PartitionSpec determined by finding the first dimension - in `shapedtype.shape` divisible by `self.axis_size` using the helper - `_shard_first_multiple_of`. - """ - del name - sharding = _shard_first_multiple_of(self.axis_name, shapedtype.shape, - self.axis_size) - if not self.replicate_unshardable and all(s is None for s in sharding): - raise AssertionError( - f"Unable to find a dim to shard because " - f"None of the dims ({shapedtype.shape}) in shape is multiple of {self.axis_size}" - ) - return sharding - - -class Mesh: - """A helper class that wraps `jax.sharding.Mesh` object. - - The goal of this class is to provide helper methods that facilitate the - sharding of PyTorch tensors or models given a JAX device mesh configuration. - It simplifies initializing models directly into a sharded state. - - Attributes: - jax_mesh: The underlying `jax.sharding.Mesh` object defining the device grid - and axis names. - _sharder: The default sharding strategy callable (like SingleAxisSharder) - used to determine the PartitionSpec for each parameter if not overridden - during method calls. Can be None if no default is appropriate or set. - """ - - @classmethod - def fsdp_mesh(cls, axis_name="fsdp"): - """Creates a Mesh instance suitable for 1D FSDP-style sharding. - - This named constructor creates a 1D mesh encompassing all available XLA - devices. It assigns the specified `axis_name` to this single dimension. - It then creates a `Mesh` instance using this JAX mesh and a - `SingleAxisSharder` configured appropriately for this 1D mesh. - - Args: - axis_name: The name to assign to the single mesh axis (default: "fsdp"). - This name will be used by the default `SingleAxisSharder`. - - Returns: - A Mesh instance configured with a 1D JAX mesh across all devices and a - corresponding SingleAxisSharder. - """ - ndevice = jax.device_count() - jax_mesh = jax.make_mesh((ndevice,), (axis_name,)) - # replicate_unshardable so scalars and small model attributes are replicated. - return cls(jax_mesh, SingleAxisSharder(axis_name, ndevice, True)) - - def __init__(self, jax_mesh, sharder=None): - """Initializes the Mesh helper. - - Args: - jax_mesh: A pre-configured `jax.sharding.Mesh` object defining the - physical device grid and logical axis names. - sharder: An optional callable (e.g., an instance of SingleAxisSharder) - that takes (name, shapedtype) and returns a `jax.sharding.PartitionSpec`. - This serves as the default sharding strategy. - If None, and the provided `jax_mesh` has exactly one axis, a - `SingleAxisSharder` is created automatically for that single axis. - If None and the mesh has multiple axes, `_sharder` remains None, and - an `override_sharder` must be provided to methods like - `initialize_model_sharded`. - """ - self.jax_mesh = jax_mesh - if sharder is None: - assert len(self.jax_mesh.axis_names) == 1 - sharder = SingleAxisSharder(self.jax_mesh.axis_names[0], - len(self.mesh.device_ids)) - self._sharder = sharder - - def initialize_model_sharded(self, - model_class, - init_args, - init_kwargs=None, - override_sharder=None): - """Initializes a PyTorch model with its parameters sharded across the mesh. - - This method orchestrates the initialization of a `torch.nn.Module` such - that its parameters are created directly on the target devices according - to the sharding specifications derived from the mesh and the chosen sharder. - It leverages `torchax.interop.jax_jit` to achieve this. - - Args: - model_class: The PyTorch model class (a subclass of `torch.nn.Module`). - init_args: A tuple containing the positional arguments required by the - `model_class.__init__` method. - init_kwargs: An optional dictionary containing the keyword arguments for - the `model_class.__init__` method. Defaults to None (treated as {}). - override_sharder: An optional callable sharding strategy to use - specifically for this initialization. If provided, it takes precedence - over the mesh's default `_sharder`. It must accept `(name, shapedtype)` - and return a `PartitionSpec`. If None, the mesh's default `_sharder` - is used. - - Returns: - An instance of `model_class` whose parameters have been initialized and - are represented by sharded tensors distributed across the devices in the - `jax_mesh`. - - Raises: - ValueError: If no sharder is available (i.e., `override_sharder` is None - and the mesh's default `_sharder` is also None). - AssertionError: Can be raised by the sharder (e.g., `SingleAxisSharder`) - if it fails to determine a valid sharding for any parameter. - TypeError: If `shapedtype` passed to the sharder doesn't have a `.shape`. - Other errors from JAX JIT compilation or PyTorch model initialization. - """ - init_kwargs = init_kwargs or {} - with torch.device("meta"), torchax.disable_temporarily(): - model = model_class(*init_args, **init_kwargs) - - sharder = override_sharder or self._sharder - - states = model.state_dict() - output_shards = { - name: NamedSharding(self.jax_mesh, sharder(name, tensor)) - for name, tensor in states.items() - } - - def model_initializer(): - with torchax.default_env(), torch.device('meta'): - model = model_class(*init_args, **init_kwargs) - return dict(model.state_dict()) - - jitted = interop.jax_jit( - model_initializer, kwargs_for_jax_jit={"out_shardings": output_shards}) - weights_dict = jitted() - - model.load_state_dict(weights_dict, assign=True) - return model - - def shard_model(self, model, override_sharder=None): - sharder = override_sharder or self._sharder - states = model.state_dict() - output_shards = { - name: NamedSharding(self.jax_mesh, sharder(name, tensor)) - for name, tensor in states.items() - } - model.load_state_dict(output_shards, assign=True) diff --git a/torchax/torchax/ops/__init__.py b/torchax/torchax/ops/__init__.py deleted file mode 100644 index 71c1b137132f..000000000000 --- a/torchax/torchax/ops/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -def all_aten_jax_ops(): - # to load the ops - import torchax.ops.jaten # type: ignore - import torchax.ops.ops_registry # type: ignore - - return { - key: val.func - for key, val in torchax.ops.ops_registry.all_aten_ops.items() - if val.is_jax_function - } diff --git a/torchax/torchax/ops/jaten.py b/torchax/torchax/ops/jaten.py deleted file mode 100644 index 700d581d7736..000000000000 --- a/torchax/torchax/ops/jaten.py +++ /dev/null @@ -1,5638 +0,0 @@ -"""Torch ops implemented using jax.""" - -import sys -from typing import Optional, Sequence, Tuple, Union, Callable -import functools - -import math -import jax -from jax import numpy as jnp -import functools -import numpy as np -import torch -import torch.distributed._functional_collectives -from torchax.ops import ops_registry -from torchax.ops import op_base, mappings -from torchax import interop -from torchax.ops import jax_reimplement -from torchax.view import View -# Keys are OpOverload, value is a callable that takes -# Tensor -all_ops = {} - - -def op(*aten, **kwargs): - - def inner(func): - for a in aten: - ops_registry.register_torch_dispatch_op(a, func, **kwargs) - continue - - if isinstance(a, torch._ops.OpOverloadPacket): - opname = a.default.name() if 'default' in a.overloads( - ) else a._qualified_op_name - elif isinstance(a, torch._ops.OpOverload): - opname = a.name() - else: - raise RuntimeError(f'oops {a}') - - torchfunc = functools.partial(interop.call_jax, func) - # HACK: to_copy is where we make the initial conversion from CPU tensor to JAX tensor - torch.library.impl(opname, 'privateuseone')( - torchfunc if a != torch.ops.aten._to_copy else func) - return func - - return inner - - -@op( - torch.ops.aten.view_copy, - torch.ops.aten.view, - torch.ops.aten._unsafe_view, - torch.ops.aten.reshape, -) -def _aten_unsafe_view(x, shape): - return jnp.reshape(x, shape) - - -@op(torch.ops.aten.add.Tensor) -@op(torch.ops.aten.add.Scalar) -def _aten_add(x, y, *, alpha=1): - """if isinstance(x, jnp.ndarray) and isinstance(y, jnp.ndarray): - - assert x.dtype == y.dtype, (x.dtype, y.dtype) - """ - res = x + y * alpha - if isinstance(x, float) or isinstance(y, float): - new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) - res = res.astype(new_dtype) - return res - - -@op(torch.ops.aten.copy_, - is_jax_function=False, - is_view_op=True, - needs_env=True) -def _aten_copy(x, y, memory_format=None, env=None): - - if y.device.type == 'cpu': - y = env.to_xla(y) - - if isinstance(x, View): - x.update(y) - return x - - if x.ndim == 1 and y.ndim == 0: - # case of torch.empty((1,)).copy_(tensor(N)) - # we need to return 0D tensor([N]) and not scalar tensor(N) - # ref: https://github.com/pytorch/xla/issues/7505#issuecomment-2395319131 - x._elem = jnp.array([y._elem.astype(x._elem.dtype)]) - else: - x._elem = y._elem.astype(x._elem.dtype) - return x - - -@op(torch.ops.aten.clone) -def _aten_clone(x, memory_format=None): - return x - - -# aten.trunc -@op(torch.ops.aten.trunc) -def _aten_trunc(x): - res = jnp.trunc(x) - return res.astype(x) - - -@op(torch.ops.aten.index_copy) -def _aten_index_copy(x, dim, indexes, source): - if x.ndim == 0: - return source - if x.ndim == 1: - source = jnp.squeeze(source) - # return jax.lax.scatter(x, index, dim) - if dim < 0: - dim = dim + x.ndim - dims = [] - for i in range(len(x.shape)): - if i == dim: - dims.append(indexes) - else: - dims.append(slice(None, None, None)) - return x.at[tuple(dims)].set(source) - - -# aten.cauchy_ -@op(torch.ops.aten.cauchy_) -def _aten_cauchy_(x, median=0, sigma=1): - """ - Fills the input array with values drawn from a Cauchy distribution. - - Args: - x: An array to be filled with Cauchy samples. - median: The median of the Cauchy distribution. - sigma: The scale parameter of the Cauchy distribution. - - Returns: - The input array filled with Cauchy samples. - """ - key = jax.random.PRNGKey(0) # You should use a different key for each call - samples = jax.random.cauchy(key, x.shape) * sigma + median - return x.at[:].set(samples) - - -@op(torch.ops.aten.atleast_2d) -def _aten_atleast_2d(inputs): - return jnp.atleast_2d(inputs) - - -@op(torch.ops.aten.atleast_1d) -def _aten_atleast_1d(inputs): - return jnp.atleast_1d(inputs) - - -# aten.complex -@op(torch.ops.aten.complex) -def _aten_complex(real, imag): - """ - Constructs a complex array from real and imaginary parts. - - Args: - real: An array of real values. - imag: An array of imaginary values. - - Returns: - A complex array with the specified real and imaginary parts. - """ - return jnp.array( - real, dtype=jnp.float32) + 1j * jnp.array( - imag, dtype=jnp.float32) - - -# aten.exponential_ -@op(torch.ops.aten.exponential_) -def _aten_exponential_(x, lambd=1.0): - """ - Fills the input array with values drawn from an exponential distribution. - - Args: - x: An array to be filled with exponential samples. - lambd: The rate parameter of the exponential distribution. - - Returns: - The input array filled with exponential samples. - """ - key = jax.random.PRNGKey(0) # Use a different key for each call - samples = jax.random.exponential(key, x.shape) / lambd - return x.at[:].set(samples) - - -# aten.linalg_householder_product -@op(torch.ops.aten.linalg_householder_product) -def _aten_linalg_householder_product(input, tau): - return jax.lax.linalg.householder_product(a=input, taus=tau) - - -@op(torch.ops.aten.select) -def _aten_select(x, dim, indexes): - return jax.lax.index_in_dim(x, index=indexes, axis=dim, keepdims=False) - - -@op(torch.ops.aten.index_select) -@op(torch.ops.aten.select_copy) -def _aten_index_select(x, dim, index): - if x.shape == (): - return x - return jnp.take(x, index, dim) - - -@op(torch.ops.aten.cholesky) -def _aten_cholesky(input, upper=False): - return jax.scipy.linalg.cholesky(input, lower=(not upper)) - - -@op(torch.ops.aten.linalg_cholesky_ex) -def _aten_linalg_cholesky_ex(input, upper=False, check_errors=False): - if check_errors: - raise NotImplementedError( - "check_errors=True is not supported in this JAX implementation. " - "Check for positive definiteness using jnp.linalg.eigvalsh before " - "calling this function.") - - L = jax.scipy.linalg.cholesky(input, lower=not upper) - if len(L.shape) > 2: - info = jnp.zeros(shape=L.shape[:-2], dtype=jnp.int32) - else: - info = jnp.array(0, dtype=jnp.int32) - return L, info - - -@op(torch.ops.aten.cholesky_solve) -def _aten_cholesky_solve(input, input2, upper=False): - # Ensure input2 is lower triangular for cho_solve - L = input2 if not upper else input2.T - # Use cho_solve to solve the linear system - solution = jax.scipy.linalg.cho_solve((L, True), input) - return solution - - -@op(torch.ops.aten.special_zeta) -def _aten_special_zeta(x, q): - new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) - res = jax.scipy.special.zeta(x, q) - if isinstance(x, int) or isinstance(q, int): - res = res.astype(new_dtype) - return res # jax.scipy.special.zeta(x, q) - - -# aten.igammac -@op(torch.ops.aten.igammac) -def _aten_igammac(input, other): - if isinstance(input, jnp.ndarray): - input = jnp.where(input < 0, jnp.nan, input) - if isinstance(other, jnp.ndarray): - other = jnp.where(other < 0, jnp.nan, other) - else: - if (input == 0 and other == 0) or (input < 0) or (other < 0): - other = jnp.nan - return jnp.array(jax.scipy.special.gammaincc(input, other)) - - -@op(torch.ops.aten.mean) -def _aten_mean(x, dim=None, keepdim=False): - if x.shape == () and dim is not None: - dim = None # disable dim for jax array without dim - return jnp.mean(x, dim, keepdims=keepdim) - - -def _torch_binary_scalar_type(scalar, tensor): - if "float" in str(tensor.dtype) or "complex" in str(tensor.dtype): - return tensor.dtype - - if isinstance(scalar, int): - if "int" in str(tensor.dtype): - return tensor.dtype - - return jnp.float32 - - -@op(torch.ops.aten.searchsorted.Tensor) -def _aten_searchsorted(sorted_sequence, values): - new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) - res = jnp.searchsorted(sorted_sequence, values) - if sorted_sequence.dtype == np.dtype( - np.int32) or sorted_sequence.dtype == np.dtype(np.int32): - # res = res.astype(new_dtype) - res = res.astype(np.dtype(np.int64)) - return res # jnp.searchsorted(sorted_sequence, values) - - -@op(torch.ops.aten.sub.Tensor) -@op(torch.ops.aten.sub.Scalar) -def _aten_sub(x, y, alpha=1): - if isinstance(x, float): - dtype = _torch_binary_scalar_type(x, y) - x = jnp.array(x, dtype=dtype) - if isinstance(y, float): - dtype = _torch_binary_scalar_type(y, x) - y = jnp.array(y, dtype=dtype) - return x - y * alpha - - -@op(torch.ops.aten.numpy_T) -def _aten_numpy_T(input): - """ - Jax implementation of torch.numpy_T. - - Args: - input: JAX array. - - Returns: - Transposed JAX array. - """ - return jnp.transpose(input) - - -@op(torch.ops.aten.mm) -def _aten_mm(x, y): - res = x @ y - return res - - -@op(torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar) -def _aten_mul(x, y): - new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) - res = x * y - if isinstance(x, float) or isinstance(y, float): - res = res.astype(new_dtype) - else: - if (not isinstance(x, int)) and (not isinstance(y, int)): - if x.dtype == np.dtype(np.float64) or y.dtype == np.dtype(np.float64): - res = res.astype(new_dtype) - return res - - -@op(torch.ops.aten.silu) -@op(torch.ops.aten.silu.default) -def _aten_silu(x): - return jax.nn.silu(x) - - -@op(torch.ops.aten.t) -def _aten_t(x): - return jnp.transpose(x) - - -@op(torch.ops.aten.transpose) -@op(torch.ops.aten.transpose_copy) -def _aten_transpose(x, dim0, dim1): - if x.ndim == 0: - return x - dim0 = dim0 if dim0 >= 0 else dim0 + x.ndim - dim1 = dim1 if dim1 >= 0 else dim1 + x.ndim - return jnp.swapaxes(x, dim0, dim1) - - -@op(torch.ops.aten.triu) -def _aten_triu(m, k=0): - return jnp.triu(m, k) - - -@op(torch.ops.aten.slice) -@op(torch.ops.aten.slice_copy) -def _aten_slice(self, dim=0, start=None, end=None, step=1): - if dim < 0: - dim += self.ndim - if end == sys.maxsize: - end = self.shape[dim] - sl = slice(start, end, step) - dims = [] - for i in range(len(self.shape)): - if i == dim: - dims.append(sl) - else: - dims.append(slice(None, None, None)) - return self[tuple(dims)] - - -@op(torch.ops.aten.positive) -@op(torch.ops.aten.detach) -def _aten_detach(self): - return self - - -@op(torch.ops.aten.imag) -def _aten_imag(x): - return jnp.imag(x) - - -@op(torch.ops.aten.isfinite) -def _aten_isfinite(x): - return jnp.isfinite(x) - - -@op(torch.ops.aten.real) -def _aten_real(x): - return jnp.real(x) - - -@op(torch.Tensor.resize_) -def _aten_resize_(x, size, interpolation='linear'): - new_size = tuple(size) - return jax.numpy.resize(x, new_size) - - -@op(torch.ops.aten.resize_as_) -def _aten_resize_as_(x, y): - return jax.numpy.resize(x, y.shape) - - -@op(torch.ops.aten.repeat_interleave.Tensor) -def repeat_interleave(repeats, dim=0): - return jnp.repeat(np.arange(repeats.shape[dim]), repeats) - - -@op(torch.ops.aten.repeat_interleave.self_int) -@op(torch.ops.aten.repeat_interleave.self_Tensor) -def repeat_interleave(self, repeats, dim=0): - total_repeat_length = None - if isinstance(repeats, int): - total_repeat_length = self.shape[dim] * repeats - repeats = np.array([repeats] * self.shape[dim]) - return jnp.repeat(self, repeats, dim, total_repeat_length=total_repeat_length) - - -@op(torch.ops.aten.view_as_real) -def _aten_view_as_real(x): - real = jnp.real(x) - im = jnp.imag(x) - res = jnp.stack([real, im], -1) - return res - - -@op(torch.ops.aten.stack) -def _aten_stack(tensors, dim=0): - return jnp.stack(tensors, dim) - - -@op(torch.ops.aten._softmax) -@op(torch.ops.aten.softmax) -@op(torch.ops.aten.softmax.int) -def _aten_softmax(x, dim, halftofloat=False): - if x.shape == (): - return jax.nn.softmax(x.reshape([1]), axis=0).reshape([]) - return jax.nn.softmax(x, dim) - - -def _is_int(x): - if isinstance(x, int): - return True - if isinstance(x, jax.Array) and (x.dtype.name.startswith('int') or - x.dtype.name.startswith('uint')): - return True - return False - - -def highest_precision_int_dtype(tensor1, tensor2): - if isinstance(tensor1, int): - return tensor2.dtype - if isinstance(tensor2, int): - return tensor1.dtype - - dtype_hierarchy = { - 'uint8': 8, - 'int8': 8, - 'uint16': 16, - 'int16': 16, - 'uint32': 32, - 'int32': 32, - 'uint64': 64, - 'int64': 64, - } - return max( - tensor1.dtype, - tensor2.dtype, - key=lambda dtype: dtype_hierarchy[str(dtype)]) - - -@op(torch.ops.aten.pow) -def _aten_pow(x, y): - y_orig = y - if isinstance(y, int): - y = float(y) - if _is_int(x) and _is_int(y_orig): - # Do the math in float then cast - res = jnp.power(jnp.astype(x, jnp.dtype('float')), y) - return res.astype(highest_precision_int_dtype(x, y_orig)) - res = jnp.power(x, y) - if isinstance(x, float): - return res.astype(_torch_binary_scalar_type(x, y_orig)) - if isinstance(y_orig, float): - return res.astype(_torch_binary_scalar_type(y_orig, x)) - return res - - -@op(torch.ops.aten.view_as_complex) -def _aten_view_as_complex(input): - if input.dtype == jnp.bfloat16: - input = input.astype(jnp.float32) - x, y = input[..., 0], input[..., 1] - return jax.lax.complex(x, y) - - -@op(torch.ops.aten.div) -def _aten_div(x, y, rounding_mode=""): - res_dtype = None - if _is_int(x) and _is_int(y): - res_dtype = jnp.dtype('float32') - - if (isinstance(x, float) or isinstance(y, float)): - res_dtype = new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) - - if rounding_mode == "floor": - res = jnp.floor_divide(x, y) - if _is_int(x) and _is_int(y): - res_dtype = jnp.dtype('int64') - else: - res = x / y - if rounding_mode == "trunc": - res = jnp.trunc(res) - if _is_int(x) and _is_int(y): - res_dtype = jnp.dtype('int64') - if res_dtype: - res = res.astype(res_dtype) - return res - - -@op(torch.ops.aten.true_divide) -def _aten_true_divide(x, y): - return x / y - - -@op(torch.ops.aten.dist) -def _aten_dist(input, other, p=2): - diff = jnp.abs(jnp.subtract(input, other)) - return _aten_linalg_vector_norm(diff, ord=p) - - -@op(torch.ops.aten.bmm) -def _aten_bmm(x, y): - res = x @ y - return res - # return jnp.einsum('bnm,bmk->bnk', x, y) - - -@op(torch.ops.aten.embedding) -# embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -def _aten_embedding(a, - w, - padding_idx=-1, - scale_grad_by_freq=False, - sparse=False): - return jnp.take(a, w, axis=0) - - -@op(torch.ops.aten.embedding_renorm_) -def _aten_embedding_renorm_(weight, indices, max_norm, norm_type): - # Adapted from https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Embedding.cpp - unique_indices = jnp.unique(indices) - - norm = jnp.linalg.norm( - _aten_embedding(weight, unique_indices), - ord=norm_type, - axis=1, - ) - - indice_idx = jnp.where(norm > max_norm) - - scale = max_norm / (norm[indice_idx] + 1e-7) - - indices_to_update = unique_indices[indice_idx] - - weight = weight.at[indices_to_update].set(weight[indices_to_update] * - scale[:, None]) - return weight - - -#- func: _embedding_bag_forward_only( -# Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, -# int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor) -@op(torch.ops.aten._embedding_bag) -@op(torch.ops.aten._embedding_bag_forward_only) -def _aten__embedding_bag(weight, - indices, - offsets=None, - scale_grad_by_freq=False, - mode=0, - sparse=False, - per_sample_weights=None, - include_last_offset=False, - padding_idx=-1): - """Jax implementation of the PyTorch _embedding_bag function. - - Args: - weight: The learnable weights of the module of shape (num_embeddings, embedding_dim). - indices: A LongTensor containing the indices to extract. - offsets: A LongTensor containing the starting offset of each bag. - scale_grad_by_freq: Whether to scale gradients by the inverse of frequency of the words in the mini-batch. - mode: 0 = "sum", 1 = "mean" or 2 = "max" - sparse: Whether the gradients with respect to weight should be a sparse tensor. - per_sample_weights: If given, each embedding vector is weighted by per_sample_weights - include_last_offset: Whether to include the last offset as a valid bag. - padding_idx: If specified, the entries at padding_idx do not contribute to the gradient. - - Returns: - A tuple of (output, offset2bag, bag_size, max_indices). - """ - embedded = _aten_embedding(weight, indices, padding_idx) - - if offsets is None: - # offsets is None only when indices.ndim > 1 - if mode == 0: # sum - output = jnp.sum(embedded, axis=1) - elif mode == 1: # mean - output = jnp.mean(embedded, axis=1) - elif mode == 2: # max - output = jnp.max(embedded, axis=1) - return output, None, None, None - - if isinstance(offsets, jax.Array): - offsets_np = np.array(offsets) - else: - offsets_np = offsets - offset2bag = np.zeros(indices.shape[0], dtype=np.int64) - bag_size = np.zeros(offsets_np.shape[0], dtype=np.int64) - max_indices = jnp.full_like(indices, -1) - - for bag in range(offsets_np.shape[0]): - start = int(offsets_np[bag]) - - end = int(indices.shape[0] if bag + - 1 == offsets_np.shape[0] else offsets_np[bag + 1]) - bag_size[bag] = end - start - offset2bag = offset2bag.at[start:end].set(bag) - - if end - start > 0: - if mode == 0: - output_bag = jnp.sum(embedded[start:end], axis=0) - elif mode == 1: - output_bag = jnp.mean(embedded[start:end], axis=0) - elif mode == 2: - output_bag = jnp.max(embedded[start:end], axis=0) - max_indices = max_indices.at[start:end].set( - jnp.argmax(embedded[start:end], axis=0)) - - # The original code returned offset2bag, bag_size, and max_indices as numpy arrays. - # Converting them to JAX arrays for consistency. - offset2bag = jnp.array(offset2bag) - bag_size = jnp.array(bag_size) - - return output_bag, offset2bag, bag_size, max_indices - - -@op(torch.ops.aten.rsqrt) -@op_base.promote_int_input -def _aten_rsqrt(x): - return jax.lax.rsqrt(x) - - -@op(torch.ops.aten.expand) -@op(torch.ops.aten.expand_copy) -def _aten_expand(x, dims): - - def fix_dims(d, xs): - if d == -1: - return xs - return d - - shape = list(x.shape) - if len(shape) < len(dims): - shape = [ - 1, - ] * (len(dims) - len(shape)) + shape - # make sure that dims and shape is the same by - # left pad with 1s. Otherwise the zip below will - # truncate - dims = [fix_dims(p, s) for p, s in zip(dims, shape)] - return jnp.broadcast_to(x, dims) - - -@op(torch.ops.aten.dot) -def _aten_dot(x, y): - return jnp.dot(x, y) - - -@op(torch.ops.aten._to_copy) -def _aten__to_copy(self, **kwargs): - dtype = mappings.t2j_dtype(kwargs["dtype"]) - if dtype != self.dtype: - return self.astype(dtype) - return jnp.copy(self) - - -@op(torch.ops.aten.empty) -@op_base.convert_dtype(use_default_dtype=False) -def _aten_empty(size: Sequence[int], *, dtype=None, **kwargs): - return jnp.empty(size, dtype=dtype) - - -@op(torch.ops.aten.empty_like) -@op_base.convert_dtype(use_default_dtype=False) -def _aten_empty_like(input, *, dtype=None, **kwargs): - return jnp.empty_like(input, dtype) - - -@op(torch.ops.aten.ones) -@op_base.convert_dtype() -def _ones(size: Sequence[int], dtype=None, **kwargs): - return jnp.ones(size, dtype) - - -@op(torch.ops.aten.zeros) -@op_base.convert_dtype() -def _zeros(size: Sequence[int], dtype=None, **kwargs): - return jnp.zeros(size, dtype) - - -@op(torch.ops.aten.full) -@op_base.convert_dtype() -def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs): - # TODO: handle torch.Size - return jnp.full(size, fill_value, dtype=dtype) - - -@op(torch.ops.aten.empty_permuted) -@op_base.convert_dtype() -def _aten_empty_permuted(sizes, physical_layout, dtype=None, **kwargs): - # Ignore the physical layout, - # since JAX and torch tensor doesn't share the same memory. - return jnp.empty(sizes, dtype=dtype) - - -@op(torch.ops.aten.empty_strided) -@op_base.convert_dtype() -def _aten_empty_strided(sizes, stride, dtype=None, **kwargs): - # Ignore stride, since JAX and torch tensor doesn't share the same memory. - return jnp.empty(sizes, dtype=dtype) - - -@op(torch.ops.aten.index_put) -def _aten_index_put(self, indexes, values, accumulate=False): - indexes = [slice(None, None, None) if i is None else i for i in indexes] - indexes = tuple(indexes) - if accumulate: - return self.at[indexes].add(values) - else: - return self.at[indexes].set(values) - - -@op(torch.ops.aten.index) -@op(torch.ops.aten._unsafe_index) -@op(torch.ops.aten.index.Tensor) -def _aten_index(self, indexes): - indexes = [slice(None, None, None) if i is None else i for i in indexes] - indexes = tuple(indexes) - return self[indexes] - - -@op(torch.ops.aten.split) -@op(torch.ops.aten.split_copy) -@op(torch.ops.aten.split_with_sizes) -def split_with_sizes(x, sizes, dim=0): - """Splits an array `x` into sub-arrays based on static sizes `sizes`. - - Args: - x: The input array to split. - sizes: A 1D array of integer sizes for each sub-array. - - Returns: - A list of sub-arrays. - """ - if isinstance(sizes, int): - # split equal size, round up - new_sizes = [sizes] * (-(-x.shape[dim] // sizes)) - sizes = new_sizes - rank = x.ndim - splits = np.cumsum(sizes) # Cumulative sum for split points - - def make_range(rank, dim, start, end): - res = [slice(None, None, None)] * rank - res[dim] = slice(start, end) - return tuple(res) - - return [ - x[make_range(rank, dim, start, end)] - for start, end in zip([0] + list(splits[:-1]), splits) - ] - - -@op(torch.ops.aten.permute) -@op(torch.ops.aten.permute_copy) -def permute(t, dims): - # TODO: return a View instead - return jnp.transpose(t, dims) - - -@op(torch.ops.aten.unsqueeze) -@op(torch.ops.aten.unsqueeze_copy) -def _aten_unsqueeze(self, dim): - if dim < 0: - dim += self.ndim + 1 - return jnp.expand_dims(self, dim) - - -@op(torch.ops.aten.ne) -def _aten_ne(x, y): - return jnp.not_equal(x, y) - - -# Create indices along a specific axis -# -# For example -# x = jnp.zeros((3,4)) -# -# _indices_along_axis(x, axis=0) -# >> [[0], [1], [2]] shape (3, 1) -# -# _indices_along_axis(x, axis=1) -# >> [[0, 1, 2, 3]] shape (1, 4) -def _indices_along_axis(x, axis): - return jnp.expand_dims( - jnp.arange(x.shape[axis]), - axis=[d for d in range(len(x.shape)) if d != axis]) - - -def _broadcast_indices(indices, shape): - return jnp.broadcast_to(indices, shape) - - -@op(torch.ops.aten.cummax) -def _aten_cummax(x, dim): - if not x.shape: - return x, jnp.zeros_like(x, dtype=jnp.int64) - - axis = dim - - indice_along_axis = _indices_along_axis(x, axis) - indices = _broadcast_indices(indice_along_axis, x.shape) - - def cummax_reduce_func(carry, elem): - v1, v2 = carry['val'], elem['val'] - i1, i2 = carry['idx'], elem['idx'] - - v = jnp.maximum(v1, v2) - i = jnp.where(v1 > v2, i1, i2) - return {'val': v, 'idx': i} - - res = jax.lax.associative_scan( - cummax_reduce_func, { - 'val': x, - 'idx': indices - }, axis=axis) - return res['val'], res['idx'] - - -@op(torch.ops.aten.cummin) -def _aten_cummin(x, dim): - if not x.shape: - return x, jnp.zeros_like(x, dtype=jnp.int64) - - axis = dim - - indice_along_axis = _indices_along_axis(x, axis) - indices = _broadcast_indices(indice_along_axis, x.shape) - - def cummin_reduce_func(carry, elem): - v1, v2 = carry['val'], elem['val'] - i1, i2 = carry['idx'], elem['idx'] - - v = jnp.minimum(v1, v2) - i = jnp.where(v1 < v2, i1, i2) - return {'val': v, 'idx': i} - - res = jax.lax.associative_scan( - cummin_reduce_func, { - 'val': x, - 'idx': indices - }, axis=axis) - return res['val'], res['idx'] - - -@op(torch.ops.aten.cumsum) -def _aten_cumsum(x, y, dtype=None): - if dtype: - dtype = mappings.t2j_dtype(dtype) - if not x.shape: - return x - res = jnp.cumsum(x, y, dtype) - return res - - -@op(torch.ops.aten.cumprod) -def _aten_cumprod(input, dim, dtype=None, out=None): - if dtype: - dtype = mappings.t2j_dtype(dtype) - if len(input.shape) > 0: - res = jnp.cumprod(input, axis=dim, dtype=dtype) - elif dtype: - res = input.astype(dtype) - else: - res = input - return res - - -@op(torch.ops.aten.native_layer_norm) -def _aten_native_layer_norm(input, - normalized_shape, - weight=None, - bias=None, - eps=1e-5): - """Implements layer normalization in Jax as defined by `aten::native_layer_norm`. - - Args: - input: The input tensor. - normalized_shape: A list of integer dimensions to be normalized over. - weight: Optional weight tensor for the affine transformation. - bias: Optional bias tensor for the affine transformation. - eps: A small epsilon value for numerical stability. - - Returns: - output: The normalized tensor. - mean: The calculated mean tensor. - std: The calculated standard deviation tensor. - """ - if isinstance(normalized_shape, int): - normalized_shape = [normalized_shape] - axis = [len(input.shape) - i - 1 for i in range(len(normalized_shape))] - - # Calculate mean and standard deviation - mean = jnp.mean(input, axis=axis, keepdims=True) - var = jnp.var(input, axis=axis, keepdims=True) - rstd = jax.lax.rsqrt(var + eps) - - # Normalize the input - norm_x = (input - mean) * rstd - - # Apply affine transformation (if provided) - if weight is not None: - norm_x *= weight - if bias is not None: - norm_x += bias - return norm_x, mean, rstd - - -@op(torch.ops.aten.matmul) -def _aten_matmul(x, y): - return x @ y - - -# - func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor -@op(torch.ops.aten.addmm) -@op(torch.ops.aten.addmv) -def _aten_addmm(self, mat1, mat2, *, beta=1.0, alpha=1.0): - alpha = jnp.array(alpha).astype(mat1.dtype) - beta = jnp.array(beta).astype(mat1.dtype) - self *= beta - self += alpha * jnp.matmul(mat1, mat2) - return self - - -@op(torch.ops.aten.sparse_sampled_addmm) -def _aten_sparse_addmm(self, mat1, mat2, *, beta=1.0, alpha=1.0): - alpha = jnp.array(alpha).astype(mat1.dtype) - beta = jnp.array(beta).astype(mat1.dtype) - self *= beta - self += alpha * jnp.matmul(mat1, mat2) * (self != 0) - return self - - -@op(torch.ops.aten.addbmm.default) -def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1): - alpha = jnp.array(alpha).astype(batch1.dtype) - beta = jnp.array(beta).astype(batch1.dtype) - mm = jnp.einsum("bxy, byz -> xz", batch1, batch2) - return jax.lax.cond(beta == 0, lambda: alpha * mm, - lambda: beta * input + alpha * mm) - - -@op(torch.ops.aten.gelu) -def _aten_gelu(self, *, approximate="none"): - approx = approximate == "tanh" - return jax.nn.gelu(self, approx) - - -@op(torch.ops.aten.squeeze) -@op(torch.ops.aten.squeeze_copy) -def _aten_squeeze_dim(self, dim=None): - if self.ndim == 0: - return self - if dim is not None: - if isinstance(dim, int): - if self.shape[dim] != 1: - return self - if dim < 0: - dim += self.ndim - else: - # NOTE: torch leaves the dims that is not 1 unchanged, - # but jax raises error. - dim = [ - i if i >= 0 else (i + self.ndim) for i in dim if self.shape[i] == 1 - ] - - return jnp.squeeze(self, dim) - - -@op(torch.ops.aten.bucketize) -def _aten_bucketize(input, - boundaries, - *, - out_int32=False, - right=False, - out=None): - return_type = jnp.int32 if out_int32 else jnp.int64 - return jnp.digitize(input, boundaries, right=not right).astype(return_type) - - -@op(torch.ops.aten.conv2d) -def _aten_conv2d( - input, - weight, - bias=None, - stride=[1, 1], - padding=[0, 0], - dilation=[1, 1], - groups=1, -): - return _aten_convolution( - input, - weight, - bias, - stride, - padding, - dilation, - transposed=False, - output_padding=1, - groups=groups) - - -@op(torch.ops.aten.convolution) -def _aten_convolution( - input, - weight, - bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups, -): - num_shape_dim = weight.ndim - 1 - batch_dims = input.shape[:-num_shape_dim] - - input = input.reshape((-1, *input.shape[-num_shape_dim:])) - - def make_padding(padding, num_spatial_dims): - # Expand single padding to pairs expected by jax - if len(padding) == 1 and len(padding) < num_spatial_dims: - padding *= num_spatial_dims - if transposed: - # See https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html - pad_out = [] - for i in range(num_spatial_dims): - front = dilation[i] * (weight.shape[i + 2] - 1) - padding[i] - back = front + output_padding[i] - pad_out.append((front, back)) - return pad_out - else: - return ((p, p) for p in padding) - - def create_default_conv_dimension_numbers(num_spatial_dims): - # Ref: https://github.com/openxla/xla/blob/main/xla/client/xla_builder.cc#L4211 - # (batch dimension, feature dimension, spatial dimensions...) - lhs_spec = [0, 1] - # (out feature dimension, in feature dimension, spatial dimensions...) - # swapped for transposed convolution - rhs_spec = [1, 0] if transposed else [0, 1] - # (batch dimension, feature dimension, spatial dimensions...) - out_spec = [0, 1] - for i in range(0, num_spatial_dims): - lhs_spec.append(i + 2) - rhs_spec.append(i + 2) - out_spec.append(i + 2) - return jax.lax.ConvDimensionNumbers( - *map(tuple, (lhs_spec, rhs_spec, out_spec))) - - if transposed: - rhs = jnp.flip(weight, range(2, 1 + num_shape_dim)) - if groups != 1: - # reshape filters for tranposed depthwise convolution - assert rhs.shape[0] % groups == 0 - rhs_shape = [rhs.shape[0] // groups, rhs.shape[1] * groups] - rhs_shape.extend(rhs.shape[2:]) - rhs = jnp.reshape(rhs, rhs_shape) - res = jax.lax.conv_general_dilated( - input, - rhs, - (1,) * len(stride), - make_padding(padding, len(stride)), - lhs_dilation=stride, - rhs_dilation=dilation, - dimension_numbers=create_default_conv_dimension_numbers(len(stride)), - feature_group_count=groups, - batch_group_count=1, - ) - else: - res = jax.lax.conv_general_dilated( - input, - weight, - stride, - make_padding(padding, len(stride)), - lhs_dilation=(1,) * len(stride), - rhs_dilation=dilation, - dimension_numbers=create_default_conv_dimension_numbers(len(stride)), - feature_group_count=groups, - batch_group_count=1, - ) - - if bias is not None: - # TODO(qihqi): bias always on channel? - if len(bias.shape) == 1: - shape = [1] * len(res.shape) - shape[1] = bias.shape[0] - bias = bias.reshape(tuple(shape)) - res = res + bias - - res = res.reshape((*batch_dims, *res.shape[-num_shape_dim:])) - return res - - -# _native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) -@op(torch.ops.aten._native_batch_norm_legit.default) -def _aten__native_batch_norm_legit(input, weight, bias, running_mean, - running_var, training, momentum, eps): - """JAX implementation of batch normalization with optional parameters. - Refers to https://github.com/pytorch/pytorch/blob/cd3a71f754a2248bcfe500de7c9860bd7d2002bf/torch/_decomp/decompositions.py#L1713. - - Args: - input (DeviceArray): Input data (N, C, H, W). - running_mean ([DeviceArray]): Running mean of input (C,). - running_var ([DeviceArray]): Running variance of input (C,). - weight (Optional[DeviceArray]): Scaling factor (gamma) (C,). Can be None. - bias (Optional[DeviceArray]): Shift factor (beta) (C,). Can be None. - training (bool): If True, use batch statistics for normalization. - If False, use running statistics. - momentum (float): Momentum factor for updating running statistics. - eps (float): Small constant for numerical stability. - - Returns: - DeviceArray: Normalized output - DeviceArray: Batch mean (C,) or empty if training is False - DeviceArray: Reversed batch variance (C,) or empty if training is False - """ - reduction_dims = [0] + list(range(2, input.ndim)) - reshape_dims = [1, -1] + [1] * (input.ndim - 2) - if training: - # Calculate batch mean and variance - mean = jnp.mean(input, axis=reduction_dims, keepdims=True) - saved_mean = jnp.squeeze(mean, reduction_dims) - var = jnp.var(input, axis=reduction_dims) - rstd = jax.lax.rsqrt(var.reshape(reshape_dims) + eps) - # Update running statistics using momentum - running_mean = (1 - momentum) * running_mean + momentum * saved_mean - running_var = (1 - momentum) * running_var + momentum * var - saved_rstd = jnp.squeeze(rstd, reduction_dims) - else: - rstd = jax.lax.rsqrt(running_var.reshape(reshape_dims) + eps) - saved_mean = jnp.array( - [], dtype=input.dtype - ) # No need to calculate batch statistics in inference mode - saved_rstd = jnp.array([], dtype=input.dtype) - - # Normalize - if training: - # use batch statistics if training - x_hat = (input - mean) * rstd - else: - # Use running statistics in inference mode - x_hat = (input - running_mean.reshape(reshape_dims)) * rstd - - # Scale and shift - if weight is not None: - x_hat *= weight.reshape(reshape_dims) # Reshape weight for broadcasting - if bias is not None: - x_hat += bias.reshape(reshape_dims) # Reshape bias for broadcasting - - return x_hat, saved_mean, saved_rstd - - -@op(torch.ops.aten._native_batch_norm_legit_no_training) -def _aten__native_batch_norm_legit_no_training(input, weight, bias, - running_mean, running_var, - momentum, eps): - return _aten__native_batch_norm_legit(input, weight, bias, running_mean, - running_var, False, momentum, eps) - - -@op(torch.ops.aten.relu) -def _aten_relu(self): - return jax.nn.relu(self) - - -@op(torch.ops.aten.cat) -def _aten_cat(tensors, dims=0): - # handle empty tensors as a special case. - # torch.cat will ignore the empty tensor, while jnp.concatenate - # will error if the dims > 0. - filtered_tensors = [ - t for t in tensors if not (t.ndim == 1 and t.shape[0] == 0) - ] - if filtered_tensors: - return jnp.concatenate(filtered_tensors, dims) - return tensors[0] - - -def _ceil_mode_padding( - padding: list[int], - input_shape: list[int], - kernel_size: list[int], - stride: list[int], - dilation: list[int], - ceil_mode: bool, -): - """Creates low and high padding specification for the given padding (which is symmetric) and ceil mode. - - Additional high padding could be required when ceil mode is set. - """ - ceil_mode_padding = [] - for i in range(len(padding)): - left_padding = padding[i] - right_padding = left_padding - - input_size = input_shape[2 + i] - output_size_rem = (input_size + 2 * left_padding - - (kernel_size[i] - 1) * dilation[i] - 1) % stride[i] - if ceil_mode and output_size_rem != 0: - extra_padding = stride[i] - output_size_rem - new_output_size = (input_size + left_padding + right_padding + - extra_padding - (kernel_size[i] - 1) * dilation[i] - - 1 + stride[i] - 1) // stride[i] + 1 - # Ensure that the last pooling starts inside the image. - size_to_compare = input_size + left_padding - - if (new_output_size - 1) * stride[i] < size_to_compare: - right_padding += extra_padding - - ceil_mode_padding.append((left_padding, right_padding)) - return ceil_mode_padding - - -@op(torch.ops.aten.max_pool2d_with_indices) -@op(torch.ops.aten.max_pool3d_with_indices) -def _aten_max_pool2d_with_indices(inputs, - kernel_size, - strides=None, - padding=0, - dilation=1, - ceil_mode=False): - num_batch_dims = len(inputs.shape) - len(kernel_size) - 1 - kernel_size = tuple(kernel_size) - # Default stride is kernel_size - strides = tuple(strides) if strides else kernel_size - if isinstance(padding, int): - padding = [padding for _ in range(len(kernel_size))] - if isinstance(dilation, int): - dilation = tuple(dilation for _ in range(len(kernel_size))) - elif isinstance(dilation, list): - dilation = tuple(dilation) - - input_shape = inputs.shape - if num_batch_dims == 0: - input_shape = [1, *input_shape] - padding = _ceil_mode_padding(padding, input_shape, kernel_size, strides, - dilation, ceil_mode) - - assert len(kernel_size) == len( - strides), f"len({kernel_size=}) must equal len({strides=})" - assert len(kernel_size) == len( - dilation), f"len({kernel_size=}) must equal len({dilation=})" - strides = (1,) * (1 + num_batch_dims) + strides - dims = (1,) * (1 + num_batch_dims) + kernel_size - dilation = (1,) * (1 + num_batch_dims) + dilation - - is_single_input = False - if num_batch_dims == 0: - # add singleton batch dimension because lax.reduce_window always - # needs a batch dimension. - inputs = inputs[None] - strides = (1,) + strides - dims = (1,) + dims - dilation = (1,) + dilation - is_single_input = True - - assert inputs.ndim == len(dims), f"len({inputs.shape}) != len({dims})" - if not isinstance(padding, str): - padding = tuple(map(tuple, padding)) - assert len(padding) == len(kernel_size), ( - f"padding {padding} must specify pads for same number of dims as " - f"kernel_size {kernel_size}") - assert all([len(x) == 2 for x in padding - ]), f"each entry in padding {padding} must be length 2" - padding = ((0, 0), (0, 0)) + padding - - indices = jnp.arange(np.prod(inputs.shape[-len(kernel_size):])) - indices = indices.reshape(inputs.shape[-len(kernel_size):]) - indices = jnp.broadcast_to(indices, inputs.shape) - - def reduce_fn(a, b): - ai, av = a - bi, bv = b - which = av >= bv # torch breaks ties in favor of later indices - return jnp.where(which, ai, bi), jnp.where(which, av, bv) - - init_val = -jnp.inf - if inputs.dtype in (jnp.int32, jnp.int64): - init_val = -(1 << 31) - init_val = jnp.array(init_val).astype(inputs.dtype) - - # Separate maxpool result and indices into two reduce_window ops. Since - # the indices tensor is usually unused in inference, separating the two - # can help DCE computations for argmax. - y = jax.lax.reduce_window( - inputs, - init_val, - jax.lax.max, - dims, - strides, - padding, - window_dilation=dilation) - indices, _ = jax.lax.reduce_window( - (indices, inputs), - (0, init_val), - reduce_fn, - dims, - strides, - padding, - window_dilation=dilation, - ) - if is_single_input: - indices = jnp.squeeze(indices, axis=0) - y = jnp.squeeze(y, axis=0) - - return y, indices - - -# Aten ops registered under the `xla` library. -try: - - @op(torch.ops.xla.max_pool2d_forward) - def _xla_max_pool2d_forward(*args, **kwargs): - return _aten_max_pool2d_with_indices(*args, **kwargs)[0] - - @op(torch.ops.xla.aot_mark_sharding) - def _xla_aot_mark_sharding(t, mesh: str, partition_spec: str): - from jax.sharding import PartitionSpec as P, NamedSharding - import ast - import torch_xla.distributed.spmd as xs - pmesh = xs.Mesh.from_str(mesh) - assert pmesh is not None - partition_spec_eval = ast.literal_eval(partition_spec) - jmesh = pmesh.get_jax_mesh() - return jax.lax.with_sharding_constraint( - t, NamedSharding(jmesh, P(*partition_spec_eval))) - - @op(torch.ops.xla.einsum_linear_forward) - def _xla_einsum_linear_forward(input, weight, bias): - with jax.named_scope('einsum_linear_forward'): - product = jax.numpy.einsum('...n,mn->...m', input, weight) - if bias is not None: - return product + bias - return product - -except AttributeError: - pass - -# TODO add more ops - - -@op(torch.ops.aten.min) -def _aten_min(x, dim=None, keepdim=False): - if dim is not None: - return _with_reduction_scalar(jnp.min, x, dim, - keepdim), _with_reduction_scalar( - jnp.argmin, x, dim, - keepdim).astype(jnp.int64) - else: - return _with_reduction_scalar(jnp.min, x, dim, keepdim) - - -@op(torch.ops.aten.mode) -def _aten_mode(input, dim=-1, keepdim=False, *, out=None): - if input.ndim == 0: # single number - return input, jnp.array(0) - dim = (input.ndim + - dim) % input.ndim # jnp.scipy.stats.mode does not accept -1 as dim - # keepdims must be True for accurate broadcasting - mode, _ = jax.scipy.stats.mode(input, axis=dim, keepdims=True) - mode_broadcast = jnp.broadcast_to(mode, input.shape) - if not keepdim: - mode = mode.squeeze(axis=dim) - indices = jnp.argmax( - jnp.equal(mode_broadcast, input), axis=dim, keepdims=keepdim) - return mode, indices - - -@op(torch.ops.aten.amin) -def _aten_amin(x, dim=None, keepdim=False): - return _with_reduction_scalar(jnp.amin, x, dim, keepdim) - - -@op(torch.ops.aten.argmin) -def _aten_argmin(self, dim=None, keepdim=False): - return _with_reduction_scalar(jnp.argmin, self, dim, keepdim) - - -@op(torch.ops.aten.sin) -@op_base.promote_int_input -def _aten_sin(x): - return jnp.sin(x) - - -@op(torch.ops.aten.sym_size) -def _aten_sym_size(x, dim): - return x.shape[dim] - - -@op(torch.ops.aten.var.correction) -@op(torch.ops.prims.var) -def _aten_var(x, dim=None, *, correction=1, keepdim=False, out=None): - return jnp.var(x, axis=dim, ddof=correction, keepdims=keepdim) - - -@op(torch.ops.prims.broadcast_in_dim) -def _prims_broadcast_in_dim(t, shape, broadcast_dimensions): - return jax.lax.broadcast_in_dim( - t, shape, broadcast_dimensions=broadcast_dimensions) - - -# aten.native_group_norm -- should use decomp table -# func: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor) - - -@op(torch.ops.aten.native_group_norm) -def _aten_native_group_norm(input, weight, bias, N, C, HxW, group, eps=1e-5): - """Group Normalization implementation in JAX. - - Args: - input: Input tensor. Expected shape (batch_size, channels, ... spatial dims - ...) - weight: Optional scaling (gamma) parameter. Shape (channels,) - bias: Optional shifting (beta) parameter. Shape (channels,) - N: Batch size. - C: Number of channels. - HxW: Product of spatial dimensions (number of elements per channel after - flattening). - group: Number of groups for Group Normalization. - eps: Small value added for numerical stability. - - Returns: - A tuple of (normalized_output, mean, rstd) - """ - - input_shape = input.shape - - if 0 in input_shape: - return input, input, input - - # Reshape for group-wise normalization - reshaped_input = jnp.reshape(input, (1, N * group, -1)) - - # **Core Group Normalization** - def group_norm_body(x): # Function to apply within each group - mean = jnp.mean(x, axis=-1, keepdims=True) - var = jnp.var(x, axis=-1, keepdims=True) - rstd = jax.lax.rsqrt(var + eps) # Reciprocal of std with epsilon - normalized = (x - mean) * rstd - return normalized, mean, rstd - - normalized, group_mean, group_rstd = jax.lax.map(group_norm_body, - reshaped_input) - - # Reshape back to original input shape - output = jnp.reshape(normalized, input_shape) - - # **Affine transformation** - affine_shape = [-1 if i == 1 else 1 for i in range(input.ndim) - ] # Shape for broadcasting - if weight is not None and bias is not None: - output = bias.reshape(affine_shape) + output * weight.reshape(affine_shape) - elif weight is not None: - output = output * weight.reshape(affine_shape) - elif bias is not None: - output = output + bias.reshape(affine_shape) - - # Reshape mean and rstd - mean = jnp.reshape(group_mean, (N, group)) - rstd = jnp.reshape(group_rstd, (N, group)) - - return output, mean, rstd - - -@op(torch.ops.aten.linalg_vector_norm) -def _aten_linalg_vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None): - """Calculates the vector norm along specified dimensions. - - Args: - self: The input tensor. - ord: The order of the norm. Can be a float or 'inf', '-inf', 'fro'. - Default is 2 (Euclidean norm). - dim: Dimensions along which to calculate the norm. If None, the norm is - calculated over all dimensions. - keepdim: Whether to keep the reduced dimensions. - dtype: Optional data type for the output. - - Returns: - The tensor containing the calculated vector norms. - """ - - if ord not in {2, float("inf"), float("-inf"), "fro" - } and not isinstance(ord, (int, float)): - raise ValueError( - f"Unsupported ord value: {ord}. Supported values are 2, inf, -inf, and" - " 'fro'.") - - # Special cases (for efficiency and clarity) - if ord == 0: - if self.shape == (): - # float sets it to float64. set it back to input type - result = jnp.astype(jnp.array(float(self != 0)), self.dtype) - else: - result = _with_reduction_scalar(jnp.sum, jnp.where(self != 0, 1, 0), dim, - keepdim) - - elif ord == 2: # Euclidean norm - result = jnp.sqrt( - _with_reduction_scalar(jnp.sum, - jnp.abs(self)**2, dim, keepdim)) - - elif ord == float("inf"): - result = _with_reduction_scalar(jnp.max, jnp.abs(self), dim, keepdim) - - elif ord == float("-inf"): - result = _with_reduction_scalar(jnp.min, jnp.abs(self), dim, keepdim) - - elif ord == "fro": # Frobenius norm - result = jnp.sqrt( - _with_reduction_scalar(jnp.sum, - jnp.abs(self)**2, dim, keepdim)) - - else: # General case (e.g., ord = 1, ord = 3) - result = _with_reduction_scalar(jnp.sum, - jnp.abs(self)**ord, dim, - keepdim)**(1.0 / ord) - - # (Optional) dtype conversion - if dtype is not None: - result = jnp.astype(result, self.dtype) - - new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) - if result.dtype == jax.numpy.int64: - result = result.astype(new_dtype) - return result - - -# aten.reflection_pad1d -@op(torch.ops.aten.reflection_pad1d) -def _aten_reflection_pad1d(input, padding): - rank = len(input.shape) - pad_size = [(0, 0)] * rank - pad_size[-1] = padding - return jnp.pad(input, pad_size, mode="reflect") - - -# aten.alias -@op(torch.ops.aten.alias) -def _aten_alias(self, *args): - return self - - -# aten.sinh -@op(torch.ops.aten.sinh) -@op_base.promote_int_input -def _aten_sinh(self): - return jnp.sinh(self) - - -# aten.native_layer_norm_backward -@op(torch.ops.aten.native_layer_norm_backward) -def _aten_native_layer_norm_backward(grad_out, - input, - normalized_shape, - weight, - bias, - eps=1e-5): - """Implements the backward pass of layer normalization in Jax as defined by `aten::native_layer_norm_backward`. - - Args: - grad_out: The gradient of the output tensor. - input: The input tensor. - normalized_shape: A list of integer dimensions to be normalized over. - weight: Optional weight tensor for the affine transformation. - bias: Optional bias tensor for the affine transformation. - eps: A small epsilon value for numerical stability. - - Returns: - A tuple of (grad_input, grad_weight, grad_bias). - """ - return jax.lax.native_layer_norm_backward(grad_out, input, normalized_shape, - weight, bias, eps) - - -# aten.reflection_pad3d_backward -# aten.reflection_pad2d - - -# aten.atanh -@op(torch.ops.aten.atanh) -@op_base.promote_int_input -def _aten_atanh(self): - res = jnp.arctanh(self) - return res - - -# aten.bincount -@op(torch.ops.aten.bincount) -def _aten_bincount(input, weights=None, minlength=0): - return jnp.bincount(input, weights, minlength) - - -# aten.bitwise_not -@op(torch.ops.aten.bitwise_not) -def _aten_bitwise_not(self): - return ~self - - -# aten.bitwise_left_shift -@op(torch.ops.aten.__lshift__) -@op(torch.ops.aten.bitwise_left_shift) -def _aten_bitwise_left_shift(input, other): - return jnp.left_shift(input, other) - - -# aten.bitwise_right_shift -@op(torch.ops.aten.__rshift__) -@op(torch.ops.aten.bitwise_right_shift) -def _aten_bitwise_right_shift(input, other): - return jnp.right_shift(input, other) - - -# aten.embedding_dense_backward - - -# aten.sum -@op(torch.ops.aten.sum) -def _aten_sum(self, dim=None, keepdim=False, dtype=None): - if not dim: - dim = None - return _with_reduction_scalar(jnp.sum, self, dim, keepdim) - - -# aten.sqrt -@op(torch.ops.aten.sqrt) -@op_base.promote_int_input -def _aten_sqrt(self): - return jnp.sqrt(self) - - -@op(torch.ops.aten.tan) -@op_base.promote_int_input -def _aten_tanh(self): - res = jnp.tan(self) - return res - - -# aten.tanh -@op(torch.ops.aten.tanh) -@op_base.promote_int_input -def _aten_tanh(self): - res = jnp.tanh(self) - return res - - -# aten.ceil -@op(torch.ops.aten.ceil) -def _aten_ceil(self): - return jnp.ceil(self).astype(self) - - -# aten.asin -@op(torch.ops.aten.asin) -@op_base.promote_int_input -def _aten_asin(self): - res = jnp.arcsin(self) - return res - - -# aten.minimum -@op(torch.ops.aten.minimum) -def _aten_minimum(self, other): - return jnp.minimum(self, other) - - -# aten.max_pool2d_backward - - -def _scatter_index(dim, index): - """Returns a tuple of indexes; - - The first is to select in input (to modify), - the second is to select from the values. - """ - index_shape = list(index.shape) - input_indexes = [] - source_indexes = [] - if dim < 0: - dim += len(index_shape) - for i in range(len(index_shape)): - source_indexes.append(slice(0, index_shape[i])) - if i == dim: - input_indexes.append(index) - else: - target_shape = [1] * len(index_shape) - target_shape[i] = index_shape[i] - input_indexes.append( - jnp.broadcast_to( - jnp.arange(index_shape[i]).reshape(target_shape), index_shape)) - return tuple(input_indexes), tuple(source_indexes) - - -# aten.scatter_add -@op(torch.ops.aten.scatter_add) -def _aten_scatter_add(input, dim, index, src): - """JAX implementation of scatter, mimicking torch.scatter behavior""" - - input_indexes, source_indexes = _scatter_index(dim, index) - return input.at[input_indexes].add(src[source_indexes]) - - -# aten.masked_scatter -@op(torch.ops.aten.masked_scatter) -def _aten_masked_scatter(self, mask, source): - - broadcast_shape = jnp.broadcast_shapes(self.shape, mask.shape) - - if self.shape != broadcast_shape: - self = jnp.broadcast_to(self, broadcast_shape) - elif mask.shape != broadcast_shape: - mask = jnp.broadcast_to(mask, broadcast_shape) - - self_flat = self.flatten() - mask_flat = mask.flatten() - source_flat = source.flatten() - - true_indices = jnp.where(mask_flat)[0] - self_flat = self_flat.at[true_indices].set(source_flat[:len(true_indices)]) - final_arr = self_flat.reshape(self.shape) - - return final_arr - - -@op(torch.ops.aten.masked_select) -def _aten_masked_select(self, mask, *args, **kwargs): - broadcast_shape = jnp.broadcast_shapes(self.shape, mask.shape) - - if self.shape != broadcast_shape: - self = jnp.broadcast_to(self, broadcast_shape) - if mask.shape != broadcast_shape: - mask = jnp.broadcast_to(mask, broadcast_shape) - - self_flat = self.flatten() - mask_flat = mask.flatten() - true_indices = jnp.where(mask_flat)[0] - - return self_flat[true_indices] - - -# aten.logical_not - - -# aten.sign -@op(torch.ops.aten.sign) -def _aten_sign(x): - return jnp.sign(x) - - -# aten.signbit -@op(torch.ops.aten.signbit) -def _aten_signbit(x): - return jnp.signbit(x) - - -# aten.sigmoid -@op(torch.ops.aten.sigmoid) -@op_base.promote_int_input -def _aten_sigmoid(x): - return jax.nn.sigmoid(x) - - -# implement aten.asinh in jax -@op(torch.ops.aten.asinh) -@op_base.promote_int_input -def _aten_asinh(self): - res = jnp.arcsinh(self) - return res - - -# aten.atan -@op(torch.ops.aten.atan) -@op_base.promote_int_input -def _aten_atan(self): - res = jnp.arctan(self) - return res - - -@op(torch.ops.aten.scatter_reduce) -@op(torch.ops.aten.scatter) -def _aten_scatter_reduce(input, - dim, - index, - src, - reduce=None, - *, - include_self=True): - if not isinstance(src, jnp.ndarray): - src = jnp.array(src, dtype=input.dtype) - input_indexes, source_indexes = _scatter_index(dim, index) - # "Zero out" target elements when not included - if not include_self: - if reduce in ["sum", "mean"]: - base_input = jnp.zeros_like(src) - elif reduce == "prod": - base_input = jnp.ones_like(src) - elif reduce == "amax": - base_input = jnp.full_like(src, -jnp.inf) - else: # amin - base_input = jnp.full_like(src, jnp.inf) - input = input.at[input_indexes].set(base_input[source_indexes]) - - if reduce == "sum" or reduce == "add": - return input.at[input_indexes].add(src[source_indexes]) - elif reduce == "prod" or reduce == "multiply": - return input.at[input_indexes].multiply(src[source_indexes]) - elif reduce == "mean": - if include_self: - count = jnp.ones_like(input) - else: - count = jnp.zeros_like(input) - count = count.at[input_indexes].add(jnp.ones_like(src)[source_indexes]) - count = jnp.clip(count, min=1) - mean = input.at[input_indexes].add(src[source_indexes]) - if _is_int(input): - return mean // count - return mean / count - elif reduce == "amax": - return input.at[input_indexes].max(src[source_indexes]) - elif reduce == "amin": - return input.at[input_indexes].min(src[source_indexes]) - else: - return input.at[input_indexes].set(src[source_indexes]) - - -# aten.acos -@op(torch.ops.aten.acos) -@op_base.promote_int_input -def _aten_acos(self): - return jnp.arccos(self) - - -# aten.sym_storage_offset -# aten.native_layer_norm_backward -# aten.max_pool3d_with_indices - - -# aten.gt -@op(torch.ops.aten.gt) -def _aten_gt(self, other): - return self > other - - -# aten.sym_stride -# aten.lt -@op(torch.ops.aten.lt) -def _aten_lt(self, other): - return self < other - - -def pool(inputs, init, reduce_fn, window_shape, strides, padding): - """Helper function to define pooling functions. - - Pooling functions are implemented using the ReduceWindow XLA op. - NOTE: Be aware that pooling is not generally differentiable. - That means providing a reduce_fn that is differentiable does not imply that - pool is differentiable. - - Args: - inputs: input data with dimensions (batch, window dims..., features). - init: the initial value for the reduction - reduce_fn: a reduce function of the form ``(T, T) -> T``. - window_shape: a shape tuple defining the window to reduce over. - strides: a sequence of ``n`` integers, representing the inter-window - strides (default: ``(1, ..., 1)``). - padding: either the string ``'SAME'``, the string ``'VALID'``, or a sequence - of ``n`` ``(low, high)`` integer pairs that give the padding to apply before - and after each spatial dimension. - Returns: - The output of the reduction for each window slice. - """ - num_batch_dims = inputs.ndim - (len(window_shape) + 1) - strides = strides or (1,) * len(window_shape) - assert len(window_shape) == len( - strides), f"len({window_shape}) must equal len({strides})" - strides = (1,) * (1 + num_batch_dims) + strides - dims = (1,) * (1 + num_batch_dims) + window_shape - - is_single_input = False - if num_batch_dims == 0: - # add singleton batch dimension because lax.reduce_window always - # needs a batch dimension. - inputs = inputs[None] - strides = (1,) + strides - dims = (1,) + dims - is_single_input = True - - assert inputs.ndim == len(dims), f"len({inputs.shape}) != len({dims})" - if not isinstance(padding, str): - padding = tuple(map(tuple, padding)) - assert len(padding) == len(window_shape), ( - f"padding {padding} must specify pads for same number of dims as " - f"window_shape {window_shape}") - assert all([len(x) == 2 for x in padding - ]), f"each entry in padding {padding} must be length 2" - padding = ((0, 0), (0, 0)) + padding - y = jax.lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding) - if is_single_input: - y = jnp.squeeze(y, axis=0) - return y - - -@op(torch.ops.aten._adaptive_avg_pool2d) -@op(torch.ops.aten._adaptive_avg_pool3d) -def adaptive_avg_pool2or3d(input: jnp.ndarray, - output_size: Tuple[int, int]) -> jnp.ndarray: - """ - Applies a 2/3D adaptive average pooling over an input signal composed of several input planes. - - See :class:`~torch.nn.AdaptiveAvgPool2d` for details and output shape. - - Args: - input: input tensor - output_size: the target output size (single integer or double-integer tuple) - - Context: - https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py#L2401 - """ - shape = input.shape - ndim = len(shape) - out_dim = len(output_size) - num_spatial_dim = ndim - out_dim - - # Preconditions - - assert ndim in ( - out_dim + 1, out_dim + 2 - ), f"adaptive_avg_pool{num_spatial_dim}d(): Expected {num_spatial_dim+1}D or {num_spatial_dim+2}D tensor, but got {ndim}" - for d in input.shape[-2:]: - assert d != 0, "adaptive_avg_pool{num_spactial_dim}d(): Expected input to have non-zero size for " \ - f"non-batch dimensions, but input has shape {tuple(shape)}." - - # Optimisation (we should also do this in the kernel implementation) - if all(s % o == 0 for o, s in zip(output_size, shape[-out_dim:])): - stride = tuple(i // o for i, o in zip(shape[-out_dim:], output_size)) - kernel = tuple(i - (o - 1) * s - for i, o, s in zip(shape[-out_dim:], output_size, stride)) - return _aten_avg_pool( - input, - kernel, - strides=stride, - ) - - def start_index(a, b, c): - return (a * c) // b - - def end_index(a, b, c): - return ((a + 1) * c + b - 1) // b - - def compute_idx(in_size, out_size): - orange = jnp.arange(out_size, dtype=jnp.int64) - i0 = start_index(orange, out_size, in_size) - # Let length = end_index - start_index, i.e. the length of the pooling kernels - # length.max() can be computed analytically as follows: - maxlength = in_size // out_size + 1 - in_size_mod = in_size % out_size - # adaptive = True iff there are kernels with different lengths - adaptive = not (in_size_mod == 0 or out_size % in_size_mod == 0) - if adaptive: - maxlength += 1 - elif in_size_mod == 0: - maxlength -= 1 - - range_max = jnp.arange(maxlength, dtype=jnp.int64) - idx = i0[:, None] + range_max - if adaptive: - # Need to clamp to avoid accessing out-of-bounds memory - idx = jnp.minimum(idx, in_size - 1) - - # Compute the length - i1 = end_index(orange, out_size, in_size) - length = i1 - i0 - else: - length = maxlength - return idx, length, range_max, adaptive - - idx, length, range_max, adaptive = [[None] * out_dim for _ in range(4)] - # length is not None if it's constant, otherwise we'll need to compute it - for i, (s, o) in enumerate(zip(shape[-out_dim:], output_size)): - idx[i], length[i], range_max[i], adaptive[i] = compute_idx(s, o) - - def _unsqueeze_to_dim(x, dim): - ndim = len(x.shape) - return jax.lax.expand_dims(x, tuple(range(ndim, dim))) - - if out_dim == 2: - # NOTE: unsqueeze to insert extra 1 in ranks; so they - # would broadcast - vals = input[..., _unsqueeze_to_dim(idx[0], 4), idx[1]] - reduce_axis = (-3, -1) - else: - assert out_dim == 3 - vals = input[..., - _unsqueeze_to_dim(idx[0], 6), - _unsqueeze_to_dim(idx[1], 4), idx[2]] - reduce_axis = (-5, -3, -1) - - # Shortcut for the simpler case - if not any(adaptive): - return jnp.mean(vals, axis=reduce_axis) - - def maybe_mask(vals, length, range_max, adaptive, dim): - if isinstance(length, int): - return vals, length - else: - # zero-out the things we didn't really want to select - assert dim < 0 - # hack - mask = range_max >= length[:, None] - if dim == -2: - mask = _unsqueeze_to_dim(mask, 4) - elif dim == -3: - mask = _unsqueeze_to_dim(mask, 6) - vals = jnp.where(mask, 0.0, vals) - # Compute the length of each window - length = _unsqueeze_to_dim(length, -dim) - return vals, length - - for i in range(len(length)): - vals, length[i] = maybe_mask( - vals, length[i], range_max[i], adaptive=adaptive[i], dim=(i - out_dim)) - - # We unroll the sum as we assume that the kernels are going to be small - ret = jnp.sum(vals, axis=reduce_axis) - # NOTE: math.prod because we want to expand it to length[0] * length[1] * ... - # this is multiplication with broadcasting, not regular pointwise product - return ret / math.prod(length) - - -@op(torch.ops.aten.avg_pool1d) -@op(torch.ops.aten.avg_pool2d) -@op(torch.ops.aten.avg_pool3d) -def _aten_avg_pool( - inputs, - kernel_size, - strides=None, - padding=0, - ceil_mode=False, - count_include_pad=True, - divisor_override=None, -): - num_batch_dims = len(inputs.shape) - len(kernel_size) - 1 - kernel_size = tuple(kernel_size) - strides = tuple(strides) if strides else kernel_size - if isinstance(padding, list) and len(padding) == 1: - padding = padding[0] - if isinstance(padding, int): - padding = [padding for _ in range(len(kernel_size))] - - input_shape = inputs.shape - if num_batch_dims == 0: - input_shape = [1, *input_shape] - padding = _ceil_mode_padding(padding, input_shape, kernel_size, strides, - [1] * len(kernel_size), ceil_mode) - - y = pool(inputs, 0.0, jax.lax.add, kernel_size, strides, padding) - if divisor_override is not None: - y = y / jnp.array(divisor_override, y.dtype) - elif count_include_pad: - div_shape = list(y.shape) - div_by = jnp.ones(div_shape, y.dtype) * np.prod(kernel_size) - unequal_paddings = map(lambda pad: pad[0] != pad[1], padding) - unequal_padding_indices = np.where(list(unequal_paddings))[0] - if len(unequal_padding_indices) > 0: - # indices to update kernel size - offset = len(div_shape) - len(padding) - skip_indices = list(map(lambda x: x + offset, unequal_padding_indices)) - indices = _generate_indices(div_shape, skip_dim_indices=skip_indices) - # updated kernel size accounting for maximum padding - new_kernel_size = list(kernel_size) - for j in unequal_padding_indices: - new_kernel_size[j] = kernel_size[j] - padding[j][1] + padding[j][0] - - for idx in indices: - for j in unequal_padding_indices: - idx[j + offset] = -1 - div_by = div_by.at[tuple(idx)].set(np.prod(new_kernel_size)) - - y = y / div_by - else: - div_shape = list(inputs.shape) - div_shape[num_batch_dims] = 1 - div_shape = tuple(div_shape) - if len(div_shape) - 2 == len(kernel_size): - div_shape = (1,) + div_shape[1:] - y = y / pool( - jnp.ones(div_shape, y.dtype), - jnp.array(0.0, y.dtype), - jax.lax.add, - kernel_size, - strides, - padding, - ) - return y.astype(inputs.dtype) - - -# helper function to generate all indices to iterate through ndarray -def _generate_indices(dims, skip_dim_indices=[]): - res = [] - - def _helper(curr_dim_idx, sofar): - if curr_dim_idx in skip_dim_indices: - _helper(curr_dim_idx + 1, sofar[:]) - return - if curr_dim_idx >= len(dims): - res.append(sofar) - return - for i in range(dims[curr_dim_idx]): - sofar[curr_dim_idx] = i - _helper(curr_dim_idx + 1, sofar[:]) - - _helper(0, [0 for _ in dims]) - return res - - -# aten.sym_numel -# aten.reciprocal -@op(torch.ops.aten.reciprocal) -def _aten_reciprocal(a): - if _is_int(a): - return (1 / a).astype(jnp.dtype('float32')) - return 1 / a - - -# aten.select_scatter -@op(torch.ops.aten.select_scatter) -def _aten_select_scatter(input, src, dim, index): - input_indexes = [] - if dim < 0: - dim += len(input.shape) - for x in range(len(input.shape)): - if x == dim: - input_indexes.append(index) - else: - input_indexes.append(slice(None, None, None)) - return input.at[tuple(input_indexes)].set(src) - - -@op(torch.ops.aten.scatter.src) -def _aten_scatter_src(input, dim, index, src, reduce=None): - input_index, source_indexes = _scatter_index(dim, index) - return input.at[input_index].set(src[source_indexes]) - - -@op(torch.ops.aten.scatter.value) -def _aten_scatter(input, dim, index, src, reduce=None): - input_index, source_indexes = _scatter_index(dim, index) - return input.at[input_index].set(src) - - -# aten.acosh -@op(torch.ops.aten.acosh) -@op_base.promote_int_input -def _aten_acosh(self): - return jnp.arccosh(self) - - -# aten.avg_pool2d_backward -# aten.col2im -# aten.avg_pool3d -# aten.round -@op(torch.ops.aten.round) -def _aten_round(input, decimals=0): - return jnp.round(input, decimals) - - -# aten.max -@op(torch.ops.aten.max) -def _aten_max(self, dim=None, keepdim=False): - if dim is not None: - return _with_reduction_scalar(jnp.max, self, dim, - keepdim), _with_reduction_scalar( - jnp.argmax, self, dim, - keepdim).astype(jnp.int64) - else: - return _with_reduction_scalar(jnp.max, self, dim, keepdim) - - -# aten.maximum -@op(torch.ops.aten.maximum) -def _aten_maximum(self, other): - return jnp.maximum(self, other) - - -# aten.abs -@op(torch.ops.aten.abs) -def _aten_abs(self): - return jnp.abs(self) - - -# generate aten.amax only -@op(torch.ops.aten.amax) -def _aten_amax(self, dim=None, keepdim=False): - return _with_reduction_scalar(jnp.amax, self, dim, keepdim) - - -def _with_reduction_scalar(jax_func, self, dim, keepdim): - expanded = False - if self.ndim == 0: - # for self of rank 0: - # torch.any(x, 0), torch.any(x, -1) works; - # torch.any(x, 1) throws out of bounds, so it's - # behavior is the same as a jnp array of rank 1 - expanded = True - self = jnp.expand_dims(self, 0) - res = jax_func(self, axis=dim, keepdims=keepdim) - if expanded: - res = res.squeeze() - return res - - -# aten.any -@op(torch.ops.aten.any) -def _aten_any(self, dim=None, keepdim=False): - return _with_reduction_scalar(jnp.any, self, dim, keepdim) - - -# aten.arange -@op(torch.ops.aten.arange.start_step) -@op(torch.ops.aten.arange.start) -@op(torch.ops.aten.arange.default) -@op_base.convert_dtype(use_default_dtype=False) -def _aten_arange( - start, - end=None, - step=None, - *, - dtype=None, - layout=None, - requires_grad=False, - device=None, - pin_memory=False, -): - return jnp.arange( - op_base.maybe_convert_constant_dtype(start, dtype), - op_base.maybe_convert_constant_dtype(end, dtype), - op_base.maybe_convert_constant_dtype(step, dtype), - dtype=dtype, - ) - - -# aten.argmax -@op(torch.ops.aten.argmax) -def _aten_argmax(self, dim=None, keepdim=False): - return _with_reduction_scalar(jnp.argmax, self, dim, keepdim) - - -def _strided_index(sizes, strides, storage_offset=None): - ind = jnp.zeros(sizes, dtype=jnp.int32) - - for i, (size, stride) in enumerate(zip(sizes, strides)): - result_shape = (1,) * i + (size,) + (1,) * (len(sizes) - i - 1) - indexes = (jnp.arange(size) * stride).reshape(result_shape) - ind += indexes - - if storage_offset is not None: - ind += storage_offset - return ind - - -# aten.as_strided -@op(torch.ops.aten.as_strided) -@op(torch.ops.aten.as_strided_copy) -def _aten_as_strided(x, sizes, strides, storage_offset=None): - ind = _strided_index(sizes, strides, storage_offset) - flattened = jnp.ravel(x) - return flattened[ind] - - -@op(torch.ops.aten.as_strided_scatter) -def _aten_as_strided_scatter(x, src, sizes, strides, storage_offset): - ind = _strided_index(sizes, strides, storage_offset) - flattened = jnp.ravel(x) - modified = flattened.at[ind].set(src) - return modified.reshape(x.shape) - - -# aten.atan2 -@op(torch.ops.aten.atan2) -@op_base.promote_int_input -def _aten_atan2(input, other): - return jnp.arctan2(input, other) - - -# aten.bitwise_and -@op(torch.ops.aten.bitwise_and) -@op(torch.ops.aten.__and__) -def _aten_bitwise_and(self, other): - return self & other - - -# aten.bitwise_or -@op(torch.ops.aten.bitwise_or) -def _aten_bitwise_or(self, other): - return self | other - - -# aten.bitwise_xor -@op(torch.ops.aten.bitwise_xor) -def _aten_bitwise_xor(self, other): - return self ^ other - - -# aten.broadcast_tensors -@op(torch.ops.aten.broadcast_tensors) -def _aten_broadcast_tensors(*tensors): - - def _get_broadcast_shape(shapes): - """ - Determines the output shape by broadcasting all input shapes. - - Args: - shapes: A list of tuples representing the shapes of the input tensors. - - Returns: - A tuple representing the broadcasted output shape. - """ - - # Find the maximum number of dimensions among all input tensors - max_dims = max(len(shape) for shape in shapes) - # Pad shorter shapes with 1s on the left to match the maximum number of dimensions - padded_shapes = [(1,) * (max_dims - len(shape)) + shape for shape in shapes] - - # Initialize the output shape with 1s - output_shape = [1] * max_dims - # Iterate through each dimension and apply broadcasting rules - for dim in range(max_dims): - dim_sizes = [shape[dim] for shape in padded_shapes] - max_size = max(dim_sizes) - if all(size == 1 or size == max_size for size in dim_sizes): - output_shape[dim] = max_size - else: - raise ValueError("Incompatible shapes for broadcasting") - return tuple(output_shape) - - def _broadcast_dimensions(input_shape, output_shape): - """ - Determines the broadcast_dimensions argument for jax.lax.broadcast_in_dim. - - Args: - input_shape: The shape of the input tensor. - output_shape: The desired output shape after broadcasting. - - Returns: - A tuple specifying which dimensions of the input tensor should be broadcasted. - """ - - res = tuple( - i for i, (in_dim, out_dim) in enumerate(zip(input_shape, output_shape))) - return res - - # clean some function's previous wrap - if len(tensors) == 1 and len(tensors[0]) >= 1 and isinstance( - tensors[0][0], jax.Array): - tensors = tensors[0] - - # Get the shapes of all input tensors - shapes = [t.shape for t in tensors] - # Find the output shape by broadcasting all input shapes - output_shape = _get_broadcast_shape(shapes) - # Broadcast each tensor to the output shape - broadcasted_tensors = [ - jax.lax.broadcast_in_dim(t, output_shape, - _broadcast_dimensions(t.shape, output_shape)) - for t in tensors - ] - - return broadcasted_tensors - - -# aten.broadcast_to -@op(torch.ops.aten.broadcast_to) -def _aten_broadcast_to(input, shape): - return jnp.broadcast_to(input, shape) - - -# aten.clamp -@op(torch.ops.aten.clamp.default) -@op(torch.ops.aten.clamp.Tensor) -def _aten_clamp(self, min=None, max=None): - return jnp.clip(self, min, max) - - -@op(torch.ops.aten.clamp_min) -def _aten_clamp_min(input, min): - return jnp.clip(input, min=min) - - -# aten.constant_pad_nd -@op(torch.ops.aten.constant_pad_nd) -def _aten_constant_pad_nd(input, padding, value=0): - # NOTE: Torch padding is flat and reversed: (1, 1, 2, 2) - # means last dim get padded 1 in front and 1 in back; - # and second last dim get padded 2 in front and 2 in back. - # Jax padding tuple of 3-tuple: the same padding is - # [(0, 0, 0), ..., (2,2,0), (1,1,0)], where the last dimension - # is the amount of padding added between any two elements in each dimension - m = len(padding) - rev_padding = [(padding[i - 1], padding[i], 0) for i in range(m - 1, 0, -2)] - pad_dim = tuple(([(0, 0, 0)] * (len(input.shape) - m // 2)) + rev_padding) - value_casted = jax.numpy.array(value, dtype=input.dtype) - return jax.lax.pad(input, padding_value=value_casted, padding_config=pad_dim) - - -# aten.convolution_backward -@op(torch.ops.aten.lift_fresh_copy) -def _aten_lift_fresh_copy(x): - return jnp.copy(x) - - -@op(torch.ops.aten.copy) -def _aten_copy(self, src): - return jnp.broadcast_to(src, self.shape).astype(self.dtype) - - -@op(torch.ops.aten._cdist_forward) -def _aten_cdist_forward(x1, x2, p, compute_mode=""): - # x1 is B x P x M - # x2 is B x Q x M - # res is B x P x Q - x1 = jnp.expand_dims(x1, len(x1.shape) - 1) - x2 = jnp.expand_dims(x2, len(x2.shape) - 2) - return jnp.linalg.norm(x1 - x2, ord=p, axis=-1) - - -@op(torch.ops.aten._pdist_forward) -def _aten__pdist_forward(x, p=2): - pairwise_dists = _aten_cdist_forward(x, x, p) - condensed_dists = pairwise_dists[jnp.triu_indices( - pairwise_dists.shape[0], k=1)] - return condensed_dists - - -@op(torch.ops.aten.cholesky_inverse) -def _aten_cholesky_inverse(input, upper=False): - t = jnp.matrix_transpose(input) - if "complex" in str(input.dtype): - t = t.conjugate() - return jnp.linalg.inv(input @ t) - - -# aten.cos -@op(torch.ops.aten.cos) -@op_base.promote_int_input -def _aten_cos(input): - return jnp.cos(input) - - -# aten.cosh -@op(torch.ops.aten.cosh) -@op_base.promote_int_input -def _aten_cosh(input): - return jnp.cosh(input) - - -@op(torch.ops.aten.diag) -def _aten_diag(input, diagonal=0): - return jnp.diag(input, diagonal) - - -# aten.diagonal -@op(torch.ops.aten.diagonal) -@op(torch.ops.aten.diagonal_copy) -def _aten_diagonal(input, offset=0, dim1=0, dim2=1): - return jnp.diagonal(input, offset, dim1, dim2) - - -def diag_indices_with_offset(input_shape, offset, dim1=0, dim2=1): - input_len = len(input_shape) - if dim1 == dim2 or not (0 <= dim1 < input_len and 0 <= dim2 < input_len): - raise ValueError("dim1 and dim2 must be different and in range [0, " + - str(input_len - 1) + "]") - - size1, size2 = input_shape[dim1], input_shape[dim2] - if offset >= 0: - indices1 = jnp.arange(min(size1, size2 - offset)) - indices2 = jnp.arange(offset, offset + len(indices1)) - else: - indices2 = jnp.arange(min(size1 + offset, size2)) - indices1 = jnp.arange(-offset, -offset + len(indices2)) - return [indices1, indices2] - - -@op(torch.ops.aten.diagonal_scatter) -def _aten_diagonal_scatter(input, src, offset=0, dim1=0, dim2=1): - indexes = diag_indices_with_offset(input.shape, offset, dim1, dim2) - - if input.ndim == 2: - return input.at[tuple(indexes)].set(src) - else: - # src has the same shape as the output of - # jnp.diagonal(input, offset, dim1, dim2). - # Last dimension always contains the diagonal elements, - # while the preceding dimensions represent the "slices" - # from which these diagonals are extracted. Thus, - # we alter input axes to match this assumption, write src - # and then move the axes back to the original state. - input = jnp.moveaxis(input, (dim1, dim2), (-2, -1)) - multi_indexes = [slice(None)] * (input.ndim - 2) + indexes - input = input.at[tuple(multi_indexes)].set(src) - return jnp.moveaxis(input, (-2, -1), (dim1, dim2)) - - -# aten.diagflat -@op(torch.ops.aten.diagflat) -def _aten_diagflat(input, offset=0): - return jnp.diagflat(jnp.array(input), offset) - - -@op(torch.ops.aten.movedim) -def _aten_movedim(input, source, destination): - return jnp.moveaxis(input, source, destination) - - -# aten.eq -@op(torch.ops.aten.eq) -def _aten_eq(input1, input2): - return input1 == input2 - - -# aten.equal -@op(torch.ops.aten.equal) -def _aten_equal(input, other): - res = jnp.array_equal(input, other) - return bool(res) - - -# aten.erf -@op(torch.ops.aten.erf) -@op_base.promote_int_input -def _aten_erf(x): - return jax.lax.erf(x) - - -@op(torch.ops.aten.erfinv) -@op_base.promote_int_input -def _aten_erfinv(input): - return jax.lax.erf_inv(input) - - -# aten.exp -@op(torch.ops.aten.exp) -def _aten_exp(input): - res = jnp.exp(input) - new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) - if input.dtype == jax.numpy.int64: - res = res.astype(new_dtype) - return res - - -# aten.expm1 -@op(torch.ops.aten.expm1) -def _aten_expm1(input): - res = jnp.expm1(input) - new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) - if input.dtype == jax.numpy.int64: - res = res.astype(new_dtype) - return res - - -# aten.exp2 -@op(torch.ops.aten.exp2) -def _aten_exp2(input): - res = jnp.exp2(input) - new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) - if input.dtype == jax.numpy.int64: - res = res.astype(new_dtype) - return res - - -# aten.fill -@op(torch.ops.aten.fill) -@op(torch.ops.aten.full_like) -def _aten_fill(x, - value, - dtype=None, - pin_memory=None, - memory_format=None, - device=None): - if dtype is None: - dtype = x.dtype - else: - dtype = mappings.t2j_dtype(dtype) - return jnp.full(x.shape, value, dtype) - - -# aten.flip -@op(torch.ops.aten.flip) -def _aten_flip(input, dims): - if dims is not None: - return jnp.flip(input, tuple(dims)) - else: - return jnp.flip(input) - - -# aten.floor -@op(torch.ops.aten.floor) -def _aten_floor(input): - return jnp.floor(input).astype(input.dtype) - - -# aten.fmax -@op(torch.ops.aten.fmax) -def _aten_fmax(input, other): - return jnp.fmax(input, other) - - -# aten.fmin -@op(torch.ops.aten.fmin) -def _aten_fmin(input, other): - return jnp.fmin(input, other) - - -# aten.fmod -@op(torch.ops.aten.fmod) -def _aten_fmod(input, other): - return input - other * _aten_div(input, other, "trunc") - - -# aten.frexp -@op(torch.ops.aten.frexp) -def _aten_frexp(input): - return jnp.frexp(input) - - -# aten.gather -@op(torch.ops.aten.gather) -def _aten_gather(input, dim, index): - if input.ndim == 0: - return jnp.broadcast_to(input, index.shape) - # short circuit for empty outputs - if not all(index.shape): - return jnp.zeros(index.shape, dtype=input.dtype) - if dim < 0: - dim += input.ndim - input_indexes, source_indexes = _scatter_index(dim, index) - return input[input_indexes] - - -# aten.ge -@op(torch.ops.aten.ge) -def _aten_ge(self, other): - return self >= other - - -@op(torch.ops.aten.glu) -def _aten_glu(x, dim=-1): - return jax.nn.glu(x, dim) - - -# aten.hardtanh -@op(torch.ops.aten.hardtanh) -def _aten_hardtanh(input, min_val=-1, max_val=1, inplace=False): - if input.dtype == np.int64 and isinstance(max_val, float) and isinstance( - min_val, float): - min_val = int(min_val) - max_val = int(max_val) - return jnp.clip(input, min_val, max_val) - - -# aten.histc -@op(torch.ops.aten.histc) -def _aten_histc(input, bins=100, min=0, max=0): - # TODO(@manfei): this function might cause some uncertainty - if min == 0 and max == 0: - if isinstance(input, jnp.ndarray) and input.size == 0: - min = 0 - max = 0 - else: - min = jnp.min(input) - max = jnp.max(input) - range_value = (min, max) - hist, bin_edges = jnp.histogram( - input, bins=bins, range=range_value, weights=None, density=None) - return hist - - -@op(torch.ops.aten.hypot) -def _aten_hypot(input, other): - return jnp.hypot(input, other) - - -@op(torch.ops.aten.digamma) -def _aten_digamma(input, *, out=None): - res = jax.scipy.special.digamma(input).astype(jnp.float32) - # replace indices where input == 0 with -inf in res - return jnp.where(jnp.equal(input, jnp.zeros(input.shape)), -jnp.inf, res) - - -@op(torch.ops.aten.igamma) -def _aten_igamma(input, other): - return jax.scipy.special.gammainc(input, other) - - -@op(torch.ops.aten.lgamma) -def _aten_lgamma(input, *, out=None): - return jax.scipy.special.gammaln(input).astype(jnp.float32) - - -@op(torch.ops.aten.mvlgamma) -def _aten_mvlgamma(input, p, *, out=None): - input = input.astype(mappings.t2j_dtype(torch.get_default_dtype())) - return jax.scipy.special.multigammaln(input, p) - - -@op(torch.ops.aten.linalg_eig) -def _aten_linalg_eig(A): - return jnp.linalg.eig(A) - - -@op(torch.ops.aten._linalg_eigh) -def _aten_linalg_eigh(A, UPLO='L'): - return jnp.linalg.eigh(A, UPLO) - - -@op(torch.ops.aten.linalg_lstsq) -def _aten_linalg_lstsq(A, B, rcond=None, driver='gelsy'): - input_dtype = A.dtype - - m = A.shape[-2] - n = A.shape[-1] - - is_batched = A.ndim > 2 - - if is_batched: - - batch_shape = jnp.broadcast_shapes(A.shape[:-2], B.shape[:-2]) - batch_size = int(np.prod(batch_shape)) - A_reshaped = A.reshape((batch_size,) + A.shape[-2:]) - B_reshaped = B.reshape((batch_size,) + B.shape[-2:]) - - X, residuals, rank, singular_values = jax.vmap( - jnp.linalg.lstsq, in_axes=(0, - 0))(A_reshaped, B_reshaped, rcond=rcond) - - X = X.reshape(batch_shape + X.shape[-2:]) - - if driver in ['gelsd', 'gelsy', 'gelss']: - rank = rank.reshape(batch_shape) - else: - rank = jnp.array([], dtype=jnp.int64) - - full_rank = jnp.all(rank == n) - if driver == 'gelsy' or m <= n or (not full_rank): - residuals = jnp.array([], dtype=input_dtype) - else: - residuals = residuals.reshape(batch_shape + residuals.shape[-1:]) - - if driver in ['gelsd', 'gelss']: - singular_values = singular_values.reshape(batch_shape + - singular_values.shape[-1:]) - else: - singular_values = jnp.array([], dtype=input_dtype) - - else: - - X, residuals, rank, singular_values = jnp.linalg.lstsq(A, B, rcond=rcond) - - if driver not in ['gelsd', 'gelsy', 'gelss']: - rank = jnp.array([], dtype=jnp.int64) - - rank_value = None - if rank.size > 0: - rank_value = int(rank.item()) - rank = jnp.array(rank_value, dtype=jnp.int64) - - # When driver is ‘gels’, assume that A is full-rank. - full_rank = driver == 'gels' or rank_value == n - if driver == 'gelsy' or m <= n or (not full_rank): - residuals = jnp.array([], dtype=input_dtype) - - if driver not in ['gelsd', 'gelss']: - singular_values = jnp.array([], dtype=input_dtype) - - return X, residuals, rank, singular_values - - -@op(torch.ops.aten.linalg_ldl_factor_ex) -def _aten_linalg_ldl_factor_ex(A, hermitian=False, check_errors=False): - # TODO: Replace with native LDL when available: - # https://github.com/jax-ml/jax/issues/12779 - # TODO: Not tested for complex inputs. Does not support hermitian=True - pivots = jnp.broadcast_to( - jnp.arange(1, A.shape[-1] + 1, dtype=jnp.int32), A.shape[:-1]) - info = jnp.zeros(A.shape[:-2], jnp.int32) - C = jnp.linalg.cholesky(A) - if C.size == 0: - return C, pivots, info - - # Fill diagonals of stacked matrices - @functools.partial(jnp.vectorize, signature='(k,k),(k,k)->(k,k)') - def fill_diagonal_batch(x, y): - return jnp.fill_diagonal(x, jnp.diag(y), inplace=False) - - D = C * jnp.eye(C.shape[-1], dtype=A.dtype) - LD = C @ jnp.linalg.inv(D) - LD = fill_diagonal_batch(LD, D * D) - return LD, pivots, info - - -@op(torch.ops.aten.linalg_lu) -def _aten_linalg_lu(A, pivot=True, out=None): - dtype = A.dtype - - *_, m, n = A.shape - k = jnp.minimum(m, n) - - lu, _, permutation = jax.lax.linalg.lu(A) - - L = jnp.tril(lu[..., :, :k], k=-1) - eye_L = jnp.eye(m, k, dtype=dtype) - L = L + eye_L - - U = jnp.triu(lu[..., :k, :]) - - def perm_to_P(perm): - m = perm.shape[-1] - P = jnp.eye(m, dtype=dtype)[perm].T - return P - - if permutation.ndim > 1: - num_batch_dims = permutation.ndim - 1 - for _ in range(num_batch_dims): - perm_to_P = jax.vmap(perm_to_P, in_axes=0) - - P = perm_to_P(permutation) - - return P, L, U - - -@op(torch.ops.aten.linalg_lu_factor_ex) -def _aten_linalg_lu_factor_ex(A, pivot=True, check_errors=False): - lu, pivots, _ = jax.lax.linalg.lu(A) - # PT pivots vector is 1-indexed - pivots = pivots + 1 - info = jnp.zeros(A.shape[:-2], jnp.int32) - return lu, pivots, info - - -@op(torch.ops.aten.linalg_lu_solve) -def _aten_linalg_lu_solve(LU, pivots, B, left=True, adjoint=False): - # JAX pivots are offset by 1 compared to torch - pivots = pivots - 1 - if not left: - # XA = B is same as A'X = B' - trans = 0 if adjoint else 2 - x = jax.scipy.linalg.lu_solve((LU, pivots), jnp.matrix_transpose(B), trans) - x = jnp.matrix_transpose(x) - else: - trans = 2 if adjoint else 0 - x = jax.scipy.linalg.lu_solve((LU, pivots), B, trans) - return x - - -@op(torch.ops.aten.gcd) -def _aten_gcd(input, other): - return jnp.gcd(input, other) - - -# aten.lcm -@op(torch.ops.aten.lcm) -def _aten_lcm(input, other): - return jnp.lcm(input, other) - - -# aten.isinf -@op(torch.ops.aten.isinf) -def _aten_isinf(input): - return jnp.isinf(input) - - -# aten.isnan -@op(torch.ops.aten.isnan) -def _aten_isnan(input): - return jnp.isnan(input) - - -@op(torch.ops.aten.le) -def _aten_le(self, other): - return self <= other - - -# aten.leaky_relu -@op(torch.ops.aten.leaky_relu) -def _aten_leaky_relu(x, negative_slope=0.01): - return jax.nn.leaky_relu(x, negative_slope) - - -# aten.log -@op(torch.ops.aten.log) -@op_base.promote_int_input -def _aten_log(x): - return jnp.log(x) - - -# aten.log10 -@op(torch.ops.aten.log10) -@op_base.promote_int_input -def _aten_log10(x): - return jnp.log10(x) - - -# aten.log1p -@op(torch.ops.aten.log1p) -@op_base.promote_int_input -def _aten_log1p(x): - return jnp.log1p(x) - - -# aten.log2 -@op(torch.ops.aten.log2) -@op_base.promote_int_input -def _aten_log2(x): - return jnp.log2(x) - - -# aten.logical_and -@op(torch.ops.aten.logical_and) -@op(torch.ops.aten.__and__) -def _aten_logical_and(self, other): - return jnp.logical_and(self, other) - - -# aten.logical_or -@op(torch.ops.aten.logical_or) -@op(torch.ops.aten.__or__) -def _aten_logical_or(self, other): - return jnp.logical_or(self, other) - - -# aten.logical_not -@op(torch.ops.aten.logical_not) -def _aten_logical_not(self): - return jnp.logical_not(self) - - -# aten.log_softmax -@op(torch.ops.aten._log_softmax) -def _aten_log_softmax(self, axis=-1, half_to_float=False): - if self.shape == (): - return jnp.astype(0.0, self.dtype) - return jax.nn.log_softmax(self, axis) - - -# aten.logaddexp -@op(torch.ops.aten.logaddexp) -def _aten_logaddexp(self, other): - return jnp.logaddexp(self, other) - - -# aten.logaddexp2 -@op(torch.ops.aten.logaddexp2) -def _aten_logaddexp2(self, other): - return jnp.logaddexp2(self, other) - - -# aten.logcumsumexp -@op(torch.ops.aten.logcumsumexp) -def _aten_logcumsumexp(self, dim=None): - if self.shape == (): - return self - return jax.lax.cumlogsumexp(self, axis=dim) - - -# aten.max_pool3d_backward -# aten.logical_xor -@op(torch.ops.aten.logical_xor) -@op(torch.ops.aten.__xor__) -def _aten_logical_xor(self, other): - return jnp.logical_xor(self, other) - - -# aten.max_pool2d_with_indices_backward -# aten.native_dropout -# aten.native_group_norm_backward -# aten.neg -@op(torch.ops.aten.neg) -def _aten_neg(x): - return -1 * x - - -@op(torch.ops.aten.nextafter) -def _aten_nextafter(input, other, *, out=None): - return jnp.nextafter(input, other) - - -@op(torch.ops.aten.nonzero_static) -def _aten_nonzero_static(input, size, fill_value=-1): - indices = jnp.argwhere(input) - - if size < indices.shape[0]: - indices = indices[:size] - elif size > indices.shape[0]: - padding = jnp.full((size - indices.shape[0], indices.shape[1]), - fill_value, - dtype=indices.dtype) - indices = jnp.concatenate((indices, padding)) - - return indices - - -# aten.nonzero -@op(torch.ops.aten.nonzero) -def _aten_nonzero(x, as_tuple=False): - if jnp.ndim(x) == 0 and (as_tuple or x.item() == 0): - return torch.empty(0, 0, dtype=torch.int64) - if jnp.ndim( - x - ) == 0: # when x is scalar, return torch.tensor([], size=(1, 0), dtype=torch.int64) - res = torch.empty(1, 0, dtype=torch.int64) - return jnp.array(res.numpy()) - index_tuple = jnp.nonzero(x) - index_tuple = [jnp.expand_dims(p, -1) for p in index_tuple] - return jnp.concatenate(index_tuple, axis=-1) - - -# aten.prod -@op(torch.ops.aten.prod) -def _aten_prod(input, dim=None, keepdim=False, *, dtype=None): - if dtype: - input = input.astype(mappings.t2j_dtype(dtype)) - return _with_reduction_scalar(jnp.prod, input, dim, keepdim) - - -@op(torch.ops.aten.put) -def _aten_put(self, index, source, accumulate=False): - expanded = False - res = None - - if self.ndim == 0: - expanded = True - self = jnp.expand_dims(self, 0) - - if accumulate: - tmp = jnp.zeros(self.shape) - tmp = jnp.put(tmp, index, source, inplace=False) - res = jnp.add(self, tmp).astype(self.dtype) - else: - res = jnp.put(self, index, source, inplace=False) - - if expanded: - res = res.squeeze() - - return res - - -# aten.randperm -# randperm.generator(SymInt n, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -@op(torch.ops.aten.randperm, needs_env=True) -def _aten_randperm(n, - *, - generator=None, - dtype=None, - layout=None, - device=None, - pin_memory=None, - env=None): - """ - Generates a random permutation of integers from 0 to n-1. - - Args: - n: The upper bound (exclusive) of the permutation range. - generator: A PRNGKey used as the random key. If None, a new key is created. - dtype: The desired data type of the output array. Default is jnp.int64. - layout: The desired layout of the output array (e.g., 'row-major', 'column-major'). - device: The desired device on which to place the output array (e.g., jax.devices()[0]). - pin_memory: Whether to pin the output array's memory to the host. - - Returns: - A DeviceArray containing a random permutation of integers from 0 to n-1. - """ - if dtype: - dtype = mappings.t2j_dtype(dtype) - else: - dtype = jnp.int64.dtype - key = env.get_and_rotate_prng_key(generator) - indices = jnp.arange(n, dtype=dtype) - permutation = jax.random.permutation(key, indices) - return permutation - - -# aten.reflection_pad3d - - -# aten.remainder -@op(torch.ops.aten.remainder) -def _aten_remainder(inputs, other): - return inputs % other - - -# aten.repeat -@op(torch.ops.aten.repeat) -def _aten_repeat(x, reps): - return jnp.tile(x, reps) - - -# aten.replication_pad2d -# aten.replication_pad3d -# aten.roll -@op(torch.ops.aten.roll) -def _aten_roll(input, shifts, dims=None): - return jnp.roll(input, shifts, dims) - - -# aten.slice_scatter -@op(torch.ops.aten.slice_scatter) -def _aten_slice_scatter(input, src, dim=0, start=None, end=None, step=1): - input_index = [] - for x in range(len(input.shape)): - if x == dim: - input_index.append(slice(start, end, step)) - else: - input_index.append(slice(None, None, None)) - return input.at[tuple(input_index)].set(src) - - -# aten.sort -# torch.sort(input, dim=-1, descending=False, stable=False, *, out=None) -@op(torch.ops.aten.sort) -def _aten_sort(a, dim=-1, descending=False, stable=False): - if a.shape == (): - return (a, jnp.astype(0, 'int64')) - return ( - jnp.sort(a, axis=dim, stable=stable, descending=descending), - jnp.argsort(a, axis=dim, stable=stable, descending=descending), - ) - - -# aten.sym_size - - -# aten.topk -@op(torch.ops.aten.topk) -def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None): - """JAX top-k implementation using jax.lax.top_k for improved efficiency. - - Args: - input: The input JAX array. - k: The number of top elements to return. - dim: The dimension along which to find the top-k. If None, operates on the - flattened array. - largest: If True, returns the largest k elements. Otherwise, smallest k. - sorted: If True, returns the elements in sorted order. - - Returns: - A tuple (values, indices) containing: - - values: The top k values. - - indices: The indices of the top k values in the original array. - """ - if dim is None: - # last dim is chosen - dim = input.ndim - 1 - - if dim < 0: - dim = dim + input.ndim - - if not largest: - input = -input # Find top-k of negated input if we want the smallest - - if input.ndim == 0: - return input, jnp.array(0, dtype=jnp.int64.dtype) - - transpose_shape = None - if dim != -1 and dim != len(input.shape) - 1: - transpose_shape = list(range(len(input.shape))) - transpose_shape[dim], transpose_shape[-1] = ( - transpose_shape[-1], - transpose_shape[dim], - ) - input = jnp.transpose(input, transpose_shape) - - values, indices = jax.lax.top_k(input, k) - - if sorted: - values = jnp.sort(values, descending=True) - indices = jnp.take_along_axis( - indices, jnp.argsort(values, axis=-1, descending=True), axis=-1) - - if not largest: - values = -values # Negate values back if we found smallest - - if transpose_shape is not None: - values = jnp.transpose(values, transpose_shape) - indices = jnp.transpose(indices, transpose_shape) - - return values, indices - - -# aten.tril_indices -#tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -@op(torch.ops.aten.tril_indices) -def _aten_tril_indices(row, - col, - offset=0, - *, - dtype=jnp.int64.dtype, - layout=None, - device=None, - pin_memory=None): - a, b = jnp.tril_indices(row, offset, col) - return jnp.stack((a, b)) - - -# aten.tril_indices -#tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -@op(torch.ops.aten.triu_indices) -def _aten_triu_indices(row, - col, - offset=0, - *, - dtype=jnp.int64.dtype, - layout=None, - device=None, - pin_memory=None): - a, b = jnp.triu_indices(row, offset, col) - return jnp.stack((a, b)) - - -@op(torch.ops.aten.unbind_copy) -def _aten_unbind(a, dim=0): - return [ - jax.lax.index_in_dim(a, i, dim, keepdims=False) - for i in range(a.shape[dim]) - ] - - -# aten.unique_dim -# -# NOTE: Like the CUDA and CPU implementations, this implementation always sorts -# the tensor regardless of the `sorted` argument passed to `torch.unique`. -@op(torch.ops.aten.unique_dim) -def _aten_unique_dim(input_tensor, - dim, - sort=True, - return_inverse=False, - return_counts=False): - result_tensor_or_tuple = jnp.unique( - input_tensor, - return_index=False, - return_inverse=return_inverse, - return_counts=return_counts, - axis=dim, - equal_nan=False) - result_list = ( - list(result_tensor_or_tuple) if isinstance(result_tensor_or_tuple, tuple) - else [result_tensor_or_tuple]) - - if not return_inverse: - result_list.insert(1, None) - elif _jax_version < (0, 4, 31) and dim is not None: - result_list[1] = result_list[1].flatten() - - if not return_counts: - result_list.insert(2, None) - - # [result, None, None] if return_inverse=False and return_counts=False - # [result, inverse, None] if return_inverse=True and return_counts=False - # [result, None, counts] if return_inverse=False and return_counts=True - # [result, inverse, counts] if return_inverse=True and return_counts=True - return result_list - - -# aten._unique -# -# NOTE: Like the CUDA and CPU implementations, this implementation always sorts -# the tensor regardless of the `sorted` argument passed to `torch.unique`. -@op(torch.ops.aten._unique) -def _aten_unique(input_tensor, sort=True, return_inverse=False): - result_tensor_or_tuple = jnp.unique( - input_tensor, - return_index=False, - return_inverse=return_inverse, - return_counts=False, - axis=None, - equal_nan=False) - if return_inverse: - return result_tensor_or_tuple - else: - return (result_tensor_or_tuple, None) - - -# aten._unique2 -# -# NOTE: Like the CUDA and CPU implementations, this implementation always sorts -# the tensor regardless of the `sorted` argument passed to `torch.unique`. -@op(torch.ops.aten._unique2) -def _aten_unique2(input_tensor, - sort=True, - return_inverse=False, - return_counts=False): - return _aten_unique_dim( - input_tensor=input_tensor, - dim=None, - sort=sort, - return_inverse=return_inverse, - return_counts=return_counts) - - -# aten.unique_consecutive -@op(torch.ops.aten.unique_consecutive) -def _aten_unique_consecutive(input_tensor, - return_inverse=False, - return_counts=None, - dim=None): - # Explanation of computations (shown in 1D for simplicity): - # - # Input [a b b c c c d d d d e e e e e] - # Slice dropping final element (input[:-1]) [a b b c c c d d d d e e e e] - # Slice dropping first element (input[1:]) [b b c c c d d d d e e e e e] - # Boolean != operation on shifted slices [1 0 1 0 0 1 0 0 0 1 0 0 0 0] - # Prepend 1 to represent the first element [1 1 0 1 0 0 1 0 0 0 1 0 0 0 0] - # Filter input by the resulting bool array [a b c d e ] - # Output [a b c d e] - - if dim is None: - inverse_shape = input_tensor.shape - input_tensor = input_tensor.flatten() - ndim = 1 - dim = 0 - else: - inverse_shape = input_tensor.shape[dim] - ndim = input_tensor.ndim - if dim < 0: - dim += ndim - - nd_slice_0 = tuple( - slice(None, -1) if d == dim else slice(None) for d in range(ndim)) - nd_slice_1 = tuple( - slice(1, None) if d == dim else slice(None) for d in range(ndim)) - - axes_to_reduce = tuple(d for d in range(ndim) if d != dim) - - does_not_equal_prior = ( - jnp.any( - input_tensor[nd_slice_0] != input_tensor[nd_slice_1], - axis=axes_to_reduce, - keepdims=False)) - - if input_tensor.shape[dim] != 0: - # Prepend `True` to represent the first element of the input. - does_not_equal_prior = jnp.insert(does_not_equal_prior, 0, True) - - include_indices = jnp.argwhere(does_not_equal_prior)[:, 0] - - output_tensor = input_tensor[tuple( - include_indices if d == dim else slice(None) for d in range(ndim))] - - if return_inverse or return_counts: - counts = ( - jnp.append(include_indices[1:], input_tensor.shape[dim]) - - include_indices[:]) - - inverse = ( - jnp.reshape(jnp.repeat(jnp.arange(len(counts)), counts), inverse_shape) - if return_inverse else None) - - return output_tensor, inverse, counts - - return output_tensor, None, None - - -# NOTE: skip aten.upsample_nearest2d and aten.upsample_bilinear2d -# despite those being core aten ops, they also have decompositions. -# here we are using torch decompositions. - - -# aten.where -@op(torch.ops.aten.where) -@op(torch.ops.aten.where.self) -@op(torch.ops.aten.where.ScalarSelf) -@op(torch.ops.aten.where.ScalarOther) -@op(torch.ops.aten.where.Scalar) -def _aten_where(condition, x=None, y=None): - return jnp.where(condition, x, y) - - -# aten.to.dtype -# Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None -@op(torch.ops.aten.to.dtype) -def _aten_to_dtype(a, - dtype, - non_blocking=False, - copy=False, - memory_format=None): - if dtype: - jaxdtype = mappings.t2j_dtype(dtype) - return a.astype(jaxdtype) - - -@op(torch.ops.aten.to.dtype_layout) -def _aten_to_dtype_layout(a, - *, - dtype=None, - layout=None, - device=None, - pin_memory=None, - non_blocking=False, - copy=False, - memory_format=None): - return _aten_to_dtype( - a, - dtype, - non_blocking=non_blocking, - copy=copy, - memory_format=memory_format) - - -# aten.to.device - - -# Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False -@op(torch.ops.aten.var_mean.correction) -def _aten_var_mean_correction(tensor, dim=None, correction=1, keepdim=False): - # The internal API technically has a default `correction` argument of `None`, - # but the public API has a default argument of 1. Therefore, we simply set our - # default argument to 1. However, since the argument is officially supposed to - # be nullable, we still need to check for `None` per the API contract. - if correction is None: - correction = 1 - mean = jnp.mean(tensor, axis=dim, keepdims=keepdim) - # TODO: Pass in the `mean=mean` argument once `jax.numpy.var` supports it. - var = jnp.var(tensor, axis=dim, ddof=correction, keepdims=keepdim) - return var, mean - - -@op(torch.ops.aten.scalar_tensor) -@op_base.convert_dtype() -def _aten_scalar_tensor(s, - dtype=None, - layout=None, - device=None, - pin_memory=None): - return jnp.array(s, dtype=dtype) - - -@op(torch.ops.aten.to.device) -def _aten_to_device(x, device, dtype): - return x - - -@op(torch.ops.aten.max_pool2d_with_indices_backward) -def max_pool2d_with_indices_backward_custom(grad_output, self, kernel_size, - stride, padding, dilation, - ceil_mode, indices): - """ - Approximates the gradient calculation of PyTorch's max_pool2d_with_indices_backward. - - Args: - grad_output: The gradient tensor from the preceding layer. - self: The input tensor on which the original max pooling was performed. - kernel_size: The size of the pooling window. - stride: The stride of the pooling window. - padding: The padding applied during max pooling. - dilation: The dilation factor for the pooling operation. - ceil_mode: Whether to use ceil or floor when calculating output shapes. - indices: The indices of the maximum values, as produced by max_pool2d_with_indices. - - Returns: - The calculated gradient with respect to the input (grad_input). - """ - - kH, kW = kernel_size - dH, dW = stride - padH, padW = padding - dilH, dilW = dilation - - # Calculate output shape (may need adjustment based on ceil_mode) - out_shape = jnp.array(self.shape) - grad_input = jnp.zeros_like(self) - - # Iterate over the flattened input and output tensors - for i, idx in enumerate(indices.flatten()): - # Calculate input coordinates corresponding to the maximum value - out_y, out_x = i // grad_output.shape[3], i % grad_output.shape[3] - in_y = out_y * dH - padH + out_y * (dilH - 1) - in_x = out_x * dW - padW + out_x * (dilW - 1) - - # Scatter the gradient to the appropriate input locations (handling potential overlaps) - for y in range(in_y, in_y + kH): - for x in range(in_x, in_x + kW): - if 0 <= y < grad_input.shape[2] and 0 <= x < grad_input.shape[3]: - grad_input = grad_input.at[y, x].add(grad_output.flatten()[i]) - - return grad_input - - -@op(torch.ops.aten._local_scalar_dense) -def _aten_local_scalar_dense(x): - return x.item() - - -@op(torch.ops.aten.tensor_split.sections) -def _aten_tensor_split(ary, indices_or_sections, axis=0): - return jnp.array_split(ary, indices_or_sections, axis) - - -@op(torch.ops.aten.randn, needs_env=True) -@op_base.convert_dtype() -def _aten_randn( - *size, - generator=None, - out=None, - dtype=None, - layout=torch.strided, - device=None, - requires_grad=False, - pin_memory=False, - env=None, -): - shape = size - if len(shape) == 1 and isinstance(shape[0], (list, tuple)): - shape = shape[0] - key = env.get_and_rotate_prng_key(generator) - res = jax.random.normal(key, shape) - if dtype is not None: - res = res.astype(dtype) - return res - - -@op(torch.ops.aten.bernoulli.p, needs_env=True) -def _aten_bernoulli( - self, - p=0.5, - *, - generator=None, - env=None, -): - key = env.get_and_rotate_prng_key(generator) - res = jax.random.uniform(key, self.shape) < p - return res - - -@op(torch.ops.aten.geometric, needs_env=True) -def geometric(self, p, *, generator=None, env=None): - key = env.get_and_rotate_prng_key(generator) - res = jax.random.geometric(key, p, self.shape) - return res - - -@op(torch.ops.aten.randn_like, needs_env=True) -@op_base.convert_dtype() -def _aten_randn_like( - x, - *, - dtype=None, - layout=None, - device=None, - pin_memory=False, - memory_format=torch.preserve_format, - env=None, -): - key = env.get_and_rotate_prng_key() - return jax.random.normal(key, dtype=dtype or x.dtype, shape=x.shape) - - -@op(torch.ops.aten.rand, needs_env=True) -@op_base.convert_dtype() -def _rand( - *size, - generator=None, - out=None, - dtype=None, - layout=torch.strided, - device=None, - requires_grad=False, - pin_memory=False, - env=None, -): - shape = size - if len(shape) == 1 and isinstance(shape[0], (list, tuple)): - shape = shape[0] - key = env.get_and_rotate_prng_key(generator) - res = jax.random.uniform(key, shape) - if dtype is not None: - res = res.astype(dtype) - return res - - -@op(torch.ops.aten.outer) -def _aten_outer(a, b): - return jnp.outer(a, b) - - -@op(torch.ops.aten.allclose) -def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): - return jnp.allclose(input, other, rtol, atol, equal_nan) - - -@op(torch.ops.aten.native_batch_norm) -def _aten_native_batch_norm(input, - weight, - bias, - running_mean, - running_var, - training=False, - momentum=0.1, - eps=1e-5): - - if running_mean is None: - running_mean = jnp.zeros( - input.shape[1], dtype=input.dtype) # Initialize running mean if None - if running_var is None: - running_var = jnp.ones( - input.shape[1], - dtype=input.dtype) # Initialize running variance if None - - if training: - return _aten__native_batch_norm_legit(input, weight, bias, running_mean, - running_var, training, momentum, eps) - else: - return _aten__native_batch_norm_legit_no_training(input, weight, bias, - running_mean, running_var, - momentum, eps) - - -@op(torch.ops.aten.normal, needs_env=True) -def _aten_normal(self, mean=0, std=1, generator=None, env=None): - shape = self.shape - res = _aten_randn(*shape, generator=generator, env=env) - return res * std + mean - - -# TODO: not clear what this function should actually do -# https://github.com/pytorch/pytorch/blob/d96c80649f301129219469d8b4353e52edab3b78/aten/src/ATen/native/native_functions.yaml#L7933-L7940 -@op(torch.ops.aten.lift_fresh) -def _aten_lift_fresh(self): - return self - - -@op(torch.ops.aten.uniform, needs_env=True) -def _aten_uniform(self, from_=0, to=1, *, generator=None, env=None): - assert from_ <= to, f'Uniform from(passed in {from_}) must be less than to(passed in {to})' - shape = self.shape - res = _rand(*shape, generator=generator, env=env) - return res * (to - from_) + from_ - - -#func: randint.low_generator(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - - -@op(torch.ops.aten.randint, needs_env=True) -@op_base.convert_dtype(use_default_dtype=False) -def _aten_randint( - *args, - generator=None, - dtype=None, - env=None, - **kwargs, -): - if len(args) == 3: - # low, high, size - low, high, size = args - elif len(args) == 2: - high, size = args - low = 0 - else: - raise AssertionError( - f'Expected at 2 or 3 args for Aten::randint, got {len(args)}') - - key = env.get_and_rotate_prng_key(generator) - res = jax.random.randint(key, size, low, high) - if dtype is not None: - res = res.astype(dtype) - return res - - -@op(torch.ops.aten.randint_like, - torch.ops.aten.randint.generator, - needs_env=True) -@op_base.convert_dtype(use_default_dtype=False) -def _aten_randint_like( - input, - *args, - generator=None, - dtype=None, - env=None, - **kwargs, -): - if len(args) == 2: - low, high = args - elif len(args) == 1: - high = args[0] - low = 0 - else: - raise AssertionError( - f'Expected at 1 or 2 args for Aten::randint_like, got {len(args)}') - - shape = input.shape - dtype = dtype or input.dtype - key = env.get_and_rotate_prng_key(generator) - res = jax.random.randint(key, shape, low, high) - if dtype is not None: - res = res.astype(dtype) - return res - - -@op(torch.ops.aten.dim, is_jax_function=False) -def _aten_dim(self): - return len(self.shape) - - -@op(torch.ops.aten.copysign) -def _aten_copysign(input, other, *, out=None): - result = jnp.copysign(input, other) - # torch.copysign(x, y) returns float32 for integer x and y, - # regardless of their exact integer dtype, whereas jax.copysign returns - # float64 when one or both of them is int64. - if jnp.issubdtype(input.dtype, jnp.integer) and jnp.issubdtype( - other.dtype, jnp.integer): - result = result.astype(jnp.float32) - return result - - -@op(torch.ops.aten.i0) -@op_base.promote_int_input -def _aten_i0(self): - return jax.scipy.special.i0(self) - - -@op(torch.ops.aten.special_i0e) -@op_base.promote_int_input -def _aten_i0e(self): - return jax.scipy.special.i0e(self) - - -@op(torch.ops.aten.special_i1) -@op_base.promote_int_input -def _aten_special_i1(self): - return jax.scipy.special.i1(self) - - -@op(torch.ops.aten.special_i1e) -@op_base.promote_int_input -def _aten_special_i1e(self): - return jax.scipy.special.i1e(self) - - -@op(torch.ops.aten.special_laguerre_polynomial_l) -@op_base.promote_int_input -def _aten_special_laguerre_polynomial_l(self, n): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3106-L3134 - - @jnp.vectorize - def vectorized(x, n_i): - - def negative_n(x): - return jnp.zeros_like(x) - - def zero_n(x): - return jnp.ones_like(x) - - def one_n(x): - return jnp.ones_like(x) - x - - def zero_abs(x): - return jnp.ones_like(x) - - def default(x): - - def f(k, carry): - p, q = carry - return (q, ((k * 2 + (jnp.ones_like(x) - x)) * q - k * p) / (k + 1)) - - _, q = jax.lax.fori_loop(1, n_i, f, init_val=(1.0, jnp.ones_like(x) - x)) - return q - - return jnp.piecewise( - x, [n_i == 1, n_i == 0, - jnp.abs(n_i) == jnp.zeros_like(x), n_i < 0], - [one_n, zero_n, zero_abs, negative_n, default]) - - return vectorized(self, n.astype(jnp.int64)) - - -@op(torch.ops.aten.special_modified_bessel_i0) -@op_base.promote_int_input -def _aten_special_modified_bessel_i0(self): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3182-L3268 - - def small(x): - A = jnp.array( - [ - -4.41534164647933937950e-18, - 3.33079451882223809783e-17, - -2.43127984654795469359e-16, - 1.71539128555513303061e-15, - -1.16853328779934516808e-14, - 7.67618549860493561688e-14, - -4.85644678311192946090e-13, - 2.95505266312963983461e-12, - -1.72682629144155570723e-11, - 9.67580903537323691224e-11, - -5.18979560163526290666e-10, - 2.65982372468238665035e-09, - -1.30002500998624804212e-08, - 6.04699502254191894932e-08, - -2.67079385394061173391e-07, - 1.11738753912010371815e-06, - -4.41673835845875056359e-06, - 1.64484480707288970893e-05, - -5.75419501008210370398e-05, - 1.88502885095841655729e-04, - -5.76375574538582365885e-04, - 1.63947561694133579842e-03, - -4.32430999505057594430e-03, - 1.05464603945949983183e-02, - -2.37374148058994688156e-02, - 4.93052842396707084878e-02, - -9.49010970480476444210e-02, - 1.71620901522208775349e-01, - -3.04682672343198398683e-01, - 6.76795274409476084995e-01, - ], - dtype=self.dtype, - ) - - def f(carry, val): - p, q, a = carry - p, q = q, a - return (p, q, ((x / 2.0) - 2.0) * q - p + val), None - - (p, _, a), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A) - - return jnp.exp(x) * (0.5 * (a - p)) - - def default(x): - B = jnp.array( - [ - -7.23318048787475395456e-18, - -4.83050448594418207126e-18, - 4.46562142029675999901e-17, - 3.46122286769746109310e-17, - -2.82762398051658348494e-16, - -3.42548561967721913462e-16, - 1.77256013305652638360e-15, - 3.81168066935262242075e-15, - -9.55484669882830764870e-15, - -4.15056934728722208663e-14, - 1.54008621752140982691e-14, - 3.85277838274214270114e-13, - 7.18012445138366623367e-13, - -1.79417853150680611778e-12, - -1.32158118404477131188e-11, - -3.14991652796324136454e-11, - 1.18891471078464383424e-11, - 4.94060238822496958910e-10, - 3.39623202570838634515e-09, - 2.26666899049817806459e-08, - 2.04891858946906374183e-07, - 2.89137052083475648297e-06, - 6.88975834691682398426e-05, - 3.36911647825569408990e-03, - 8.04490411014108831608e-01, - ], - dtype=self.dtype, - ) - - def f(carry, val): - p, q, b = carry - p, q = q, b - return (p, q, (32.0 / x - 2.0) * q - p + val), None - - (p, _, b), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B) - - return jnp.exp(x) * (0.5 * (b - p)) / jnp.sqrt(x) - - self = jnp.abs(self) - return jnp.piecewise(self, [self <= 8], [small, default]) - - -@op(torch.ops.aten.special_modified_bessel_i1) -@op_base.promote_int_input -def _aten_special_modified_bessel_i1(self): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3271-L3364 - - def small(x): - A = jnp.array( - [ - 2.77791411276104639959e-18, - -2.11142121435816608115e-17, - 1.55363195773620046921e-16, - -1.10559694773538630805e-15, - 7.60068429473540693410e-15, - -5.04218550472791168711e-14, - 3.22379336594557470981e-13, - -1.98397439776494371520e-12, - 1.17361862988909016308e-11, - -6.66348972350202774223e-11, - 3.62559028155211703701e-10, - -1.88724975172282928790e-09, - 9.38153738649577178388e-09, - -4.44505912879632808065e-08, - 2.00329475355213526229e-07, - -8.56872026469545474066e-07, - 3.47025130813767847674e-06, - -1.32731636560394358279e-05, - 4.78156510755005422638e-05, - -1.61760815825896745588e-04, - 5.12285956168575772895e-04, - -1.51357245063125314899e-03, - 4.15642294431288815669e-03, - -1.05640848946261981558e-02, - 2.47264490306265168283e-02, - -5.29459812080949914269e-02, - 1.02643658689847095384e-01, - -1.76416518357834055153e-01, - 2.52587186443633654823e-01, - ], - dtype=self.dtype, - ) - - def f(carry, val): - p, q, a = carry - p, q = q, a - return (p, q, ((jnp.abs(x) / 2.0) - 2.0) * q - p + val), None - - (p, _, a), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A) - - return jax.lax.cond( - x < 0, lambda: -(0.5 * (a - p) * jnp.abs(x) * jnp.exp(jnp.abs(x))), - lambda: 0.5 * (a - p) * jnp.abs(x) * jnp.exp(jnp.abs(x))) - - def default(x): - B = jnp.array( - [ - 7.51729631084210481353e-18, - 4.41434832307170791151e-18, - -4.65030536848935832153e-17, - -3.20952592199342395980e-17, - 2.96262899764595013876e-16, - 3.30820231092092828324e-16, - -1.88035477551078244854e-15, - -3.81440307243700780478e-15, - 1.04202769841288027642e-14, - 4.27244001671195135429e-14, - -2.10154184277266431302e-14, - -4.08355111109219731823e-13, - -7.19855177624590851209e-13, - 2.03562854414708950722e-12, - 1.41258074366137813316e-11, - 3.25260358301548823856e-11, - -1.89749581235054123450e-11, - -5.58974346219658380687e-10, - -3.83538038596423702205e-09, - -2.63146884688951950684e-08, - -2.51223623787020892529e-07, - -3.88256480887769039346e-06, - -1.10588938762623716291e-04, - -9.76109749136146840777e-03, - 7.78576235018280120474e-01, - ], - dtype=self.dtype, - ) - - def f(carry, val): - p, q, b = carry - p, q = q, b - return (p, q, (32.0 / jnp.abs(x) - 2.0) * q - p + val), None - - (p, _, b), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B) - - return jax.lax.cond( - x < 0, lambda: -(jnp.exp(jnp.abs(x)) * - (0.5 * (b - p)) / jnp.sqrt(jnp.abs(x))), - lambda: jnp.exp(jnp.abs(x)) * (0.5 * (b - p)) / jnp.sqrt(jnp.abs(x))) - - return jnp.piecewise(self, [self <= 8], [small, default]) - - -@op(torch.ops.aten.special_modified_bessel_k0) -@op_base.promote_int_input -def _aten_special_modified_bessel_k0(self): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3367-L3441 - - def zero(x): - return jnp.array(jnp.inf, x.dtype) - - def negative(x): - return jnp.array(jnp.nan, x.dtype) - - def small(x): - A = jnp.array( - [ - 1.37446543561352307156e-16, - 4.25981614279661018399e-14, - 1.03496952576338420167e-11, - 1.90451637722020886025e-09, - 2.53479107902614945675e-07, - 2.28621210311945178607e-05, - 1.26461541144692592338e-03, - 3.59799365153615016266e-02, - 3.44289899924628486886e-01, - -5.35327393233902768720e-01, - ], - dtype=self.dtype, - ) - - def f(carry, val): - p, q, a = carry - p, q = q, a - return (p, q, (x * x - 2.0) * q - p + val), None - - (p, _, a), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A) - - return 0.5 * (a - p) - jnp.log( - 0.5 * x) * _aten_special_modified_bessel_i0(x) - - def default(x): - B = jnp.array( - [ - 5.30043377268626276149e-18, - -1.64758043015242134646e-17, - 5.21039150503902756861e-17, - -1.67823109680541210385e-16, - 5.51205597852431940784e-16, - -1.84859337734377901440e-15, - 6.34007647740507060557e-15, - -2.22751332699166985548e-14, - 8.03289077536357521100e-14, - -2.98009692317273043925e-13, - 1.14034058820847496303e-12, - -4.51459788337394416547e-12, - 1.85594911495471785253e-11, - -7.95748924447710747776e-11, - 3.57739728140030116597e-10, - -1.69753450938905987466e-09, - 8.57403401741422608519e-09, - -4.66048989768794782956e-08, - 2.76681363944501510342e-07, - -1.83175552271911948767e-06, - 1.39498137188764993662e-05, - -1.28495495816278026384e-04, - 1.56988388573005337491e-03, - -3.14481013119645005427e-02, - 2.44030308206595545468e+00, - ], - dtype=self.dtype, - ) - - def f(carry, val): - p, q, b = carry - p, q = q, b - return (p, q, (8.0 / x - 2.0) * q - p + val), None - - (p, _, b), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B) - - return jnp.exp(-x) * (0.5 * (b - p)) / jnp.sqrt(x) - - return jnp.piecewise(self, [self <= 2, self < 0, self == 0], - [small, negative, zero, default]) - - -@op(torch.ops.aten.special_modified_bessel_k1) -@op_base.promote_int_input -def _aten_special_modified_bessel_k1(self): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3444-L3519 - - def zero(x): - return jnp.array(jnp.inf, x.dtype) - - def negative(x): - return jnp.array(jnp.nan, x.dtype) - - def small(x): - A = jnp.array( - [ - -7.02386347938628759343e-18, - -2.42744985051936593393e-15, - -6.66690169419932900609e-13, - -1.41148839263352776110e-10, - -2.21338763073472585583e-08, - -2.43340614156596823496e-06, - -1.73028895751305206302e-04, - -6.97572385963986435018e-03, - -1.22611180822657148235e-01, - -3.53155960776544875667e-01, - 1.52530022733894777053e+00, - ], - dtype=self.dtype, - ) - - def f(carry, val): - p, q, a = carry - p, q = q, a - a = (x * x - 2.0) * q - p + val - return (p, q, a), None - - (p, _, a), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A) - - return jnp.log( - 0.5 * x) * _aten_special_modified_bessel_i1(x) + 0.5 * (a - p) / x - - def default(x): - B = jnp.array( - [ - -5.75674448366501715755e-18, - 1.79405087314755922667e-17, - -5.68946255844285935196e-17, - 1.83809354436663880070e-16, - -6.05704724837331885336e-16, - 2.03870316562433424052e-15, - -7.01983709041831346144e-15, - 2.47715442448130437068e-14, - -8.97670518232499435011e-14, - +3.34841966607842919884e-13, - -1.28917396095102890680e-12, - 5.13963967348173025100e-12, - -2.12996783842756842877e-11, - 9.21831518760500529508e-11, - -4.19035475934189648750e-10, - 2.01504975519703286596e-09, - -1.03457624656780970260e-08, - 5.74108412545004946722e-08, - -3.50196060308781257119e-07, - 2.40648494783721712015e-06, - -1.93619797416608296024e-05, - 1.95215518471351631108e-04, - -2.85781685962277938680e-03, - 1.03923736576817238437e-01, - 2.72062619048444266945e+00, - ], - dtype=self.dtype, - ) - - def f(carry, val): - p, q, b = carry - p, q = q, b - b = (8.0 / x - 2.0) * q - p + val - return (p, q, b), None - - (p, _, b), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B) - - return jnp.exp(-x) * (0.5 * (b - p)) / jnp.sqrt(x) - - return jnp.piecewise(self, [self <= 2, self < 0, self == 0], - [small, negative, zero, default]) - - -@op(torch.ops.aten.polygamma) -def _aten_polygamma(x, n): - if n.dtype in [jnp.int8, jnp.int16, jnp.int32, jnp.int64]: - n = n.astype(mappings.t2j_dtype(torch.get_default_dtype())) - return jax.lax.polygamma(jnp.float32(x), n) - - -@op(torch.ops.aten.special_ndtri) -@op_base.promote_int_input -def _aten_special_ndtri(self): - return jax.scipy.special.ndtri(self) - - -@op(torch.ops.aten.special_bessel_j0) -@op_base.promote_int_input -def _aten_special_bessel_j0(self): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L2379-L2489 - - def very_small(x): - return 1.0 - x * x / 4.0 - - def small(x): - RP = jnp.array( - [ - -4.79443220978201773821e09, - 1.95617491946556577543e12, - -2.49248344360967716204e14, - 9.70862251047306323952e15, - ], - dtype=self.dtype, - ) - RQ = jnp.array( - [ - 4.99563147152651017219e02, - 1.73785401676374683123e05, - 4.84409658339962045305e07, - 1.11855537045356834862e10, - 2.11277520115489217587e12, - 3.10518229857422583814e14, - 3.18121955943204943306e16, - 1.71086294081043136091e18, - ], - dtype=self.dtype, - ) - - rp = op_base.foreach_loop(RP, lambda carry, rp_i: carry * (x * x) + rp_i) - rq = op_base.foreach_loop(RQ, lambda carry, rq_i: carry * (x * x) + rq_i) - - return ((x * x - 5.78318596294678452118e00) * - (x * x - 3.04712623436620863991e01) * rp / rq) - - def default(x): - PP = jnp.array( - [ - 7.96936729297347051624e-04, - 8.28352392107440799803e-02, - 1.23953371646414299388e00, - 5.44725003058768775090e00, - 8.74716500199817011941e00, - 5.30324038235394892183e00, - 9.99999999999999997821e-01, - ], - dtype=self.dtype, - ) - PQ = jnp.array( - [ - 9.24408810558863637013e-04, - 8.56288474354474431428e-02, - 1.25352743901058953537e00, - 5.47097740330417105182e00, - 8.76190883237069594232e00, - 5.30605288235394617618e00, - 1.00000000000000000218e00, - ], - dtype=self.dtype, - ) - QP = jnp.array( - [ - -1.13663838898469149931e-02, - -1.28252718670509318512e00, - -1.95539544257735972385e01, - -9.32060152123768231369e01, - -1.77681167980488050595e02, - -1.47077505154951170175e02, - -5.14105326766599330220e01, - -6.05014350600728481186e00, - ], - dtype=self.dtype, - ) - QQ = jnp.array( - [ - 6.43178256118178023184e01, - 8.56430025976980587198e02, - 3.88240183605401609683e03, - 7.24046774195652478189e03, - 5.93072701187316984827e03, - 2.06209331660327847417e03, - 2.42005740240291393179e02, - ], - dtype=self.dtype, - ) - - pp = op_base.foreach_loop( - PP, lambda carry, pp_i: carry * (25.0 / (x * x)) + pp_i) - pq = op_base.foreach_loop( - PQ, lambda carry, pq_i: carry * (25.0 / (x * x)) + pq_i) - qp = op_base.foreach_loop( - QP, lambda carry, qp_i: carry * (25.0 / (x * x)) + qp_i) - qq = op_base.foreach_loop( - QQ, lambda carry, qq_i: carry * (25.0 / (x * x)) + qq_i) - - return ((pp / pq * jnp.cos(x - 0.785398163397448309615660845819875721) - - 5.0 / x * - (qp / qq) * jnp.sin(x - 0.785398163397448309615660845819875721)) * - 0.797884560802865355879892119868763737 / jnp.sqrt(x)) - - self = jnp.abs(self) - # Last True condition in `piecewise` takes priority, but last function is - # default. See https://github.com/numpy/numpy/issues/16475 - return jnp.piecewise(self, [self <= 5.0, self < 0.00001], - [small, very_small, default]) - - -@op(torch.ops.aten.special_bessel_j1) -@op_base.promote_int_input -def _aten_special_bessel_j1(self): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L2491-L2597 - - def small(x): - RP = jnp.array( - [ - -8.99971225705559398224e08, - 4.52228297998194034323e11, - -7.27494245221818276015e13, - 3.68295732863852883286e15, - ], - dtype=self.dtype, - ) - RQ = jnp.array( - [ - 6.20836478118054335476e02, - 2.56987256757748830383e05, - 8.35146791431949253037e07, - 2.21511595479792499675e10, - 4.74914122079991414898e12, - 7.84369607876235854894e14, - 8.95222336184627338078e16, - 5.32278620332680085395e18, - ], - dtype=self.dtype, - ) - - rp = op_base.foreach_loop(RP, lambda carry, rp_i: carry * (x * x) + rp_i) - rq = op_base.foreach_loop(RQ, lambda carry, rq_i: carry * (x * x) + rq_i) - - return (rp / rq * x * (x * x - 1.46819706421238932572e01) * - (x * x - 4.92184563216946036703e01)) - - def default(x): - PP = jnp.array( - [ - 7.62125616208173112003e-04, - 7.31397056940917570436e-02, - 1.12719608129684925192e00, - 5.11207951146807644818e00, - 8.42404590141772420927e00, - 5.21451598682361504063e00, - 1.00000000000000000254e00, - ], - dtype=self.dtype, - ) - PQ = jnp.array( - [ - 5.71323128072548699714e-04, - 6.88455908754495404082e-02, - 1.10514232634061696926e00, - 5.07386386128601488557e00, - 8.39985554327604159757e00, - 5.20982848682361821619e00, - 9.99999999999999997461e-01, - ], - dtype=self.dtype, - ) - QP = jnp.array( - [ - 5.10862594750176621635e-02, - 4.98213872951233449420e00, - 7.58238284132545283818e01, - 3.66779609360150777800e02, - 7.10856304998926107277e02, - 5.97489612400613639965e02, - 2.11688757100572135698e02, - 2.52070205858023719784e01, - ], - dtype=self.dtype, - ) - QQ = jnp.array( - [ - 7.42373277035675149943e01, - 1.05644886038262816351e03, - 4.98641058337653607651e03, - 9.56231892404756170795e03, - 7.99704160447350683650e03, - 2.82619278517639096600e03, - 3.36093607810698293419e02, - ], - dtype=self.dtype, - ) - - pp = op_base.foreach_loop( - PP, lambda carry, pp_i: carry * (25.0 / (x * x)) + pp_i) - pq = op_base.foreach_loop( - PQ, lambda carry, pq_i: carry * (25.0 / (x * x)) + pq_i) - qp = op_base.foreach_loop( - QP, lambda carry, qp_i: carry * (25.0 / (x * x)) + qp_i) - qq = op_base.foreach_loop( - QQ, lambda carry, qq_i: carry * (25.0 / (x * x)) + qq_i) - - return ((pp / pq * jnp.cos(x - 2.356194490192344928846982537459627163) - - 5.0 / x * - (qp / qq) * jnp.sin(x - 2.356194490192344928846982537459627163)) * - 0.797884560802865355879892119868763737 / jnp.sqrt(x)) - - # If x < 0, bessel_j1(x) = -bessel_j1(-x) - sign = jnp.sign(self) - self = jnp.abs(self) - return sign * jnp.piecewise( - self, - [self <= 5.0], - [small, default], - ) - - -@op(torch.ops.aten.special_bessel_y0) -@op_base.promote_int_input -def _aten_special_bessel_y0(self): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L2599-L2712 - - def zero(x): - return jnp.array(-jnp.inf, x.dtype) - - def negative(x): - return jnp.array(jnp.nan, x.dtype) - - def small(x): - YP = jnp.array( - [ - 1.55924367855235737965e04, - -1.46639295903971606143e07, - 5.43526477051876500413e09, - -9.82136065717911466409e11, - 8.75906394395366999549e13, - -3.46628303384729719441e15, - 4.42733268572569800351e16, - -1.84950800436986690637e16, - ], - dtype=self.dtype, - ) - YQ = jnp.array( - [ - 1.04128353664259848412e03, - 6.26107330137134956842e05, - 2.68919633393814121987e08, - 8.64002487103935000337e10, - 2.02979612750105546709e13, - 3.17157752842975028269e15, - 2.50596256172653059228e17, - ], - dtype=self.dtype, - ) - - yp = op_base.foreach_loop(YP, lambda carry, yp_i: carry * (x * x) + yp_i) - yq = op_base.foreach_loop(YQ, lambda carry, yq_i: carry * (x * x) + yq_i) - - return yp / yq + (0.636619772367581343075535053490057448 * jnp.log(x) * - _aten_special_bessel_j0(x)) - - def default(x): - PP = jnp.array( - [ - 7.96936729297347051624e-04, - 8.28352392107440799803e-02, - 1.23953371646414299388e00, - 5.44725003058768775090e00, - 8.74716500199817011941e00, - 5.30324038235394892183e00, - 9.99999999999999997821e-01, - ], - dtype=self.dtype, - ) - PQ = jnp.array( - [ - 9.24408810558863637013e-04, - 8.56288474354474431428e-02, - 1.25352743901058953537e00, - 5.47097740330417105182e00, - 8.76190883237069594232e00, - 5.30605288235394617618e00, - 1.00000000000000000218e00, - ], - dtype=self.dtype, - ) - QP = jnp.array( - [ - -1.13663838898469149931e-02, - -1.28252718670509318512e00, - -1.95539544257735972385e01, - -9.32060152123768231369e01, - -1.77681167980488050595e02, - -1.47077505154951170175e02, - -5.14105326766599330220e01, - -6.05014350600728481186e00, - ], - dtype=self.dtype, - ) - QQ = jnp.array( - [ - 6.43178256118178023184e01, - 8.56430025976980587198e02, - 3.88240183605401609683e03, - 7.24046774195652478189e03, - 5.93072701187316984827e03, - 2.06209331660327847417e03, - 2.42005740240291393179e02, - ], - dtype=self.dtype, - ) - - factor = 25.0 / (x * x) - pp = op_base.foreach_loop(PP, lambda carry, pp_i: carry * factor + pp_i) - pq = op_base.foreach_loop(PQ, lambda carry, pq_i: carry * factor + pq_i) - qp = op_base.foreach_loop(QP, lambda carry, qp_i: carry * factor + qp_i) - qq = op_base.foreach_loop(QQ, lambda carry, qq_i: carry * factor + qq_i) - - return ((pp / pq * jnp.sin(x - 0.785398163397448309615660845819875721) + - 5.0 / x * - (qp / qq) * jnp.cos(x - 0.785398163397448309615660845819875721)) * - 0.797884560802865355879892119868763737 / jnp.sqrt(x)) - - return jnp.piecewise( - self, - [self <= 5.0, self < 0., self == 0.], - [small, negative, zero, default], - ) - - -@op(torch.ops.aten.special_bessel_y1) -@op_base.promote_int_input -def _aten_special_bessel_y1(self): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L2714-L2826 - - def zero(x): - return jnp.array(-jnp.inf, x.dtype) - - def negative(x): - return jnp.array(jnp.nan, x.dtype) - - def small(x): - YP = jnp.array( - [ - 1.26320474790178026440e09, - -6.47355876379160291031e11, - 1.14509511541823727583e14, - -8.12770255501325109621e15, - 2.02439475713594898196e17, - -7.78877196265950026825e17, - ], - dtype=self.dtype, - ) - YQ = jnp.array( - [ - 5.94301592346128195359e02, - 2.35564092943068577943e05, - 7.34811944459721705660e07, - 1.87601316108706159478e10, - 3.88231277496238566008e12, - 6.20557727146953693363e14, - 6.87141087355300489866e16, - 3.97270608116560655612e18, - ], - dtype=self.dtype, - ) - - yp = op_base.foreach_loop(YP, lambda carry, yp_i: carry * (x * x) + yp_i) - yq = op_base.foreach_loop(YQ, lambda carry, yq_i: carry * (x * x) + yq_i) - - return (x * (yp / yq) + - (0.636619772367581343075535053490057448 * - (_aten_special_bessel_j1(x) * jnp.log(x) - 1.0 / x))) - - def default(x): - PP = jnp.array( - [ - 7.62125616208173112003e-04, - 7.31397056940917570436e-02, - 1.12719608129684925192e00, - 5.11207951146807644818e00, - 8.42404590141772420927e00, - 5.21451598682361504063e00, - 1.00000000000000000254e00, - ], - dtype=self.dtype, - ) - PQ = jnp.array( - [ - 5.71323128072548699714e-04, - 6.88455908754495404082e-02, - 1.10514232634061696926e00, - 5.07386386128601488557e00, - 8.39985554327604159757e00, - 5.20982848682361821619e00, - 9.99999999999999997461e-01, - ], - dtype=self.dtype, - ) - QP = jnp.array( - [ - 5.10862594750176621635e-02, - 4.98213872951233449420e00, - 7.58238284132545283818e01, - 3.66779609360150777800e02, - 7.10856304998926107277e02, - 5.97489612400613639965e02, - 2.11688757100572135698e02, - 2.52070205858023719784e01, - ], - dtype=self.dtype, - ) - QQ = jnp.array( - [ - 7.42373277035675149943e01, - 1.05644886038262816351e03, - 4.98641058337653607651e03, - 9.56231892404756170795e03, - 7.99704160447350683650e03, - 2.82619278517639096600e03, - 3.36093607810698293419e02, - ], - dtype=self.dtype, - ) - - factor = 25.0 / (x * x) - pp = op_base.foreach_loop(PP, lambda carry, pp_i: carry * factor + pp_i) - pq = op_base.foreach_loop(PQ, lambda carry, pq_i: carry * factor + pq_i) - qp = op_base.foreach_loop(QP, lambda carry, qp_i: carry * factor + qp_i) - qq = op_base.foreach_loop(QQ, lambda carry, qq_i: carry * factor + qq_i) - - return ((pp / pq * jnp.sin(x - 2.356194490192344928846982537459627163) + - 5.0 / x * - (qp / qq) * jnp.cos(x - 2.356194490192344928846982537459627163)) * - 0.797884560802865355879892119868763737 / jnp.sqrt(x)) - - return jnp.piecewise( - self, - [self <= 5.0, self < 0., self == 0.], - [small, negative, zero, default], - ) - - -@op(torch.ops.aten.special_chebyshev_polynomial_t) -@op_base.promote_int_input -def _aten_special_chebyshev_polynomial_t(self, n): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L2828-L2865 - - @jnp.vectorize - def vectorized(x, n_i): - - def negative_n(x): - return jnp.zeros_like(x) - - def one_x(x): - return jnp.where((x > 0) | (n_i % 2 == 0), jnp.ones_like(x), - -jnp.ones_like(x)) - - def large_n_small_x(x): - return jnp.cos(n_i * jnp.acos(x)) - - def zero_n(x): - return jnp.ones_like(x) - - def one_n(x): - return x - - def default(x): - - def f(_, carry): - p, q = carry - return (q, 2 * x * q - p) - - _, r = jax.lax.fori_loop(0, n_i - 1, f, init_val=(1., x)) - return r - - return jnp.piecewise(x, [ - n_i == 1, n_i == 0, (n_i == 6) & (jnp.abs(x) < 1), - jnp.abs(x) == 1., n_i < 0 - ], [one_n, zero_n, large_n_small_x, one_x, negative_n, default]) - - # Explcicitly vectorize since we must vectorizes over both self and n - return vectorized(self, n.astype(jnp.int64)) - - -@op(torch.ops.aten.special_chebyshev_polynomial_u) -@op_base.promote_int_input -def _aten_special_chebyshev_polynomial_u(self, n): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L2872-L2913 - - @jnp.vectorize - def vectorized(x, n_i): - - def negative_n(x): - return jnp.zeros_like(x) - - def one_x(x): - return jnp.where((x > 0) | (n_i % 2 == 0), n_i + 1, -(n_i + 1)) - - def large_n_small_x(x): - sin_acos_x = jnp.sin(jnp.acos(x)) - return jnp.where( - sin_acos_x != 0, - jnp.sin((n_i + 1) * jnp.acos(x)) / sin_acos_x, - (n_i + 1) * jnp.cos((n_i + 1) * jnp.acos(x)) / x, - ) - - def zero_n(x): - return jnp.ones_like(x) - - def one_n(x): - return 2 * x - - def default(x): - - def f(_, carry): - p, q = carry - return (q, 2 * x * q - p) - - _, r = jax.lax.fori_loop(0, n_i - 1, f, init_val=(1.0, 2 * x)) - return r - - return jnp.piecewise( - x, - [ - n_i == 1, - n_i == 0, - (n_i > 8) & (jnp.abs(x) < 1), - jnp.abs(x) == 1.0, - n_i < 0, - ], - [one_n, zero_n, large_n_small_x, one_x, negative_n, default], - ) - - return vectorized(self, n.astype(jnp.int64)) - - -@op(torch.ops.aten.special_erfcx) -@op_base.promote_int_input -def _aten_special_erfcx(x): - return jnp.exp(x * x) * jax.lax.erfc(x) - - -@op(torch.ops.aten.erfc) -@op_base.promote_int_input -def _aten_erfcx(x): - return jax.lax.erfc(x) - - -@op(torch.ops.aten.special_hermite_polynomial_h) -@op_base.promote_int_input -def _aten_special_hermite_polynomial_h(self, n): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3036-L3061 - - @jnp.vectorize - def vectorized(x, n_i): - - def negative_n(x): - return jnp.zeros_like(x) - - def zero_n(x): - return jnp.ones_like(x) - - def one_n(x): - return 2 * x - - def default(x): - - def f(k, carry): - p, q = carry - return (q, 2 * x * q - 2 * k * p) - - _, r = jax.lax.fori_loop(1, n_i, f, init_val=(1.0, 2 * x)) - return r - - return jnp.piecewise(x, [n_i == 1, n_i == 0, n_i < 0], - [one_n, zero_n, negative_n, default]) - - return vectorized(self, n.astype(jnp.int64)) - - -@op(torch.ops.aten.special_hermite_polynomial_he) -@op_base.promote_int_input -def _aten_special_hermite_polynomial_he(self, n): - # Adapted from https://github.com/pytorch/pytorch/blob/f8f41dcb24cb4f4e87a51bb04847942dd835e496/aten/src/ATen/native/Math.h#L3073-L3098 - - @jnp.vectorize - def vectorized(x, n_i): - - def negative_n(x): - return jnp.zeros_like(x) - - def zero_n(x): - return jnp.ones_like(x) - - def one_n(x): - return x - - def default(x): - - def f(k, carry): - p, q = carry - return (q, x * q - k * p) - - _, r = jax.lax.fori_loop(1, n_i, f, init_val=(1.0, x)) - return r - - return jnp.piecewise(x, [n_i == 1.0, n_i == 0.0, n_i < 0], - [one_n, zero_n, negative_n, default]) - - return vectorized(self, n.astype(jnp.int64)) - - -@op(torch.ops.aten.multinomial, needs_env=True) -def _aten_multinomial(input, - num_samples, - replacement=False, - *, - generator=None, - out=None, - env=None): - assert num_samples <= input.shape[ - -1] or replacement, "cannot take a larger sample than population when replacement=False" - key = env.get_and_rotate_prng_key(generator) - if input.ndim == 1: - return jax.random.choice( - key, input.shape[-1], (num_samples,), replace=replacement, p=input) - else: - return jnp.array([ - jax.random.choice( - key, - input.shape[-1], (num_samples,), - replace=replacement, - p=input[i, :]) for i in range(input.shape[0]) - ]) - - -@op(torch.ops.aten.narrow) -@op(torch.ops.aten.narrow_copy) -def _aten_narrow(input, dim, start, length): - return jax.lax.dynamic_slice_in_dim(input, start, length, axis=dim) - - -@op(torch.ops.aten.flatten) -def _aten_flatten(x, start_dim=0, end_dim=-1): - """ - Flattens a JAX array (similar to torch.flatten). - - Args: - x: The JAX array to be flattened. - start_dim: The first dimension to include in the flattening. - end_dim: The last dimension to include in the flattening. - - Returns: - A flattened JAX array. - """ - shape = x.shape - - if end_dim < 0: - end_dim += len(shape) # Handle negative indexing - - new_shape = (*shape[:start_dim], -1, *shape[end_dim + 1:]) - return jnp.reshape(x, new_shape) - - -@op(torch.ops.aten.new_empty) -def _new_empty(self, size, **kwargs): - dtype = kwargs.get('dtype') - if dtype is not None: - dtype = mappings.t2j_dtype(dtype) - else: - dtype = self.dtype - return jnp.empty(size, dtype=dtype) - - -@op(torch.ops.aten.new_empty_strided) -def _new_empty_strided(self, size, stride, dtype=None, **kwargs): - # Ignore stride, since JAX and torch tensor doesn't share the same memory. - if not dtype: - return jnp.empty(size, dtype=self.dtype) - else: - jax_dtype = mappings.t2j_dtype(dtype) - return jnp.empty(size, dtype=jax_dtype) - - -@op(torch.ops.aten._unsafe_index_put) -def _aten_unsafe_index_put(self, indices, values, accumulate=False): - return _aten_index_put(self, indices, values, accumulate) - - -@op(torch.ops.aten.conj_physical, torch.ops.aten.conj, - torch.ops.aten._conj_physical, torch.ops.aten._conj) -def _aten_conj_physical(self): - return jnp.conjugate(self) - - -@op(torch.ops.aten.log_sigmoid) -def _aten_log_sigmoid(x): - return jax.nn.log_sigmoid(x) - - -# torch.qr -@op(torch.ops.aten.qr) -def _aten_qr(input, *args, **kwargs): - jax_mode = "reduced" - # torch bool param 'simple=True' corresponds to jax 'reduced' mode, - # and simple=False corresponds to jax 'complete' mode. - if kwargs.get("simple") is False: - jax_mode = "complete" - return jax.numpy.linalg.qr(input, mode=jax_mode) - - -# torch.linalg.qr -@op(torch.ops.aten.linalg_qr) -def _aten_linalg_qr(input, *args, **kwargs): - mode = kwargs.get("mode", "reduced") - return jax.numpy.linalg.qr(input, mode=mode) - - -# torch.linalg.matrix_exp -@op(torch.ops.aten.linalg_matrix_exp) -def _aten_linalg_matrix_exp(input): - return jax.scipy.linalg.expm(input) - - -# torch._linalg.slogdet -@op(torch.ops.aten._linalg_slogdet) -def _aten__linalg_slogdet(input): - res = jnp.linalg.slogdet(input) - return res.sign, res.logabsdet - - -# torch.linalg.svd -@op(torch.ops.aten._linalg_svd) -def _aten__linalg_svd(a, full_matrices=False, **kwargs): - return jnp.linalg.svd(a, full_matrices=full_matrices, **kwargs) - - -# torch.linalg.pinv -@op(torch.ops.aten.linalg_pinv.atol_rtol_tensor) -def _aten_linalg_pinv_atol_rtol_tensor(a, rtol=None, **kwargs): - return jnp.linalg.pinv(a, rtol, hermitian=False) - - -# torch.linalg.solve -@op(torch.ops.aten._linalg_solve_ex) -def _aten__linalg_solve_ex(a, b): - batched = False - if b.ndim > 1 and b.shape[-1] == a.shape[-1]: - batched = True - b = b[..., None] - res = jnp.linalg.solve(a, b) - if batched: - res = res.squeeze(-1) - info_shape = a.shape[:-2] - info = jnp.zeros(info_shape, dtype=mappings.t2j_dtype(torch.int32)) - return res, info - - -# torch.linalg.solve_triangular -@op(torch.ops.aten.linalg_solve_triangular) -def _aten_linalg_solve_triangular(a, - b, - *, - upper=True, - left=True, - unitriangular=False): - if left is False: - a = jnp.matrix_transpose(a) - b = jnp.matrix_transpose(b) - upper = not upper - res = jax.scipy.linalg.solve_triangular( - a, b, lower=not upper, unit_diagonal=unitriangular) - if left is False: - res = jnp.matrix_transpose(res) - return res - - -@op(torch.ops.aten.linalg_inv_ex) -def _aten_linalg_inv_ex(a): - ainv = jnp.linalg.inv(a) - info = jnp.zeros(a.shape[:-2], jnp.int32) - return ainv, info - - -@op(torch.ops.aten._linalg_check_errors) -def _aten__linalg_check_errors(*args, **kwargs): - pass - - -@op(torch.ops.aten.median) -def _aten_median(self, dim=None, keepdim=False): - output = _with_reduction_scalar( - functools.partial(jnp.quantile, q=0.5, method='lower'), - self, - dim=dim, - keepdim=keepdim).astype(self.dtype) - if dim is None: - return output - else: - index = _with_reduction_scalar(_get_median_index, self, dim, - keepdim).astype(jnp.int64) - return output, index - - -@op(torch.ops.aten.nanmedian) -def _aten_nanmedian(input, dim=None, keepdim=False, *, out=None): - output = _with_reduction_scalar( - functools.partial(jnp.nanquantile, q=0.5, method='lower'), - input, - dim=dim, - keepdim=keepdim).astype(input.dtype) - if dim is None: - return output - else: - index = _with_reduction_scalar(_get_median_index, input, dim, - keepdim).astype(jnp.int64) - return output, index - - -def _get_median_index(x, axis=None, keepdims=False): - sorted_arg = jnp.argsort(x, axis=axis) - n = x.shape[axis] if axis is not None else x.size - if n % 2 == 1: - index = n // 2 - else: - index = (n // 2) - 1 - if axis is None: - median_index = sorted_arg[index] - else: - median_index = jnp.take(sorted_arg, index, axis=axis) - if keepdims and axis is not None: - median_index = jnp.expand_dims(median_index, axis) - return median_index - - -@op(torch.ops.aten.triangular_solve) -def _aten_triangular_solve(b, - a, - upper=True, - transpose=False, - unittriangular=False): - return (jax.lax.linalg.triangular_solve( - a, - b, - left_side=True, - lower=not upper, - transpose_a=transpose, - unit_diagonal=unittriangular), a) - - -# func: _fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor -@op(torch.ops.aten._fft_c2c) -def _aten__fft_c2c(self, dim, normalization, forward): - if forward: - norm = [ - 'backward', - 'ortho', - 'forward', - ][normalization] - return jnp.fft.fftn(self, axes=dim, norm=norm) - else: - norm = [ - 'forward', - 'ortho', - 'backward', - ][normalization] - return jnp.fft.ifftn(self, axes=dim, norm=norm) - - -@op(torch.ops.aten._fft_r2c) -def _aten__fft_r2c(self, dim, normalization, onesided): - norm = [ - 'backward', - 'ortho', - 'forward', - ][normalization] - if onesided: - return jnp.fft.rfftn(self, axes=dim, norm=norm) - else: - return jnp.fft.fftn(self, axes=dim, norm=norm) - - -@op(torch.ops.aten._fft_c2r) -def _aten__fft_c2r(self, dim, normalization, last_dim_size): - norm = [ - 'forward', - 'ortho', - 'backward', - ][normalization] - if len(dim) == 1: - s = [last_dim_size] - else: - s = None - return jnp.fft.irfftn(self, norm=norm, axes=dim, s=s) - - -@op(torch.ops.aten._trilinear) -def _aten_trilinear(i1, - i2, - i3, - expand1, - expand2, - expand3, - sumdim, - unroll_dim=1): - return _aten_sum( - jnp.expand_dims(i1, expand1) * jnp.expand_dims(i2, expand2) * - jnp.expand_dims(i3, expand3), sumdim) - - -@op(torch.ops.aten.max_unpool2d) -@op(torch.ops.aten.max_unpool3d) -def _aten_max_unpoolxd(input, indices, output_size, stride=None, padding=0): - if output_size is None: - raise ValueError( - "output_size value is not set correctly. It cannot be None or empty.") - - output_size = [input.shape[0], input.shape[1]] + output_size - output = jnp.zeros(output_size, dtype=input.dtype) - - for idx in np.ndindex(input.shape): - max_index = indices[idx] - spatial_dims = output_size[2:] # (D, H, W) - unpooled_spatial_idx = np.unravel_index(max_index, spatial_dims) - full_idx = idx[:2] + unpooled_spatial_idx - output = output.at[full_idx].set(input[idx]) - - return output - - -def _aten_upsample(input, - output_size, - align_corners, - antialias, - method, - scale_factors=None, - scales_h=None, - scales_w=None): - # input: is of type jaxlib.xla_extension.ArrayImpl - image = input - - # https://jax.readthedocs.io/en/latest/_autosummary/jax.image.resize.html - # Resize does not distinguish batch, channel size. - # We need to leave them as is - # https://pytorch.org/vision/stable/transforms.html#supported-input-types-and-conventions - # pytorch image shape is (C,H,W) or (N,C,H,W) - # N - batch size - # C - no of channels - # H,W - heigth, width - - shape = list(image.shape) - # overriding output_size - if scale_factors: - shape[-1] = int(math.floor(shape[-1] * scale_factors[-1])) - shape[-2] = int(math.floor(shape[-2] * scale_factors[-2])) - if scales_h: - shape[-2] = int(math.floor(shape[-2] * scales_h)) - if scales_w: - shape[-1] = int(math.floor(shape[-1] * scales_w)) - # output_size overrides scale_factors, scales_* - if output_size: - shape[-1] = output_size[-1] - shape[-2] = output_size[-2] - - # pytorch upsample_bilinear returns the input as is when the shape is the same as input - if shape == list(image.shape): - return image - - spatial_dims = (2, 3) - if len(shape) == 3: - spatial_dims = (1, 2) - - scale = list([shape[i] / image.shape[i] for i in spatial_dims]) - if scale_factors: - scale = scale_factors - if scales_h: - scale[0] = scales_h - if scales_w: - scale[1] = scales_w - scale = jnp.array(scale) - - # align_corners is not supported in resize() - # https://github.com/jax-ml/jax/issues/11206 - if align_corners: - scale = jnp.array([ - (shape[i] - 1.0) / (image.shape[i] - 1.0) for i in spatial_dims - ]) - - translation = jnp.array([0 for i in spatial_dims]) - - return jax_reimplement.scale_and_translate( - image, - shape, - method=method, - scale=scale, - spatial_dims=spatial_dims, - translation=translation, - antialias=antialias, - ) - - -@op(torch.ops.aten._upsample_bilinear2d_aa) -def _aten_upsample_billinear_aa(input, - output_size, - align_corners, - scale_factors=None, - scales_h=None, - scales_w=None): - return _aten_upsample( - input, - output_size, - align_corners, - True, # antialias - "bilinear", # method - scale_factors, - scales_h, - scales_w) - - -@op(torch.ops.aten._upsample_bicubic2d_aa) -def _aten_upsample_bicubic2d_aa(input, - output_size, - align_corners, - scale_factors=None, - scales_h=None, - scales_w=None): - return _aten_upsample( - input, - output_size, - align_corners, - True, # antialias - "bicubic", # method - scale_factors, - scales_h, - scales_w) - - -@op(torch.ops.aten.polar) -def _aten_polar(abs, angle, *, out=None): - return jax.lax.complex(abs * jnp.cos(angle), abs * jnp.sin(angle)) - - -@op(torch.ops.aten.cdist) -def _aten_cdist(x1, - x2, - p=2.0, - compute_mode='use_mm_for_euclid_dist_if_necessary'): - x1 = x1.astype(jnp.float32) - x2 = x2.astype(jnp.float32) - - if p == 0.0: - # For p = 0, use Hamming-like distance multiplied by the number of elements - return _hamming_distance(x1, x2).astype(jnp.float32) - elif p == 2.0: - # Use optimized Euclidean distance calculation - if compute_mode == 'use_mm_for_euclid_dist_if_necessary' and ( - x1.shape[-2] > 25 or x2.shape[-2] > 25): - return _euclidean_mm(x1, x2) - elif compute_mode == 'use_mm_for_euclid_dist': - return _euclidean_mm(x1, x2) - else: - return _euclidean_direct(x1, x2) - else: - # General p-norm distance calculation - diff = jnp.abs(jnp.expand_dims(x1, -2) - jnp.expand_dims(x2, -3)) - return jnp.sum(jnp.power(diff, p), axis=-1).astype(jnp.float32)**(1 / p) - - -def _hamming_distance(x1, x2): - """ - Computes the Hamming-like distance for p=0. - - Args: - x1: JAX array of shape (..., P, M) - x2: JAX array of shape (..., R, M) - - Returns: - JAX array of shape (..., P, R) representing pairwise Hamming distances. - """ - diff = jnp.not_equal(jnp.expand_dims(x1, -2), jnp.expand_dims(x2, -3)) - - hamming_dist = jnp.sum(diff, axis=-1).astype(jnp.float32) - - return hamming_dist - - -def _euclidean_mm(x1, x2): - """ - Computes the Euclidean distance using matrix multiplication. - - Args: - x1: JAX array of shape (..., P, M) - x2: JAX array of shape (..., R, M) - - Returns: - JAX array of shape (..., P, R) representing pairwise Euclidean distances. - """ - x1_sq = jnp.sum(x1**2, axis=-1, keepdims=True).astype(jnp.float32) - x2_sq = jnp.sum(x2**2, axis=-1, keepdims=True).astype(jnp.float32) - - x2_sq = jnp.swapaxes(x2_sq, -2, -1) - - dot_product = jnp.matmul(x1, jnp.swapaxes(x2, -1, -2)) - - dist_sq = x1_sq + x2_sq - 2 * dot_product - dist_sq = jnp.maximum(dist_sq, 0.0) - dist = jnp.sqrt(dist_sq).astype(jnp.float32) - - return dist - - -def _euclidean_direct(x1, x2): - """ - Computes the Euclidean distance directly without matrix multiplication. - - Args: - x1: JAX array of shape (..., P, M) - x2: JAX array of shape (..., R, M) - - Returns: - JAX array of shape (..., P, R) representing pairwise Euclidean distances. - """ - diff = jnp.expand_dims(x1, -2) - jnp.expand_dims(x2, -3) - - dist_sq = jnp.sum(diff**2, axis=-1).astype(jnp.float32) - - dist_sq = jnp.maximum(dist_sq, 0.0) - - dist = jnp.sqrt(dist_sq).astype(jnp.float32) - - return dist - - -@op(torch.ops.aten.lu_unpack) -def _aten_lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): - # lu_unpack doesnt exist in jax. - # Get commonly used data shape variables - n = LU_data.shape[-2] - m = LU_data.shape[-1] - dim = min(n, m) - - ### Compute the Lower and Upper triangle - if unpack_data: - # Extract lower triangle - L = jnp.tril(LU_data, k=-1) - - #emulate pytorch behavior: Add ones to the diagonal of L - eye = jnp.eye(n, m, dtype=LU_data.dtype) - L = L + eye - - # emulate pytorch behavior: Reshape lower triangle to match pivot - start_indices = jnp.zeros(len(LU_data.shape), dtype=int) - limit_indices = list(LU_data.shape) - limit_indices[-1] = dim - L = jax.lax.slice(L, start_indices, limit_indices) - - # Extract upper triangle - U = jnp.triu(LU_data) - - # emulate pytorch behavior: Reshape upper triangle to match pivot - start_indices = jnp.zeros(len(LU_data.shape), dtype=int) - limit_indices = list(LU_data.shape) - limit_indices[-2] = dim - U = jax.lax.slice(U, start_indices, limit_indices) - else: - # emulate pytroch behavior: return empty tensors - L = torch.empty(torch.Size([0])) - U = torch.empty(torch.Size([0])) - - ### Compute the Permutation matrix - if unpack_pivots: - # We should return a permutation matrix (2D) for each pivot array (1D) - # The shape of the final Permutation matrix depends on the shape of the input - # data and the pivots - - # start with a 2D identity matrix and tile it to the other dims of input data - identity2d = jnp.identity(n, dtype=jnp.float32) - tile_shape = list(LU_data.shape) - tile_shape[-1] = 1 - tile_shape[-2] = 1 - P = jnp.tile(identity2d, tile_shape) - - # closure to be called for each input 2D matrix. - def _lu_unpack_2d(p, pivot): - _pivot = pivot - 1 # pivots are offset by 1 in jax - indices = jnp.array([*range(n)], dtype=jnp.int32) - - def update_indices(i, _indices): - tmp = _indices[i] - _indices = _indices.at[i].set(_indices[_pivot[i]]) - _indices = _indices.at[_pivot[i]].set(tmp) - return _indices - - indices = jax.lax.fori_loop(0, _pivot.size, update_indices, indices) - p = p[jnp.array(indices)] - p = jnp.transpose(p) - return p - - if len(LU_pivots.shape) == 1: - # if we are dealing with a simple 2D input and 1D pivot, call the closure directly - P = _lu_unpack_2d(P, LU_pivots) - else: - # We are dealing with >=3D inputs. Flatten inputs to 3D and use vmap to call the - # closure for each 2D matrix. Finally unflatten the result to match the input data - # shape. - - # reshape permutation matrix to 3d - dim_size = jnp.prod(jnp.array(P.shape[:-2])) - newPshape = (dim_size, P.shape[-2], P.shape[-1]) - reshapedP = P.reshape(newPshape) - - # reshape pivots to 3d - dim_size = jnp.prod(jnp.array(LU_pivots.shape[:-1])) - newPivotshape = (dim_size, LU_pivots.shape[-1]) - reshapedPivot = LU_pivots.reshape(newPivotshape) - - # vmap the reshaped 3d tensors - v_lu_unpack_2d = jax.vmap(_lu_unpack_2d, in_axes=(0, 0)) - unpackedP = v_lu_unpack_2d(reshapedP, reshapedPivot) - - # reshape result back to P's shape - newRetshape = (*P.shape[:-2], unpackedP.shape[-2], unpackedP.shape[-1]) - P = unpackedP.reshape(newRetshape) - else: - # emulate pytroch behavior: return empty tensors - P = torch.empty(torch.Size([0])) - - return P, L, U - - -@op(torch.ops.aten.linear) -def linear(input, weight, bias=None): - res = input @ jnp.transpose(weight) - if bias is not None: - res += bias - return res - - -@op(torch.ops.aten.kthvalue) -def kthvalue(input, k, dim=None, keepdim=False, *, out=None): - if input.ndim == 0: - return input, jnp.array(0) - dimension = -1 - if dim is not None: - dimension = dim - while dimension < 0: - dimension = dimension + input.ndim - values = jax.lax.index_in_dim( - jnp.partition(input, k - 1, dimension), k - 1, dimension, keepdim) - indices = jax.lax.index_in_dim( - jnp.argpartition(input, k - 1, dimension).astype('int64'), k - 1, - dimension, keepdim) - return values, indices - - -@op(torch.ops.aten.take) -def _aten_take(self, index): - return self.flatten()[index] - - -# func: pad(Tensor self, SymInt[] pad, str mode="constant", float? value=None) -> Tensor -@op(torch.ops.aten.pad) -def _aten_pad(self, pad, mode='constant', value=None): - if not isinstance(pad, (tuple, list)) or len(pad) % 2 != 0: - raise ValueError("Padding must be a sequence of even length.") - - num_dims = self.ndim - if len(pad) > 2 * num_dims: - raise ValueError( - f"Padding sequence length ({len(pad)}) exceeds 2 * number of dimensions ({2 * num_dims})." - ) - - # JAX's pad function expects padding for each dimension as a tuple of (low, high) - # We need to reverse the pad sequence and group them for JAX. - # pad = [p_l0, p_r0, p_l1, p_r1, ...] - # becomes ((..., ..., (p_l1, p_r1), (p_l0, p_r0))) - jax_pad_width = [] - # Iterate in reverse pairs - for i in range(len(pad) // 2): - jax_pad_width.append((pad[(2 * i)], pad[(2 * i + 1)])) - - # Pad any leading dimensions with (0, 0) if the pad sequence is shorter - # than the number of dimensions. - for _ in range(num_dims - len(pad) // 2): - jax_pad_width.append((0, 0)) - - # Reverse the jax_pad_width list to match the dimension order - jax_pad_width.reverse() - - if mode == "constant": - if value is None: - value = 0.0 - return jnp.pad( - self, pad_width=jax_pad_width, mode="constant", constant_values=value) - elif mode == "reflect": - return jnp.pad(self, pad_width=jax_pad_width, mode="reflect") - elif mode == "edge": - return jnp.pad(self, pad_width=jax_pad_width, mode="edge") - else: - raise ValueError( - f"Unsupported padding mode: {mode}. Expected 'constant', 'reflect', or 'edge'." - ) - - -@op(torch.ops.aten.is_nonzero) -def _aten_is_nonzero(a): - a = jnp.squeeze(a) - if a.shape == (0,): - raise RuntimeError('bool value of Tensor with no values is ambiguous') - if a.ndim != 0: - raise RuntimeError( - 'bool value of Tensor with more than one value is ambiguous') - return a.item() != 0 - - -@op(torch.ops.aten.logit) -def _aten_logit(self: jax.Array, eps: float | None = None) -> jax.Array: - """ - Computes the logit function of the input tensor. - - logit(p) = log(p / (1 - p)) - - Args: - self: Input tensor. - eps: A small value to clip the input tensor to avoid log(0) or division by zero. - If None, no clipping is performed. - - Returns: - A tensor with the logit of each element of the input. - """ - if eps is not None: - self = jnp.clip(self, eps, 1.0 - eps) - res = jnp.log(self / (1.0 - self)) - res = res.astype(mappings.t2j_dtype(torch.get_default_dtype())) - return res - - -@op(torch.ops.aten.floor_divide) -def _aten_floor_divide(x, y): - res = jnp.floor_divide(x, y) - return res - - -@op(torch.ops.aten._assert_tensor_metadata) -@op(torch.ops.aten._assert_scalar) -def _aten__assert_tensor_metadata(*args, **kwargs): - pass - - -mutation_ops_to_functional = { - torch.ops.aten.add_: - op_base.InplaceOp(torch.ops.aten.add), - torch.ops.aten.sub_: - op_base.InplaceOp(torch.ops.aten.sub), - torch.ops.aten.mul_: - op_base.InplaceOp(torch.ops.aten.mul), - torch.ops.aten.div_: - op_base.InplaceOp(torch.ops.aten.div), - torch.ops.aten.pow_: - op_base.InplaceOp(torch.ops.aten.pow), - torch.ops.aten.lt_: - op_base.InplaceOp(torch.ops.aten.lt), - torch.ops.aten.le_: - op_base.InplaceOp(torch.ops.aten.le), - torch.ops.aten.gt_: - op_base.InplaceOp(torch.ops.aten.gt), - torch.ops.aten.ge_: - op_base.InplaceOp(torch.ops.aten.ge), - torch.ops.aten.eq_: - op_base.InplaceOp(torch.ops.aten.eq), - torch.ops.aten.ne_: - op_base.InplaceOp(torch.ops.aten.ne), - torch.ops.aten.bernoulli_: - op_base.InplaceOp(torch.ops.aten.bernoulli.p), - torch.ops.aten.bernoulli_.float: - op_base.InplaceOp(_aten_bernoulli, is_jax_func=True), - torch.ops.aten.geometric_: - op_base.InplaceOp(torch.ops.aten.geometric), - torch.ops.aten.normal_: - op_base.InplaceOp(torch.ops.aten.normal), - torch.ops.aten.random_: - op_base.InplaceOp(torch.ops.aten.uniform), - torch.ops.aten.uniform_: - op_base.InplaceOp(torch.ops.aten.uniform), - torch.ops.aten.relu_: - op_base.InplaceOp(torch.ops.aten.relu), - # squeeze_ is expected to change tensor's shape. So replace with new value - torch.ops.aten.squeeze_: - op_base.InplaceOp(torch.ops.aten.squeeze, True), - torch.ops.aten.sqrt_: - op_base.InplaceOp(torch.ops.aten.sqrt), - torch.ops.aten.clamp_: - op_base.InplaceOp(torch.ops.aten.clamp), - torch.ops.aten.clamp_min_: - op_base.InplaceOp(torch.ops.aten.clamp_min), - torch.ops.aten.sigmoid_: - op_base.InplaceOp(torch.ops.aten.sigmoid), - torch.ops.aten.tanh_: - op_base.InplaceOp(torch.ops.aten.tanh), - torch.ops.aten.ceil_: - op_base.InplaceOp(torch.ops.aten.ceil), - torch.ops.aten.logical_not_: - op_base.InplaceOp(torch.ops.aten.logical_not), - torch.ops.aten.unsqueeze_: - op_base.InplaceOp(torch.ops.aten.unsqueeze), - torch.ops.aten.transpose_: - op_base.InplaceOp(torch.ops.aten.transpose), - torch.ops.aten.log_normal_: - op_base.InplaceOp(torch.ops.aten.log_normal), - torch.ops.aten.scatter_add_: - op_base.InplaceOp(torch.ops.aten.scatter_add), - torch.ops.aten.scatter_reduce_.two: - op_base.InplaceOp(torch.ops.aten.scatter_reduce), - torch.ops.aten.scatter_: - op_base.InplaceOp(torch.ops.aten.scatter), - torch.ops.aten.bitwise_or_: - op_base.InplaceOp(torch.ops.aten.bitwise_or), - torch.ops.aten.floor_divide_: - op_base.InplaceOp(torch.ops.aten.floor_divide), - torch.ops.aten.remainder_: - op_base.InplaceOp(torch.ops.aten.remainder), - torch.ops.aten.index_put_: - op_base.InplaceOp(torch.ops.aten.index_put), -} - -# Note: tuple comparisons work intuitively, e.g. `_jax_version >= (0, 4, 32)`. -_jax_version = tuple(int(v) for v in jax.version._version.split(".")) - -mutation_needs_env = { - torch.ops.aten.bernoulli_, - torch.ops.aten.bernoulli_.float, -} - -for operator, mutation in mutation_ops_to_functional.items(): - ops_registry.register_torch_dispatch_op( - operator, - mutation, - is_jax_function=False, - is_view_op=True, - needs_env=(operator in mutation_needs_env)) diff --git a/torchax/torchax/ops/jax_reimplement.py b/torchax/torchax/ops/jax_reimplement.py deleted file mode 100644 index d9acc3be51ab..000000000000 --- a/torchax/torchax/ops/jax_reimplement.py +++ /dev/null @@ -1,171 +0,0 @@ -from collections.abc import Sequence -from jax._src.numpy.util import promote_dtypes_inexact -import numpy as np -import jax -from jax import numpy as jnp -from jax._src.util import canonicalize_axis -from jax._src import core -from jax._src.image.scale import _kernels, ResizeMethod -from jax import lax -from typing import Callable - -# TODO: This block of code needs to be revisited based on https://github.com/jax-ml/jax/issues/24106 -# START ----------------- JAX code copied for fixing scale_and_translate ----------------------------- - -# JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L52 - - -def compute_weight_mat(input_size: core.DimSize, output_size: core.DimSize, - scale, translation, kernel: Callable, antialias: bool): - dtype = jnp.result_type(scale, translation) - inv_scale = 1. / scale - # When downsampling the kernel should be scaled since we want to low pass - # filter and interpolate, but when upsampling it should not be since we only - # want to interpolate. - kernel_scale = jnp.maximum(inv_scale, 1.) if antialias else 1. - sample_f = ((jnp.arange(output_size, dtype=dtype) + 0.5) * inv_scale - - translation * inv_scale - 0.5) - x = ( - jnp.abs(sample_f[jnp.newaxis, :] - - jnp.arange(input_size, dtype=dtype)[:, jnp.newaxis]) / - kernel_scale) - weights = kernel(x) - - total_weight_sum = jnp.sum(weights, axis=0, keepdims=True) - weights = jnp.where( - jnp.abs(total_weight_sum) > 1000. * float(np.finfo(np.float32).eps), - jnp.divide(weights, jnp.where(total_weight_sum != 0, total_weight_sum, - 1)), 0) - # Zero out weights where the sample location is completely outside the input - # range. - # Note sample_f has already had the 0.5 removed, hence the weird range below. - - # (barney-s) -------------- returning weights without zeroing --------------------- - return weights - input_size_minus_0_5 = core.dimension_as_value(input_size) - 0.5 - return jnp.where( - jnp.logical_and(sample_f >= -0.5, sample_f - <= input_size_minus_0_5)[jnp.newaxis, :], weights, 0) - # (barney-s) -------------- END returning weights without zeroing --------------------- - - -# JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L86 - - -def _scale_and_translate(x, output_shape: core.Shape, - spatial_dims: Sequence[int], scale, translation, - kernel, antialias: bool, precision): - input_shape = x.shape - assert len(input_shape) == len(output_shape) - assert len(spatial_dims) == len(scale) - assert len(spatial_dims) == len(translation) - if len(spatial_dims) == 0: - return x - contractions = [] - in_indices = list(range(len(output_shape))) - out_indices = list(range(len(output_shape))) - for i, d in enumerate(spatial_dims): - d = canonicalize_axis(d, x.ndim) - m = input_shape[d] - n = output_shape[d] - w = compute_weight_mat(m, n, scale[i], translation[i], kernel, - antialias).astype(x.dtype) - contractions.append(w) - contractions.append([d, len(output_shape) + i]) - out_indices[d] = len(output_shape) + i - contractions.append(out_indices) - return jnp.einsum(x, in_indices, *contractions, precision=precision) - - -# JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L172 - - -# scale and translation here are scalar elements of an np.array, what is the -# correct type annotation? -def scale_and_translate( - image, - shape: core.Shape, - spatial_dims: Sequence[int], - scale, - translation, - # (barney-s) use string - method: str, #(barney-s) | ResizeMethod, - antialias: bool = True, - precision=lax.Precision.HIGHEST): - """Apply a scale and translation to an image. - - Generates a new image of shape 'shape' by resampling from the input image - using the sampling method corresponding to method. For 2D images, this - operation transforms a location in the input images, (x, y), to a location - in the output image according to:: - - (x * scale[1] + translation[1], y * scale[0] + translation[0]) - - (Note the *inverse* warp is used to generate the sample locations.) - Assumes half-centered pixels, i.e the pixel at integer location ``row, col`` - has coordinates ``y, x = row + 0.5, col + 0.5``, and similarly for other input - image dimensions. - - If an output location(pixel) maps to an input sample location that is outside - the input boundaries then the value for the output location will be set to - zero. - - The ``method`` argument expects one of the following resize methods: - - ``ResizeMethod.LINEAR``, ``"linear"``, ``"bilinear"``, ``"trilinear"``, - ``"triangle"`` `Linear interpolation`_. If ``antialias`` is ``True``, uses a - triangular filter when downsampling. - - ``ResizeMethod.CUBIC``, ``"cubic"``, ``"bicubic"``, ``"tricubic"`` - `Cubic interpolation`_, using the Keys cubic kernel. - - ``ResizeMethod.LANCZOS3``, ``"lanczos3"`` - `Lanczos resampling`_, using a kernel of radius 3. - - ``ResizeMethod.LANCZOS5``, ``"lanczos5"`` - `Lanczos resampling`_, using a kernel of radius 5. - - .. _Linear interpolation: https://en.wikipedia.org/wiki/Bilinear_interpolation - .. _Cubic interpolation: https://en.wikipedia.org/wiki/Bicubic_interpolation - .. _Lanczos resampling: https://en.wikipedia.org/wiki/Lanczos_resampling - - Args: - image: a JAX array. - shape: the output shape, as a sequence of integers with length equal to the - number of dimensions of `image`. - spatial_dims: A length K tuple specifying the spatial dimensions that the - passed scale and translation should be applied to. - scale: A [K] array with the same number of dimensions as image, containing - the scale to apply in each dimension. - translation: A [K] array with the same number of dimensions as image, - containing the translation to apply in each dimension. - method: the resizing method to use; either a ``ResizeMethod`` instance or a - string. Available methods are: LINEAR, LANCZOS3, LANCZOS5, CUBIC. - antialias: Should an antialiasing filter be used when downsampling? Defaults - to ``True``. Has no effect when upsampling. - - Returns: - The scale and translated image. - """ - shape = core.canonicalize_shape(shape) - if len(shape) != image.ndim: - msg = ('shape must have length equal to the number of dimensions of x; ' - f' {shape} vs {image.shape}') - raise ValueError(msg) - if isinstance(method, str): - method = ResizeMethod.from_string(method) - if method == ResizeMethod.NEAREST: - # Nearest neighbor is currently special-cased for straight resize, so skip - # for now. - raise ValueError('Nearest neighbor resampling is not currently supported ' - 'for scale_and_translate.') - assert isinstance(method, ResizeMethod) - - kernel = _kernels[method] - image, = promote_dtypes_inexact(image) - scale, translation = promote_dtypes_inexact(scale, translation) - return _scale_and_translate(image, shape, spatial_dims, scale, translation, - kernel, antialias, precision) - - -# END ----------------- END JAX code copied for testing ----------------------------- diff --git a/torchax/torchax/ops/jc10d.py b/torchax/torchax/ops/jc10d.py deleted file mode 100644 index 79544943f918..000000000000 --- a/torchax/torchax/ops/jc10d.py +++ /dev/null @@ -1,52 +0,0 @@ -import torch -import jax -import jax.numpy as jnp - -from torchax.ops import ops_registry - - -def op(*aten, **kwargs): - - def inner(func): - for a in aten: - ops_registry.register_torch_dispatch_op(a, func, **kwargs) - return func - - return inner - - -@op(torch.ops._c10d_functional.all_gather_into_tensor) -def _c10d_all_gather(input, group_size: int, group_name: str): - return jax.lax.all_gather(input, "torch_dist") - - -@op(torch.ops._c10d_functional.all_reduce) -def _c10d_all_reduce(self, reduceOp: str, group_name: str): - - if reduceOp == "sum": - res = jax.lax.psum(self, axis_name="torch_dist") - elif reduceOp == "avg": - res = jax.lax.pmean(self, axis_name="torch_dist") - elif reduceOp == "min": - res = jax.lax.pmin(self, axis_name="torch_dist") - elif reduceOp == "max": - res = jax.lax.pmax(self, axis_name="torch_dist") - else: - raise RuntimeError(f"Reduce op {reduceOp} not implemented") - return res - - -@op(torch.ops._c10d_functional.broadcast) -def _c10d_broadcast(self, src: int, group_name: str): - masked = jnp.where( - jax.lax.axis_index("torch_dist") == src, - self, - jnp.zeros_like(self), - ) - return jax.lax.psum(masked, "torch_dist") - - -@op(torch.ops._c10d_functional.wait_tensor) -def _c10d_wait_tensor(tensor): - # Async tensor is aleady `wait`ed by dispatcher - return tensor diff --git a/torchax/torchax/ops/jimage.py b/torchax/torchax/ops/jimage.py deleted file mode 100644 index 947be0a5e3f0..000000000000 --- a/torchax/torchax/ops/jimage.py +++ /dev/null @@ -1,113 +0,0 @@ -import jax -import jax.numpy as jnp - - -def cubic_kernel(x, a=-0.75): - """Cubic kernel with a = -0.75 (PyTorch-like Keys kernel)""" - absx = jnp.abs(x) - x2 = absx * absx - x3 = x2 * absx - cond1 = (absx <= 1) - cond2 = (absx > 1) & (absx < 2) - f1 = (a + 2) * x3 - (a + 3) * x2 + 1 - f2 = a * x3 - 5 * a * x2 + 8 * a * absx - 4 * a - return jnp.where(cond1, f1, jnp.where(cond2, f2, 0.0)) - - -def compute_contribs(in_size, - out_size, - scale, - support=2.0, - align_corners=False, - dtype=None): - if align_corners: - if out_size == 1: - in_coords = jnp.zeros((1,), dtype=dtype) - else: - in_coords = jnp.linspace(0, in_size - 1, out_size, dtype=dtype) - else: - out_coords = jnp.arange(out_size, dtype=dtype) + 0.5 - in_coords = out_coords / scale - 0.5 - - left_idx = jnp.floor(in_coords).astype(jnp.int32) - 1 - idxs = left_idx[:, None] + jnp.arange(4) - - dx = in_coords[:, None] - idxs - - weights = cubic_kernel(dx) - - weights = weights / jnp.sum(weights, axis=1, keepdims=True) - return idxs, weights - - -def gather_weights(img, idxs, axis): - """Safely gather with boundary handling""" - idxs = jnp.clip(idxs, 0, img.shape[axis] - 1) - return jnp.take(img, idxs, axis=axis) - - -def interpolate_along_axis_bchw(img, idxs, weights, axis): - """ - Interpolate along H (axis=2) or W (axis=3) for tensor (B, C, H, W). - idxs: (out_size, 4) int32 indices - weights: (out_size, 4) float32 weights - """ - assert axis in (2, 3), "Axis must be 2 (H) or 3 (W)" - out_size = idxs.shape[0] - k = idxs.shape[1] # Typically 4 for cubic - - # Clip to input bounds - idxs = jnp.clip(idxs, 0, img.shape[axis] - 1) # (out_size, 4) - - def gather_and_weight(i): - idx = idxs[i] # (4,) - w = weights[i] # (4,) - - def gather_one(offset): - return jnp.take(img, idx[offset], axis=axis) # shape (B, C, H, W) - - gathered = jnp.stack([gather_one(o) for o in range(k)], - axis=0) # (4, B, C, H, W) - weighted = jnp.tensordot(w, gathered, axes=(0, 0)) # (B, C, H, W) - return weighted - - out = jax.vmap(gather_and_weight)( - jnp.arange(out_size)) # (out_size, B, C, H, W) - - # Move the interpolated axis back into place - if axis == 2: # interpolated over H - return jnp.moveaxis(out, 0, 2) # (B, C, out_H, W) - else: # axis == 3, interpolated over W - return jnp.moveaxis(out, 0, 3) # (B, C, H, out_W) - - -def interpolate_bicubic_no_aa(img, out_h, out_w, align_corners=False): - h, w = img.shape[-2:] - if align_corners and out_h > 1: - scale_y = (h - 1) / (out_h - 1) - else: - scale_y = out_h / h - - if align_corners and out_w > 1: - scale_x = (w - 1) / (out_w - 1) - else: - scale_x = out_w / w - - idxs_y, weights_y = compute_contribs( - h, - out_h, - scale_y, - align_corners=align_corners, - dtype=img.dtype, - ) - tmp = interpolate_along_axis_bchw(img, idxs_y, weights_y, axis=2) - - idxs_x, weights_x = compute_contribs( - w, - out_w, - scale_x, - align_corners=align_corners, - dtype=img.dtype, - ) - out = interpolate_along_axis_bchw(tmp, idxs_x, weights_x, axis=3) - return out diff --git a/torchax/torchax/ops/jlibrary.py b/torchax/torchax/ops/jlibrary.py deleted file mode 100644 index 17cdb161c3c3..000000000000 --- a/torchax/torchax/ops/jlibrary.py +++ /dev/null @@ -1,80 +0,0 @@ -"""The `jlibrary` module has functions which help to preserve torch.library ops -during export. This includes aten ops, and custom operations. -""" - -import torch -import torch.nn as nn -import torchax -from torchax.ops import jaten -import jax -import functools - - -def _jit_composite_impl(composite_name, jaxpr_impl, **jit_args): - """Wrap a jaxpr in a jitted function with the proper composite name - TODO: Wrap JIT in a `stablehlo.composite` op, instead of generating a call op. - """ - - def composite_impl(*args): - return jaxpr_impl(*args) - - composite_impl.__name__ = composite_name - composite_impl.__qualname__ = composite_name - return jax.jit(composite_impl, **jit_args) - - -def register_jax_composite(composite_name, impl, *ops, **jit_args): - """Register a composite using a JAX implementation. - composite_name - The name of the library op to use in the exported composite - impl - A JAX lowering for the library operation - *ops - Variadic torch.ops to lower using `impl`. - **jit_args - Additional parameters to forward to JAX jit. - - This is used to register custom lowerings with an explicit jaxpr - implementation, such as preserving a specific aten op using a jaten impl. - - For custom torch op registration with a decomposition written in torch, - use `register_torch_composite`. - - For jit params and troubleshooting see: - https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html - """ - - @jaten.op(*ops) - def _composite_impl(*args): - return _jit_composite_impl(composite_name, impl, **jit_args)(*args) - - -def register_torch_composite(composite_name, impl, *ops, **jit_args): - """Register a torch decomposition as a composite. - This is useful for registerring custom torch op libraries as composite ops. - - The `impl` can be the `@impl` used to define the torch custom library op. - This must be a function or module impl that provides the decompositions, and - not an instance of the custom op. - - TODO: Better error handling, or can we make this an instance of the op as a param? - - For jit params and troubleshooting see: - https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html - """ - - @jaten.op(*ops) - def _composite_impl(*args): - - class ImplWrapper(torch.nn.Module): - - def __init__(self): - super().__init__() - - def forward(self, *args): - return impl(*args) - - # Note: avoid refactoring to share code with register_jaxpr_composite. - # The `extract_jax` call must live in the `@jaten.op` handler. If called - # outside of the handler, we would build the jaxpr representation of the - # module once during registration, potentially missing op registrations that - # come after. I.e. may miss nested abstractions if we build jaxpr AoT. - state, jfn = torchax.extract_jax(ImplWrapper()) - jaxpr_impl = lambda *args: jfn(state, tuple([*args])) - return _jit_composite_impl(composite_name, jaxpr_impl, **jit_args)(*args) diff --git a/torchax/torchax/ops/jtorch.py b/torchax/torchax/ops/jtorch.py deleted file mode 100644 index ac2042a7511e..000000000000 --- a/torchax/torchax/ops/jtorch.py +++ /dev/null @@ -1,576 +0,0 @@ -"""Tensor constructor overrides""" - -import math -import collections.abc -import functools -from typing import Optional, Sequence, Tuple -import numpy as np - -import jax -import jax.numpy as jnp -from jax.experimental.pallas.ops.tpu import flash_attention -from jax.experimental.shard_map import shard_map - -import torch -from torchax.ops.ops_registry import register_torch_function_op -from torchax.ops import op_base, mappings, jaten, jimage -import torchax.tensor -from torchax.view import View, NarrowInfo -import torch.utils._pytree as pytree - - -def register_function(torch_func, **kwargs): - return functools.partial(register_torch_function_op, torch_func, **kwargs) - - -@register_function(torch.as_tensor, is_jax_function=False, needs_env=True) -@op_base.convert_dtype( - use_default_dtype=False) # Attempt to infer type from elements -def _as_tensor(data, dtype=None, device=None, env=None): - if isinstance(data, torch.Tensor): - return env._to_copy(data, dtype, device) - if isinstance(data, np.ndarray): - jax_res = jnp.asarray(data) - else: - jax_res = _tensor(data, dtype=dtype) - return torchax.tensor.Tensor(jax_res, env) - - -@register_function(torch.tensor) -@op_base.convert_dtype( - use_default_dtype=False) # Attempt to infer type from elements -def _tensor(data, *, dtype=None, **kwargs): - python_types_to_torch_types = { - bool: jnp.bool, - int: jnp.int64, - float: jnp.float32, - complex: jnp.complex64, - } - if not dtype: - leaves = jax.tree_util.tree_leaves(data) - if len(leaves) > 0: - dtype = python_types_to_torch_types.get(type(leaves[0])) - - return jnp.array( - data, dtype=dtype or mappings.t2j_dtype(torch.get_default_dtype())) - - -@register_function(torch.allclose) -def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): - return jnp.allclose(input, other, rtol, atol, equal_nan) - - -@register_function(torch.angle) -def _torch_angle(input): - if input.dtype.name == "int64": - input = input.astype(jnp.dtype("float32")) - return jnp.angle(input) - - -@register_function(torch.argsort) -def _torch_argsort(input, dim=-1, descending=False, stable=False): - expanded = False - if input.ndim == 0: - # for self of rank 0: - # torch.any(x, 0), torch.any(x, -1) works; - # torch.any(x, 1) throws out of bounds, so it's - # behavior is the same as a jnp array of rank 1 - expanded = True - input = jnp.expand_dims(input, 0) - res = jnp.argsort(input, axis=dim, descending=descending, stable=stable) - if expanded: - res = res.squeeze() - return res - - -@register_function(torch.diag) -def _diag(input, diagonal=0): - return jnp.diag(input, k=diagonal) - - -@register_function(torch.einsum) -@register_function(torch.ops.aten.einsum) -def _einsum(equation, *operands): - - def get_params(*a): - inner_list = a[0] - if not isinstance(inner_list, jax.Array): - if len(inner_list) == 1: - A = inner_list - return A - elif len(inner_list) == 2: - A, B = inner_list - return A, B - return operands - - assert isinstance(equation, str), "Only accept str equation" - filtered_operands = get_params(*operands) - return jnp.einsum(equation, *filtered_operands) - - -def _sdpa_reference( - query, - key, - value, - attn_mask=None, - dropout_p=0.0, - is_causal=False, - scale=None, - enable_gqa=False, -) -> torch.Tensor: - L, S = query.size(-2), key.size(-2) - scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale - attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) - if is_causal: - assert attn_mask is None - temp_mask = torch.ones( - L, S, dtype=torch.bool, device=query.device).tril(diagonal=0) - attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) - attn_bias.to(query.dtype) - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) - else: - attn_bias += attn_mask - if enable_gqa: - key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) - value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) - - attn_weight = query @ key.transpose(-2, -1) * scale_factor - attn_weight += attn_bias - attn_weight = torch.softmax(attn_weight, dim=-1) - if dropout_p > 0: - attn_weight = torch.dropout(attn_weight, dropout_p, train=True) - return attn_weight @ value - - -from jax.sharding import PartitionSpec - - -def _tpu_flash_attention(query, key, value, env): - fsdp_partition = PartitionSpec("fsdp") - - def wrap_flash_attention(query, key, value): - block_sizes = flash_attention.BlockSizes( - block_b=min(2, query.shape[0]), - block_q=min(512, query.shape[2]), - block_k_major=min(512, key.shape[2]), - block_k=min(512, key.shape[2]), - block_q_major_dkv=min(512, query.shape[2]), - block_k_major_dkv=min(512, key.shape[2]), - block_k_dkv=min(512, key.shape[2]), - block_q_dkv=min(512, query.shape[2]), - block_k_major_dq=min(512, key.shape[2]), - block_k_dq=min(256, key.shape[2]), - block_q_dq=min(1024, query.shape[2]), - ) - return flash_attention.flash_attention( - query, key, value, causal=True, block_sizes=block_sizes) - - if env.config.shmap_flash_attention: - wrap_flash_attention = shard_map( - wrap_flash_attention, - mesh=env._mesh, - in_specs=(fsdp_partition, fsdp_partition, fsdp_partition), - out_specs=fsdp_partition, - check_rep=False, - ) - # return flash_attn_mapped(query, key, value) - return wrap_flash_attention(query, key, value) - - -@register_function(torch.nn.functional.one_hot) -def one_hot(tensor, num_classes=-1): - if num_classes == -1: - num_classes = jnp.max(tensor) + 1 - return jax.nn.one_hot(tensor, num_classes).astype(jnp.int64) - - -@register_function(torch.nn.functional.pad) -def pad(tensor, pad, mode="constant", value=None): - # For padding modes that have different names between Torch and NumPy, this - # dict provides a Torch-to-NumPy translation. Any string not in this dict will - # be passed through as-is. - MODE_NAME_TRANSLATION = { - "circular": "wrap", - "replicate": "edge", - } - - numpy_mode = MODE_NAME_TRANSLATION.get(mode, mode) - - num_prefix_dims = tensor.ndim - len(pad) // 2 - - numpy_pad_width = [(0, 0)] * num_prefix_dims - nd_slice = [slice(None)] * num_prefix_dims - - for i in range(len(pad) - 2, -1, -2): - pad_start, pad_end = pad[i:i + 2] - slice_start, slice_end = None, None - - if pad_start < 0: - slice_start = -pad_start - pad_start = 0 - - if pad_end < 0: - slice_end = pad_end - pad_end = 0 - - numpy_pad_width.append((pad_start, pad_end)) - nd_slice.append(slice(slice_start, slice_end)) - - nd_slice = tuple(nd_slice) - - # `jax.numpy.pad` complains if we provide an irrelevant `constant_values` arg, - # even if the value we pass in is `None`. (It treats `None` as `nan`.) - kwargs = dict() - if mode == "constant" and value is not None: - kwargs["constant_values"] = value - - # The "replicate" mode pads first and then slices, whereas the "circular" mode - # slices first and then pads. The latter approach deals with smaller tensors, - # so we default to that option in modes where the order of operations doesn't - # affect the result. - if mode == "replicate": - return jnp.pad(tensor, numpy_pad_width, mode=numpy_mode, **kwargs)[nd_slice] - else: - return jnp.pad(tensor[nd_slice], numpy_pad_width, mode=numpy_mode, **kwargs) - - -@register_function( - torch.nn.functional.scaled_dot_product_attention, - is_jax_function=False, - needs_env=True, -) -@register_function( - torch.ops.aten.scaled_dot_product_attention, - is_jax_function=False, - needs_env=True) -def scaled_dot_product_attention( - query, - key, - value, - attn_mask=None, - dropout_p=0.0, - is_causal=False, - scale=None, - enable_gqa=False, - env=None, -) -> torch.Tensor: - - if env.config.use_tpu_flash_attention: - jquery, jkey, jvalue = env.t2j_iso((query, key, value)) - res = _tpu_flash_attention(jquery, jkey, jvalue, env) - return env.j2t_iso(res) - - return _sdpa_reference(query, key, value, attn_mask, dropout_p, is_causal, - scale, enable_gqa) - - -@register_function( - torch.Tensor.__getitem__, is_jax_function=False, is_view_op=True) -def getitem(self, indexes): - - if isinstance(indexes, list) and isinstance(indexes[0], int): - # list of int, i.e. x[[1, 2]] NOT x[1, 2] (the second would be tuple of int) - indexes = (indexes,) - elif isinstance(indexes, list): - indexes = tuple(indexes) - - def is_narrow_slicing(): - tensor_free = not pytree.tree_any( - lambda x: isinstance(x, torch.Tensor) or isinstance(x, jax.Array), - indexes) - list_free = not isinstance(indexes, tuple) or all( - [False if isinstance(x, list) else True for x in indexes]) - return tensor_free and list_free - - if is_narrow_slicing(): - return View(self, view_info=NarrowInfo(indexes), env=self._env) - - indexes = self._env.t2j_iso(indexes) - return torchax.tensor.Tensor(self._elem[indexes], self._env) - - -@register_function(torch.corrcoef) -def _corrcoef(x): - if x.dtype.name == "int64": - return jnp.corrcoef(x).astype(jnp.float32) - return jnp.corrcoef(x) - - -@register_function(torch.sparse.mm, is_jax_function=False) -def _sparse_mm(mat1, mat2, reduce="sum"): - return torch.mm(mat1, mat2) - - -@register_function(torch.isclose) -def _aten_isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): - return jnp.isclose(input, other, rtol, atol, equal_nan) - - -@register_function(torch.linalg.det) -def linalg_det(input): - return jnp.linalg.det(input) - - -@register_function(torch.ones) -def _ones(*size: int, dtype=None, **kwargs): - if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): - size = size[0] - return jaten._ones(size, dtype=dtype) - - -@register_function(torch.zeros, is_jax_function=True) -def _zeros(*size: int, dtype=None, **kwargs): - if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): - size = size[0] - return jaten._zeros(size, dtype=dtype) - - -@register_function(torch.eye) -@op_base.convert_dtype() -def _eye(n: int, m: Optional[int] = None, *, dtype=None, **kwargs): - return jnp.eye(n, m, dtype=dtype) - - -@register_function(torch.full) -@op_base.convert_dtype() -def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs): - # TODO: handle torch.Size - return jnp.full(size, fill_value, dtype=dtype) - - -@register_function(torch.empty) -@op_base.convert_dtype() -def empty(*size: Sequence[int], dtype=None, **kwargs): - if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): - size = size[0] - return jnp.empty(size, dtype=dtype) - - -@register_function(torch.arange, is_jax_function=True) -def arange( - start, - end=None, - step=None, - out=None, - dtype=None, - layout=torch.strided, - device=None, - requires_grad=False, - pin_memory=None, -): - if end is None: - end = start - start = 0 - if step is None: - step = 1 - return jaten._aten_arange(start, end, step, dtype=dtype) - - -@register_function(torch.empty_strided, is_jax_function=True) -def empty_strided( - size, - stride, - *, - dtype=None, - layout=None, - device=None, - requires_grad=False, - pin_memory=False, -): - return empty(size, dtype=dtype, requires_grad=requires_grad) - - -@register_function(torch.unravel_index) -def unravel_index(indices, shape): - return jnp.unravel_index(indices, shape) - - -@register_function(torch.rand, is_jax_function=True, needs_env=True) -def rand(*size, **kwargs): - if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): - size = size[0] - return jaten._rand(size, **kwargs) - - -@register_function(torch.randn, is_jax_function=True, needs_env=True) -def randn( - *size, - generator=None, - out=None, - dtype=None, - layout=torch.strided, - device=None, - requires_grad=False, - pin_memory=False, - env=None, -): - if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): - size = size[0] - return jaten._aten_randn(size, generator=generator, dtype=dtype, env=env) - - -@register_function(torch.randint, is_jax_function=False, needs_env=True) -def randint(*args, **kwargs): - return jaten._aten_randint(*args, **kwargs) - - -@register_function(torch.logdet) -def logdet(input): - _, logabsdet = jaten._aten__linalg_slogdet(input) - return logabsdet - - -@register_function(torch.linalg.slogdet) -def linalg_slogdet(input): - sign, logabsdet = jaten._aten__linalg_slogdet(input) - return torch.return_types.slogdet((sign, logabsdet)) - - -@register_function(torch.tensor_split) -def tensor_split(input, indices_or_sections, dim=0): - return jnp.array_split(input, indices_or_sections, axis=dim) - - -@register_function(torch.linalg.solve) -def linalg_solve(a, b): - res, _ = jaten._aten__linalg_solve_ex(a, b) - return res - - -@register_function(torch.linalg.solve_ex) -def linalg_solve_ex(a, b): - res, info = jaten._aten__linalg_solve_ex(a, b) - return res, info - - -@register_function(torch.linalg.svd) -def linalg_svd(a, full_matrices=True): - return jaten._aten__linalg_svd(a, full_matrices=full_matrices) - - -@register_function(torch.linalg.matrix_power) -def matrix_power(A, n, *, out=None): - return jnp.linalg.matrix_power(A, n) - - -@register_function(torch.svd) -def svd(a, some=True, compute_uv=True): - if not compute_uv: - S = jaten._aten__linalg_svd(a, full_matrices=False)[1] - U = jnp.zeros((a.shape[-2], a.shape[-2]), dtype=a.dtype) - V = jnp.zeros((a.shape[-1], a.shape[-1]), dtype=a.dtype) - return U, S, V - U, S, V = jaten._aten__linalg_svd(a, full_matrices=not some) - return U, S, jnp.matrix_transpose(V) - - -@register_function(torch.cdist) -def _cdist(x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"): - return jaten._aten_cdist(x1, x2, p, compute_mode) - - -@register_function(torch.lu) -def lu(A, **kwargs): - lu, pivots, _ = jax.lax.linalg.lu(A) - # JAX pivots are offset by 1 compared to torch - _pivots = pivots + 1 - info_shape = pivots.shape[:-1] - info = jnp.zeros(info_shape, dtype=mappings.t2j_dtype(torch.int32)) - if kwargs["get_infos"] == True: - return lu, _pivots, info - return lu, _pivots - - -@register_function(torch.lu_solve) -def lu_solve(b, LU_data, LU_pivots, **kwargs): - # JAX pivots are offset by 1 compared to torch - _pivots = LU_pivots - 1 - x = jax.scipy.linalg.lu_solve((LU_data, _pivots), b) - return x - - -@register_function(torch.linalg.tensorsolve) -def linalg_tensorsolve(A, b, dims=None): - # examples: - # A = torch.randn(2, 3, 6), b = torch.randn(3, 2) - # A = torch.randn(2, 3, 6), b = torch.randn(2, 3) -> torch.Size([3, 6]) - # A = torch.randn(9, 2, 6, 3) b = torch.randn(6, 3) -> torch.Size([6, 3]) - # A = torch.randn(9, 2, 3, 6) b = torch.randn(6, 3) -> torch.Size([3, 6]) - # A = torch.randn(18, 6, 3) b = torch.randn(18) -> torch.Size([6, 3]) - # A = torch.randn(3, 8, 4, 6) b = torch.randn(4, 6) -> torch.Size([4,6]) - # A = torch.randn(3, 8, 1, 2, 2, 6) b = torch.randn(3, 4, 2) -> torch.Size([2, 2, 6]) - - # torch allows b to be shaped differently. - # especially when axes are moved using dims. - # ValueError: After moving axes to end, leading shape of a must match shape of b. got a.shape=(3, 2, 6), b.shape=(2, 3) - # So we are handling the moveaxis and forcing b's shape to match what jax expects - if dims is not None: - A = jnp.moveaxis(A, dims, len(dims) * (A.ndim - 1,)) - dims = None - if A.shape[:b.ndim] != b.shape: - b = jnp.reshape(b, A.shape[:b.ndim]) - return jnp.linalg.tensorsolve(A, b, axes=dims) - - -@register_function(torch.nn.functional.linear) -def functional_linear(self, weights, bias=None): - res = jnp.einsum("...a,ba->...b", self, weights) - if bias is not None: - res += bias - return res - - -@register_function(torch.nn.functional.interpolate) -def functional_interpolate( - input, - size: Tuple[int, int], - scale_factor: Optional[float], - mode: str, - align_corners: bool, - recompute_scale_factor: bool, - antialias: bool, -): - supported_methods = ( - "nearest", - "linear", - "bilinear", - "trilinear", - "cubic", - "bicubic", - "tricubic", - "lanczos3", - "lanczos5", - ) - is_jax_supported = mode in supported_methods - if not is_jax_supported: - raise torchax.tensor.OperatorNotFound( - f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}" - ) - # None check - antialias = antialias or False - align_corners = align_corners or False - - if mode in ('cubic', 'bicubic', - 'tricubic') and not antialias and size is not None: - return jimage.interpolate_bicubic_no_aa( - input, - size[0], - size[1], - align_corners, - ) - else: - # fallback - raise torchax.tensor.OperatorNotFound( - f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}" - ) - - -@register_function(torch.Tensor.repeat_interleave) -def torch_Tensor_repeat_interleave(self, - repeats, - dim=None, - *, - output_size=None): - return jnp.repeat(self, repeats, axis=dim, total_repeat_length=output_size) diff --git a/torchax/torchax/ops/jtorchvision_nms.py b/torchax/torchax/ops/jtorchvision_nms.py deleted file mode 100644 index 57832b560b03..000000000000 --- a/torchax/torchax/ops/jtorchvision_nms.py +++ /dev/null @@ -1,234 +0,0 @@ -""" -Forked at: https://raw.githubusercontent.com/mlperf/training_results_v0.7/refs/heads/master/Google/benchmarks/ssd/implementations/ssd-research-JAX-tpu-v3-4096/nms.py -""" - -import functools -from typing import List, Union, Optional, Tuple - -import torch -from jax import lax -import jax.numpy as jnp -from . import ops_registry - -_NMS_TILE_SIZE = 256 - - -def _bbox_overlap(boxes, gt_boxes): - """Find Bounding box overlap. - - Args: - boxes: first set of bounding boxes - gt_boxes: second set of boxes to compute IOU - - Returns: - iou: Intersection over union matrix of all input bounding boxes - """ - bb_y_min, bb_x_min, bb_y_max, bb_x_max = jnp.split( - ary=boxes, indices_or_sections=4, axis=2) - gt_y_min, gt_x_min, gt_y_max, gt_x_max = jnp.split( - ary=gt_boxes, indices_or_sections=4, axis=2) - - # Calculates the intersection area. - i_xmin = jnp.maximum(bb_x_min, jnp.transpose(gt_x_min, [0, 2, 1])) - i_xmax = jnp.minimum(bb_x_max, jnp.transpose(gt_x_max, [0, 2, 1])) - i_ymin = jnp.maximum(bb_y_min, jnp.transpose(gt_y_min, [0, 2, 1])) - i_ymax = jnp.minimum(bb_y_max, jnp.transpose(gt_y_max, [0, 2, 1])) - i_area = jnp.maximum((i_xmax - i_xmin), 0) * jnp.maximum((i_ymax - i_ymin), 0) - - # Calculates the union area. - bb_area = (bb_y_max - bb_y_min) * (bb_x_max - bb_x_min) - gt_area = (gt_y_max - gt_y_min) * (gt_x_max - gt_x_min) - # Adds a small epsilon to avoid divide-by-zero. - u_area = bb_area + jnp.transpose(gt_area, [0, 2, 1]) - i_area + 1e-8 - - # Calculates IoU. - iou = i_area / u_area - - return iou - - -def _self_suppression(in_args): - iou, _, iou_sum = in_args - batch_size = iou.shape[0] - can_suppress_others = jnp.reshape( - jnp.max(iou, 1) <= 0.5, [batch_size, -1, 1]).astype(iou.dtype) - iou_suppressed = jnp.reshape( - (jnp.max(can_suppress_others * iou, 1) <= 0.5).astype( - iou.dtype), [batch_size, -1, 1]) * iou - iou_sum_new = jnp.sum(iou_suppressed, [1, 2]) - return iou_suppressed, jnp.any(iou_sum - iou_sum_new > 0.5), iou_sum_new - - -def _cross_suppression(in_args): - boxes, box_slice, iou_threshold, inner_idx = in_args - batch_size = boxes.shape[0] - new_slice = lax.dynamic_slice(boxes, [0, inner_idx * _NMS_TILE_SIZE, 0], - [batch_size, _NMS_TILE_SIZE, 4]) - iou = _bbox_overlap(new_slice, box_slice) - ret_slice = jnp.expand_dims((jnp.all(iou < iou_threshold, [1])).astype( - box_slice.dtype), 2) * box_slice - return boxes, ret_slice, iou_threshold, inner_idx + 1 - - -def _suppression_loop_body(in_args): - """Process boxes in the range [idx*_NMS_TILE_SIZE, (idx+1)*_NMS_TILE_SIZE). - - Args: - in_args: A tuple of arguments: boxes, iou_threshold, output_size, idx - - Returns: - boxes: updated boxes. - iou_threshold: pass down iou_threshold to the next iteration. - output_size: the updated output_size. - idx: the updated induction variable. - """ - boxes, iou_threshold, output_size, idx = in_args - num_tiles = boxes.shape[1] // _NMS_TILE_SIZE - batch_size = boxes.shape[0] - - # Iterates over tiles that can possibly suppress the current tile. - box_slice = lax.dynamic_slice(boxes, [0, idx * _NMS_TILE_SIZE, 0], - [batch_size, _NMS_TILE_SIZE, 4]) - - def _loop_cond(in_args): - _, _, _, inner_idx = in_args - return inner_idx < idx - - _, box_slice, _, _ = lax.while_loop(_loop_cond, _cross_suppression, - (boxes, box_slice, iou_threshold, 0)) - - # Iterates over the current tile to compute self-suppression. - iou = _bbox_overlap(box_slice, box_slice) - mask = jnp.expand_dims( - jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [1, -1]) - > jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [-1, 1]), 0) - iou *= (jnp.logical_and(mask, iou >= iou_threshold)).astype(iou.dtype) - - def _loop_cond2(in_args): - _, loop_condition, _ = in_args - return loop_condition - - suppressed_iou, _, _ = lax.while_loop(_loop_cond2, _self_suppression, - (iou, True, jnp.sum(iou, [1, 2]))) - suppressed_box = jnp.sum(suppressed_iou, 1) > 0 - box_slice *= jnp.expand_dims(1.0 - suppressed_box.astype(box_slice.dtype), 2) - - # Uses box_slice to update the input boxes. - mask = jnp.reshape((jnp.equal(jnp.arange(num_tiles), - idx)).astype(boxes.dtype), [1, -1, 1, 1]) - boxes = jnp.tile(jnp.expand_dims( - box_slice, 1), [1, num_tiles, 1, 1]) * mask + jnp.reshape( - boxes, [batch_size, num_tiles, _NMS_TILE_SIZE, 4]) * (1 - mask) - boxes = jnp.reshape(boxes, [batch_size, -1, 4]) - - # Updates output_size. - output_size += jnp.sum(jnp.any(box_slice > 0, [2]).astype(jnp.int32), [1]) - return boxes, iou_threshold, output_size, idx + 1 - - -def non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold): - """A wrapper that handles non-maximum suppression. - - Assumption: - * The boxes are sorted by scores unless the box is a dot (all coordinates - are zero). - * Boxes with higher scores can be used to suppress boxes with lower scores. - - The overal design of the algorithm is to handle boxes tile-by-tile: - - boxes = boxes.pad_to_multiply_of(tile_size) - num_tiles = len(boxes) // tile_size - output_boxes = [] - for i in range(num_tiles): - box_tile = boxes[i*tile_size : (i+1)*tile_size] - for j in range(i - 1): - suppressing_tile = boxes[j*tile_size : (j+1)*tile_size] - iou = _bbox_overlap(box_tile, suppressing_tile) - # if the box is suppressed in iou, clear it to a dot - box_tile *= _update_boxes(iou) - # Iteratively handle the diagnal tile. - iou = _box_overlap(box_tile, box_tile) - iou_changed = True - while iou_changed: - # boxes that are not suppressed by anything else - suppressing_boxes = _get_suppressing_boxes(iou) - # boxes that are suppressed by suppressing_boxes - suppressed_boxes = _get_suppressed_boxes(iou, suppressing_boxes) - # clear iou to 0 for boxes that are suppressed, as they cannot be used - # to suppress other boxes any more - new_iou = _clear_iou(iou, suppressed_boxes) - iou_changed = (new_iou != iou) - iou = new_iou - # remaining boxes that can still suppress others, are selected boxes. - output_boxes.append(_get_suppressing_boxes(iou)) - if len(output_boxes) >= max_output_size: - break - - Args: - scores: a tensor with a shape of [batch_size, anchors]. - boxes: a tensor with a shape of [batch_size, anchors, 4]. - max_output_size: a scalar integer `Tensor` representing the maximum number - of boxes to be selected by non max suppression. - iou_threshold: a float representing the threshold for deciding whether boxes - overlap too much with respect to IOU. - Returns: - nms_scores: a tensor with a shape of [batch_size, anchors]. It has same - dtype as input scores. - nms_proposals: a tensor with a shape of [batch_size, anchors, 4]. It has - same dtype as input boxes. - """ - batch_size = boxes.shape[0] - num_boxes = boxes.shape[1] - pad = int(jnp.ceil( - float(num_boxes) / _NMS_TILE_SIZE)) * _NMS_TILE_SIZE - num_boxes - boxes = jnp.pad(boxes.astype(jnp.float32), [[0, 0], [0, pad], [0, 0]]) - scores = jnp.pad(scores.astype(jnp.float32), [[0, 0], [0, pad]]) - num_boxes += pad - - def _loop_cond(in_args): - unused_boxes, unused_threshold, output_size, idx = in_args - return jnp.logical_and( - jnp.min(output_size) < max_output_size, idx - < num_boxes // _NMS_TILE_SIZE) - - selected_boxes, _, output_size, _ = lax.while_loop( - _loop_cond, _suppression_loop_body, - (boxes, iou_threshold, jnp.zeros([batch_size], jnp.int32), 0)) - idx = num_boxes - lax.top_k( - jnp.any(selected_boxes > 0, [2]).astype(jnp.int32) * - jnp.expand_dims(jnp.arange(num_boxes, 0, -1), 0), - max_output_size)[0].astype(jnp.int32) - idx = jnp.minimum(idx, num_boxes - 1) - idx = jnp.reshape( - idx + jnp.reshape(jnp.arange(batch_size) * num_boxes, [-1, 1]), [-1]) - - return idx - boxes = jnp.reshape((jnp.reshape(boxes, [-1, 4]))[idx], - [batch_size, max_output_size, 4]) - boxes = boxes * (jnp.reshape(jnp.arange(max_output_size), [1, -1, 1]) - < jnp.reshape(output_size, [-1, 1, 1])).astype(boxes.dtype) - scores = jnp.reshape( - jnp.reshape(scores, [-1, 1])[idx], [batch_size, max_output_size]) - scores = scores * (jnp.reshape(jnp.arange(max_output_size), [1, -1]) - < jnp.reshape(output_size, [-1, 1])).astype(scores.dtype) - return scores, boxes - - -# registry: - - -def nms(boxes, scores, iou_threshold): - max_output_size = boxes.shape[0] - boxes = boxes.reshape((1, *boxes.shape)) - scores = scores.reshape((1, *scores.shape)) - res = non_max_suppression_padded(scores, boxes, max_output_size, - iou_threshold) - return res - - -try: - import torch - import torchvision - ops_registry.register_torch_dispatch_op(torch.ops.torchvision.nms, nms) -except Exception: - pass diff --git a/torchax/torchax/ops/mappings.py b/torchax/torchax/ops/mappings.py deleted file mode 100644 index 4eb7c6996159..000000000000 --- a/torchax/torchax/ops/mappings.py +++ /dev/null @@ -1,147 +0,0 @@ -from jax import dlpack as jaxdl -import jax.numpy as jnp -import numpy -import torch -import torch.func -import torch.utils.dlpack as torchdl -import torch.utils._mode_utils as mode_utils - -NUMPY_UNSUPPORTED_DTYPES = { - torch.bfloat16: jnp.bfloat16, - torch.float8_e4m3fn: jnp.float8_e4m3fn, - torch.float8_e4m3fnuz: jnp.float8_e4m3fnuz, - torch.float8_e5m2: jnp.float8_e5m2, - torch.float8_e5m2fnuz: jnp.float8_e5m2fnuz, -} - - -def t2j(t, use_dlpack=True): - is_bool = False - if t.dtype == torch.bool: - is_bool = True - t = t.to(torch.int8) - - t = t.to_dense() - - if not t.is_contiguous(): - t = t.contiguous() - - res = None - if use_dlpack: - try: - res = jaxdl.from_dlpack(t) - except Exception: - pass - - if res is None: - # https://github.com/google/jax/issues/7657 - # https://github.com/google/jax/issues/17784 - if t.dtype in NUMPY_UNSUPPORTED_DTYPES: - nparray = (t.cpu().detach().to(torch.float32).numpy() - ) # handle dtypes not supported by numpy - else: - nparray = t.cpu().detach().numpy() - res = jnp.asarray(nparray) - if t.dtype in NUMPY_UNSUPPORTED_DTYPES: - res = res.astype(NUMPY_UNSUPPORTED_DTYPES[t.dtype]) - - if is_bool: - res = res.astype(jnp.bool_) - return res - - -def j2t(x, use_dlpack=True): - with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): - res = None - if use_dlpack: - try: - dl = jaxdl.to_dlpack(x) - res = torchdl.from_dlpack(dl) - except Exception: - res = None - - orig_dtype = None - if res is None: - orig_dtype = None - if x.dtype == jnp.bfloat16.dtype: - orig_dtype = x.dtype - x = x.astype(jnp.float32.dtype) - res = torch.from_numpy(numpy.asarray(x)) - - if x.dtype == jnp.bool_: - res = res.to(torch.bool) - - if orig_dtype is not None: - res = res.to(j2t_dtype(orig_dtype)) - return res - - -TORCH_DTYPE_TO_JAX = { - # NO_MAPPING : jnp.float0.dtype (signless scalar int), - torch.bool: - jnp.bool_.dtype, - # NO_MAPPING : jnp.int4.dtype, - torch.int8: - jnp.int8.dtype, - torch.int16: - jnp.int16.dtype, - torch.int32: - jnp.int32.dtype, - torch.int64: - jnp.int64.dtype, - torch.long: - jnp.int64.dtype, - # NO_MAPPING : jnp.uint4 - torch.uint8: - jnp.uint8.dtype, - torch.uint16: - jnp.uint16.dtype, - torch.uint32: - jnp.uint32.dtype, - torch.uint64: - jnp.uint64.dtype, - # NO_MAPPING : jnp.float8_e4m3b11fnuz.dtype, - torch.float8_e4m3fn: - jnp.float8_e4m3fn.dtype, - # NO_MAPPING : jnp.float8_e4m3fnuz.dtype, - torch.float8_e5m2: - jnp.float8_e5m2.dtype, - # NO_MAPPING : jnp.float8_e5m2fnuz.dtype, - torch.bfloat16: - jnp.bfloat16.dtype, - torch.half: - jnp.float16.dtype, - torch.float16: - jnp.float16.dtype, - torch.float32: - jnp.float32.dtype, - torch.float64: - jnp.float64.dtype, - torch.double: - jnp.double.dtype, - torch.complex64: - jnp.complex64.dtype, - torch.complex128: - jnp.complex128.dtype, - None: - None, -} - -JAX_DTYPE_TO_TORCH = {value: key for key, value in TORCH_DTYPE_TO_JAX.items()} -# Add imprecise mappings for some JAX dtypes which don't have torch analogues -JAX_DTYPE_TO_TORCH[jnp.dtype('int4')] = torch.int8 -JAX_DTYPE_TO_TORCH[jnp.dtype('uint4')] = torch.uint8 - - -def t2j_dtype(dtype): - if dtype not in TORCH_DTYPE_TO_JAX: - raise RuntimeError( - f'Attempting to convert unknown type: {dtype} to jax type,') - return TORCH_DTYPE_TO_JAX[dtype] - - -def j2t_dtype(dtype): - if dtype not in JAX_DTYPE_TO_TORCH: - raise RuntimeError( - f'Attempting to convert unknown type: {dtype} to torch type,') - return JAX_DTYPE_TO_TORCH[dtype] diff --git a/torchax/torchax/ops/op_base.py b/torchax/torchax/ops/op_base.py deleted file mode 100644 index d69e85ae50a6..000000000000 --- a/torchax/torchax/ops/op_base.py +++ /dev/null @@ -1,131 +0,0 @@ -import functools -import jax -import jax.numpy as jnp -import numpy as np -import torch -from torchax.ops import mappings -from torchax.view import View -from torchax import types -import sys - -from typing import Callable, Optional, ParamSpec, Concatenate - - -class InplaceOp: - - def __init__(self, - functional_op, - replace=False, - position_to_mutate=0, - is_jax_func=False): - self.functional = functional_op - self.replace = replace - self.position_to_mutate = position_to_mutate - self.is_jax_func = is_jax_func - - def __call__(self, *args, **kwargs): - to_mutate = args[self.position_to_mutate] - view_value = to_mutate - if isinstance(to_mutate, View): - view_value = to_mutate.torch() - # Convert the target View to a Tensor, and - # leave the rest args as is. If other args are - # also View, they will be converted to tensors - # in the self.functional dispatch. - env = view_value._env - if self.is_jax_func: - view_value, args, kwargs = env.t2j_iso((view_value, args, kwargs)) - new_value_jax = self.functional(view_value, *args[1:], **kwargs) - new_value = env.j2t_iso(new_value_jax) - else: - new_value = self.functional(view_value, *args[1:], **kwargs) - - if isinstance(to_mutate, View): - to_mutate.update(new_value) - else: - if self.replace: - to_mutate._elem = new_value._elem - else: - to_mutate.copy_(new_value) - return to_mutate - - -class OutVariant: - - def __call__(self, *args, **kwargs): - to_mutate = kwargs['out'] - del kwargs['out'] - to_mutate._elem = self.functional(*args, **kwargs)._elem - return to_mutate - - -P = ParamSpec('P') - - -def convert_dtype(use_default_dtype: bool = True): - """Converts `dtype` kwarg of function from torch to JAX. - - Args: - use_default_dtype: Whether to use torch default dtype if none is provided. - - Returns: - A decorator that wraps a JAX implementation of a torch function. - """ - - def decorator(func: types.TorchCallable): - - @functools.wraps(func) - def wrapper(*args: P.args, - dtype: Optional[torch.dtype] = None, - **kwargs: P.kwargs): - if not dtype and use_default_dtype: - dtype = torch.get_default_dtype() - if isinstance(dtype, torch.dtype): - jax_dtype = mappings.t2j_dtype(dtype) - else: - jax_dtype = dtype - - return func(*args, dtype=jax_dtype, **kwargs) - - return wrapper - - return decorator - - -def maybe_convert_constant_dtype(val: Optional[types.JaxValue], - dtype: Optional[jnp.dtype]): - """Optionally converts scalar constant's dtype using `numpy` - - Use in cases where you require a constant and can't handle a traced array. - """ - if val and dtype: - if isinstance(val, jax.Array): - return maybe_convert_constant_dtype(val.item(), dtype) - - return np.array(val, dtype) - - return val - - -def promote_int_input(f: Callable[Concatenate[jax.Array, P], types.JaxValue]): - """If the first argument is an int array, promote it to float32.""" - - @functools.wraps(f) - def wrapper(x: jax.Array, *args: P.args, **kwargs: P.kwargs): - if x.dtype in [jnp.int8, jnp.int16, jnp.int32, jnp.int64]: - x = x.astype(mappings.t2j_dtype(torch.get_default_dtype())) - - return f(x, *args, **kwargs) - - return wrapper - - -def foreach_loop(seq: jax.Array, - fn: Callable[[jax.Array, jax.Array], jax.Array], - init_val=0.0): - """Run `fn` for each element of 1D array `seq`. - - Similar to `functools.reduce`, but implemented with `jax.lax.fori_loop`.""" - assert len(seq.shape) == 1 - return jax.lax.fori_loop(0, len(seq), lambda i, carry: fn(carry, seq[i]), - init_val) diff --git a/torchax/torchax/ops/ops_registry.py b/torchax/torchax/ops/ops_registry.py deleted file mode 100644 index aa0d61cbb491..000000000000 --- a/torchax/torchax/ops/ops_registry.py +++ /dev/null @@ -1,55 +0,0 @@ -import dataclasses -import logging -from torchax.types import JaxCallable, TorchCallable - -from typing import Union, Dict - - -@dataclasses.dataclass -class Operator: - torch_op: TorchCallable - func: Union[TorchCallable, JaxCallable] - is_jax_function: bool - is_user_defined: bool - needs_env: bool - is_view_op: bool - - -all_aten_ops: Dict[TorchCallable, Operator] = {} -all_torch_functions: Dict[TorchCallable, Operator] = {} - - -def register_torch_dispatch_op(aten_op, - impl_callable, - is_jax_function=True, - is_user_defined=False, - needs_env=False, - is_view_op=False): - op = Operator( - aten_op, - impl_callable, - is_jax_function=is_jax_function, - is_user_defined=is_user_defined, - needs_env=needs_env, - is_view_op=is_view_op) - if aten_op in all_aten_ops: - logging.warning(f'Duplicate op registration for {aten_op}') - all_aten_ops[aten_op] = op - return impl_callable - - -def register_torch_function_op(torch_func, - impl_callable, - is_jax_function=True, - is_user_defined=False, - needs_env=False, - is_view_op=False): - op = Operator( - torch_func, - impl_callable, - is_jax_function=is_jax_function, - is_user_defined=is_user_defined, - needs_env=needs_env, - is_view_op=is_view_op) - all_torch_functions[torch_func] = op - return impl_callable diff --git a/torchax/torchax/tensor.py b/torchax/torchax/tensor.py deleted file mode 100644 index a325c51dfc10..000000000000 --- a/torchax/torchax/tensor.py +++ /dev/null @@ -1,711 +0,0 @@ -import threading -import logging -import sys -import contextlib -from typing import Optional, Any -import jax -import jax.numpy as jnp -import numpy -import itertools -import torch -import torch.distributed._functional_collectives -import torch.func -import torch.utils._mode_utils as mode_utils -import torch.utils._python_dispatch as torch_dispatch -import torch.utils._pytree as torch_pytree -from torchax.view import View -from torchax import config -from torchax.ops import mappings, ops_registry -from torchax import amp - -logger = logging.getLogger(__name__) - - -class OperatorNotFound(Exception): - pass - - -@contextlib.contextmanager -def log_nested(env, message): - if env.config.debug_print_each_op: - print((" " * log_nested.level) + message, file=sys.stderr) - log_nested.level += 1 - yield - log_nested.level -= 1 - - -log_nested.level = 0 - - -class Tensor(torch.Tensor): - - @staticmethod - def __new__(cls, elem, env, requires_grad=False): - dtype = mappings.j2t_dtype(elem.dtype) - shape = list(elem.shape) - for i, s in enumerate(shape): - if not isinstance(s, int): - shape[i] = 1 - if dtype is None: - dtype = torch.float32 - #dispatch_keys = torch.DispatchKeySet(torch._C.DispatchKey.PrivateUse1).add(torch._C.DispatchKey.AutogradPrivateUse1) - if not (dtype.is_floating_point or dtype.is_complex): - requires_grad = False - - return torch.Tensor._make_wrapper_subclass( - cls, - shape, - dtype=dtype, - device='meta', - requires_grad=requires_grad, - ) - - def __init__(self, elem: jax.Array, env: "Environment", requires_grad=False): - super().__init__() - self._elem = elem - self._env = env - - def __str__(self): - return "Tensor({} {})".format(str(type(self._elem)), str(self._elem)) - - __repr__ = __str__ - - @property - def shape(self): - return torch.Size(self._elem.shape) - - @property - def ndim(self): - return len(self._elem.shape) - - def flatten(self, start_dim=0, end_dim=-1): - if end_dim == -1: - end_dim = self.ndim - new_shape = ( - self._elem.shape[:start_dim] + (-1,) + self._elem.shape[end_dim + 1:]) - new_elem = jnp.reshape(self._elem, new_shape) - return Tensor(new_elem, self._env) - # return torch.reshape(self, new_shape) - - def __setitem__(self, key, val): - key, val = self._env.t2j_iso((key, val)) - self._elem = self._elem.at[key].set(val) - - def type_as(self, other): - self._elem = self._elem.astype(other._elem.dtype) - return self - - __torch_function__ = torch._C._disabled_torch_function_impl - - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - # TODO(hanq): figure out why is dispatch mode not sufficient - if func == torch.ops._c10d_functional.wait_tensor.default: - return args[0]._env.dispatch(func, types, args, kwargs) - if func == torch.ops.prim.device.default: - return torch.device('privateuseone', 0) - raise AssertionError( - 'torchax Tensors can only do math within the torchax environment.' - 'Please wrap your code with `with torchax.default_env()` or ' - 'call torchax.enable_globally() before.') - - def detach(self): - return Tensor(jax.lax.stop_gradient(self.jax()), self._env) - - def numpy(self) -> numpy.ndarray: - import numpy as np - - return np.array(self._elem) - - def jax(self) -> jax.Array: - return self._elem - - def torch(self) -> torch.Tensor: - return self._env.j2t_copy(self.jax()) - - @property - def dtype(self): - return mappings.j2t_dtype(self._elem.dtype) - - def dim(self): - return self.ndim - - @property - def device(self): - return torch.device("jax:0") - - @property - def jax_device(self): - return self._elem.device - - @property - def data(self): - logger.warning( - "In-place to .data modifications still results a copy on TPU") - return self - - @data.setter - def data(self, other): - if isinstance(other, Tensor): - self._elem = other._elem - - def apply_jax(self, jax_function, *args, **kwargs): - # Call a jax function on _elem - res = jax_function(self._elem, *args, **kwargs) - return self._env.j2t_iso(res) - - def apply_jax_(self, jax_function, *args, **kwargs): - self._elem = jax_function(self._elem, *args, **kwargs) - return self - - def tolist(self): - return self._elem.tolist() - - def shard_(self, sharding): - self.apply_jax_(jax.lax.with_sharding_constraint, sharding) - - -def debug_accuracy(func, args, kwargs, current_output): - args_torch, kwargs_torch, out_torch = torch_pytree.tree_map_only( - torch.Tensor, lambda x: x.torch(), (args, kwargs, current_output)) - - with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): - if "device" in kwargs_torch: - kwargs_torch["device"] = "cpu" # do the torch native for comparison - expected_out = func(*args_torch, **kwargs_torch) - - flattened_current_out, _ = torch_pytree.tree_flatten(out_torch) - flattened_expected_out, _ = torch_pytree.tree_flatten(expected_out) - - for ex, real in zip(flattened_expected_out, flattened_current_out): - if isinstance(ex, torch.Tensor) and ex.dtype != real.dtype: - ex = ex.to(real.dtype) - try: - if isinstance(ex, torch.Tensor) and not torch.allclose( - ex, real, atol=1e-3, equal_nan=True): - import pdb - - pdb.set_trace() - except: - import pdb - - pdb.set_trace() - - return True - - -def _make_debug_msg(is_dispatch, log_args, func, args, kwargs): - - def _display(a): - if isinstance(a, torch.Tensor): - return f"Tensor of {type(a)}: {a.dtype}{a.shape}" - elif isinstance(a, jax.Array): - return f"Jax Array of {type(a)}: {a.dtype}{a.shape}" - else: - return str(a) - - kwargs = kwargs or {} - title = "DISPATCH" if is_dispatch else "FUNCTION" - args_msg = "args: " + ",".join(_display(a) for a in args) if log_args else "" - kwargs_msg = ("kwargs: " + - ",".join(f"{key}: {_display(a)}" for key, a in kwargs.items()) - if log_args else "") - return f"{title}: {_name_of_func(func)} {args_msg} ~ {kwargs_msg}" - - -class XLAFunctionMode(torch.overrides.TorchFunctionMode): - """Context manager that dispatches torch function calls to JAX.""" - - def __init__(self, env): - self.env = env - - def __torch_function__(self, - func, - types, - args=(), - kwargs=None) -> torch.Tensor: - message = f"FUNCTION: {_name_of_func(func)}" - if self.env.config.debug_print_each_op_operands: - message = message + "f" - message = _make_debug_msg(False, - self.env.config.debug_print_each_op_operands, - func, args, kwargs) - with log_nested(self.env, message): - try: - return self.env.dispatch(func, types, args, kwargs) - except OperatorNotFound: - pass - if _name_of_func(func) in ( - "rot90"): # skip rot90 with k%4==0 due to no change - if len(args) >= 2 and type(args[1]) == int: - if (args[1]) % 4 == 0: - return args[0] - return func(*args, **(kwargs or {})) - - -class XLADispatchMode(torch_dispatch.TorchDispatchMode): - - def __init__(self, env): - self.env = env - - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - message = _make_debug_msg(True, - self.env.config.debug_print_each_op_operands, - func, args, kwargs) - with log_nested(self.env, message): - if isinstance(func, torch._ops.OpOverloadPacket): - with self: - return func(*args, **kwargs) - # Only functions under these namespaces will be intercepted - if func.namespace not in ( - "aten", - "_c10d_functional", - "torchvision", - "xla", - ): - return func(*args, **kwargs) - return self.env.dispatch(func, types, args, kwargs) - - -def _name_of_func(func): - if hasattr(func, "name"): - return func.name() - return func.__name__ - - -# Constructors that don't take other tensor as input -TENSOR_CONSTRUCTORS = { - torch.ones, - torch.zeros, - torch.empty, - torch.empty_strided, - torch.tensor, - torch.arange, - torch.eye, - torch.randn, - torch.rand, - torch.randint, - torch.full, - torch.as_tensor, -} - -# TODO(wen): use existing types, either from torch or jax -SUPPORTED_JAX_PLATFROM = ["cpu", "tpu"] - - -class RuntimeProperty: - mesh: Any - prng: Any - autocast_dtype: Any - - def __init__(self, mesh, prng, autocast_dtype): - self.mesh = mesh - self.prng = prng - self.autocast_dtype = autocast_dtype - - def override(self, **kwargs): - return OverrideProperty(self, kwargs) - - def get_and_rotate_prng_key(self): - old_key = self.prng - new_prng_key, next_key = jax.random.split(old_key) - self.prng = new_prng_key - return next_key - - -class OverrideProperty(RuntimeProperty): - - def __init__(self, parent, override): - self.parent = parent - self._override = dict(override) - - def __getattr__(self, name): - if name in self._override: - return self._override[name] - return getattr(self.parent, name) - - -class Environment(contextlib.ContextDecorator): - """This class holds a set of configurations and "globals" needed - - for executing torch program using jax. - Things included so far: - - op registry - PRNGKey - Configs - - Also helper functions to manipulate those. - """ - - def __init__(self, configuration=None): - self._function_mode = XLAFunctionMode(self) - self._dispatch_mode = XLADispatchMode(self) - - # name is torch callable - self._ops = {} - self._decomps = {} - - self.load_ops() - - _mesh = None - self.config = configuration or config.Configuration() - - self.enabled = False - - autocast_dtype = None - - _prng_key = jax.random.key(torch.initial_seed() % (1 << 63)) - self._property = threading.local() - self._property.content = [ - RuntimeProperty( - mesh=_mesh, prng=_prng_key, autocast_dtype=autocast_dtype) - ] - - @property - def param(self): - return self._property.content[-1] - - def manual_seed(self, key): - jax_key = jax.random.PRNGKey(key) - new_prop = self.param.override(prng=jax_key) - self._property.content.append(new_prop) - - @property - def prng_key(self): - return self.param.prng - - def _should_use_torchax_tensor(self, device): - if device is None: - device = torch.get_default_device() - - if isinstance(device, torch.device): - device = device.type - - if ':' in device: - device = device.split(':')[0] - - match device: - case 'cpu': - return False - case 'cuda': - return self.config.treat_cuda_as_jax_device - case 'jax': - return True - case 'privateuseone': - return True - case 'meta': - return self.enabled - return False - - def load_ops(self): - from torchax.ops import jaten, jtorch, jc10d, jtorchvision_nms - - for k, v in itertools.chain(ops_registry.all_aten_ops.items(), - ops_registry.all_torch_functions.items()): - if v.is_jax_function: - self._ops[k] = v - else: - self._decomps[k] = v - - from torchax.decompositions import DECOMPOSITIONS, MUTABLE_DECOMPOSITION - - for k, v in DECOMPOSITIONS.items(): - if k not in self._decomps: - self._decomps[k] = ops_registry.Operator( - k, - v, - is_jax_function=False, - is_user_defined=False, - needs_env=False, - is_view_op=k in MUTABLE_DECOMPOSITION, - ) - - def _get_op_or_decomp(self, func): - - def _get_from_dict(op_dict, op): - op = op_dict.get(func) - if op is None and isinstance(func, torch._ops.OpOverloadPacket): - op = op_dict.get(func.default) - if op is None and isinstance(func, torch._ops.OpOverload): - op = op_dict.get(func.overloadpacket) - return op - - op = _get_from_dict(self._ops, func) - - if op is None: - # fallback to decompose - op = _get_from_dict(self._decomps, func) - - if op is None: - raise OperatorNotFound( - f"Operator with name {_name_of_func(func)} has no lowering") - - return op - - def _is_same_device(self, the_tensor, new_device): - if new_device is None: - return True - if new_device == 'meta' and the_tensor.device.type == 'jax': - return True - if the_tensor.device.type != new_device: - if the_tensor.device.type == 'cuda': - return self.config.treat_cuda_as_jax_device - return False - return True - - def _to_copy(self, the_tensor, new_dtype, new_device): - if isinstance(the_tensor, View): - the_tensor = the_tensor.torch() - if isinstance(new_device, torch.device): - new_device = new_device.type - res = the_tensor - if not self._is_same_device(the_tensor, new_device): - if isinstance(the_tensor, Tensor): - torch_tensor = self.j2t_copy(the_tensor._elem) - with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): - return torch_tensor.to(device=new_device, dtype=new_dtype) - else: - arr = self.t2j_copy(the_tensor) - res = Tensor(arr, self, the_tensor.requires_grad) - - if new_dtype is not None and new_dtype != res.dtype: - if isinstance(res, Tensor): - res = res.apply_jax(jnp.astype, mappings.t2j_dtype(new_dtype)) - else: - with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): - return res.to(device=new_device, dtype=new_dtype) - return res - - def get_and_rotate_prng_key(self, - generator: Optional[torch.Generator] = None): - if generator is not None: - return jax.random.PRNGKey(generator.initial_seed() % (2**63)) - return self.param.get_and_rotate_prng_key() - - def _handle_tensor_constructor(self, func, args, kwargs): - device = kwargs.get("device") - if self._should_use_torchax_tensor(device): - # don't set default device, let caller set it - requires_grad = kwargs.get("requires_grad", False) - op = self._get_op_or_decomp(func) - if op.needs_env: - kwargs['env'] = self - if op.is_jax_function: - (args, kwargs) = self.t2j_iso((args, kwargs)) - res = op.func(*args, **kwargs) - if isinstance(res, jax.Array): - res = Tensor(res, self, requires_grad) - return res - else: - with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): - return func(*args, **kwargs) - - def _torch_Tensor_to(self, args, kwargs): - the_tensor = args[0] - args = args[1:] - if len(args) >= 1 and isinstance(args[0], torch.Tensor): - dtype = args[0].dtype - device = args[0].device - return self._to_copy(the_tensor, dtype, device) - device = kwargs.get("device") - dtype = kwargs.get("dtype") - # args like pin_memory etc that we will ignore - args = list(filter(lambda x: not isinstance(x, bool), args)) - if len(args) >= 2: - device, dtype, *_ = args - elif len(args) == 1 and isinstance(args[0], torch.dtype): - dtype = args[0] - elif len(args) == 1: - device = args[0] - return self._to_copy(the_tensor, dtype, device) - - def dispatch(self, func, types, args, kwargs): - kwargs = kwargs or {} - if func in TENSOR_CONSTRUCTORS: - return self._handle_tensor_constructor(func, args, kwargs) - if func in ( - torch.Tensor.to, - torch.ops.aten.lift_fresh.default, - torch.ops.aten._to_copy, - torch.ops.aten._to_copy.default, - ): - return self._torch_Tensor_to(args, kwargs) - - # If the func doesn't act on Tensor, and is not a tensor constructor, - # We should skip and let torch handle it. - - tensor_args = [ - t for t in torch_pytree.tree_flatten(args)[0] - if isinstance(t, torch.Tensor) - ] - - def is_not_torchax_tensor(x): - return not isinstance(x, Tensor) and not isinstance(x, View) - - if tensor_args and all(is_not_torchax_tensor(t) for t in tensor_args): - res = func(*args, **kwargs) - return res - - with jax.named_scope(_name_of_func(func)): - op = self._get_op_or_decomp(func) - - old_args, old_kwargs = args, kwargs - with self._dispatch_mode: - args, kwargs = torch_pytree.tree_map_only( - torch.distributed._functional_collectives.AsyncCollectiveTensor, - torch.distributed._functional_collectives.wait_tensor, - (args, kwargs), - ) - - try: - if not op.is_view_op: - args, kwargs = self.v2t_iso((args, kwargs)) - - with self: - if self.param.autocast_dtype is not None: - autocast_policy = amp.autocast_policy.get(func) - if autocast_policy is not None: - args, kwargs = amp.execute_policy(autocast_policy, args, kwargs, - self.param.autocast_dtype) - - if op.is_jax_function: - args, kwargs = self.t2j_iso((args, kwargs)) - except AssertionError: - if self.config.debug_mixed_tensor: - breakpoint() - else: - raise - - if op.needs_env: - kwargs["env"] = self - - if op.is_jax_function: - res = op.func(*args, **kwargs) - else: - # enable dispatch mode because this op could be a composite autograd op - # meaning, it will decompose in C++ - with self._dispatch_mode: - res = op.func(*args, **kwargs) - - if op.is_jax_function: - res = self.j2t_iso(res) - - if self.config.force_materialize_views and isinstance(res, View): - res = res.torch() - - if self.config.debug_accuracy_for_each_op: - debug_accuracy(func, old_args, old_kwargs, res) - return res - - def enable_torch_modes(self): - self._dispatch_mode.__enter__() - self._function_mode.__enter__() - self.enabled = True - - def disable_torch_modes(self, *exc): - if not exc: - exc = (None, None, None) - self._function_mode.__exit__(*exc) - self._dispatch_mode.__exit__(*exc) - self.enabled = False - - def __enter__(self): - self.enable_torch_modes() - return self - - def __exit__(self, *exc): - self.disable_torch_modes(*exc) - - def _move_one_value(self, val): - if isinstance(val, torch.nn.Module): - with self: - return val.to("jax") - if isinstance(val, Tensor): - return val - if isinstance(val, torch.Tensor): - return Tensor(self.t2j_copy(val), self) - return val - - def to_xla(self, torchvalues): - # tensors are torch.Tensors (not XLATensor) - res = torch_pytree.tree_map(self._move_one_value, torchvalues) - return res - - def t2j_iso(self, torchtensors): - """Convert torchax Tensor to jax array. - - This function will not copy, will just unwrap the inner jax array out. - Note: iso is short for "isomorphic" - """ - - def to_jax(x): - if self.config.allow_mixed_math_with_scalar_tensor and not isinstance( - x, Tensor): - if x.squeeze().ndim == 0: - return x.item() - if isinstance( - x, torch.distributed._functional_collectives.AsyncCollectiveTensor): - x = x.wait() - assert isinstance(x, Tensor) or isinstance(x, View), ( - f"Expect a Tensor or a View but got {type(x)}; usually this means there is a mixed math between XLATensor and torch.Tensor" - ) - return x.jax() - - res = torch_pytree.tree_map_only(torch.Tensor, to_jax, torchtensors) - return res - - def v2t_iso(self, views): - - def to_tensor(x): - if isinstance(x, View): - return x.torch() - return x - - res = torch_pytree.tree_map_only(View, to_tensor, views) - return res - - def j2t_iso(self, jaxarray): - """Convert jax array to torchax Tensor. - - This function will not copy, will just wrap the jax array with a torchax Tensor - Note: iso is short for "isomorphic" - """ - return torch_pytree.tree_map_only(jax.Array, lambda x: Tensor(x, self), - jaxarray) - - def j2t_copy(self, args): - """Convert torch.Tensor in cpu to a jax array - - This might involves copying the data (depending if dlpack is enabled) - """ - return torch_pytree.tree_map_only( - jax.Array, - lambda x: mappings.j2t(x, self.config.use_dlpack_for_data_conversion), - args) - - def t2j_copy(self, args): - """Convert jax array to torch.Tensor in cpu. - - This might involves copying the data (depending if dlpack is enabled) - """ - return torch_pytree.tree_map_only( - torch.Tensor, - lambda x: mappings.t2j(x, self.config.use_dlpack_for_data_conversion), - args) - - def override_op_definition(self, op_to_override, op_impl): - self._ops[op_to_override] = ops_registry.Operator( - op_to_override, - op_impl, - is_jax_function=False, - is_user_defined=True, - needs_env=False, - ) - - @contextlib.contextmanager - def override_property(self, **kwargs): - new_prop = self.param.override(**kwargs) - self._property.content.append(new_prop) - yield - self._property.content.pop() diff --git a/torchax/torchax/train.py b/torchax/torchax/train.py deleted file mode 100644 index fb4e16fc48ee..000000000000 --- a/torchax/torchax/train.py +++ /dev/null @@ -1,117 +0,0 @@ -import collections -import functools -import torch -import jax -import torchax -from torchax import interop -from torchax.interop import torch_view, jax_view -import optax - -remat = torch_view(jax.remat) -mark_sharding = torch_view(jax.lax.with_sharding_constraint) - - -def make_train_step(model_fn, loss_fn, optax_optimizer, remat_policy=None): - """Make a function that do one train step given model and loss. - - model_fn: a function representing the model's forward: - i.e. has signature Callable[weights, buffers, args] -> result. Where, - weights is a pytree of trainable parameters - buffers is a pytree of non-trainable parameters / constants - args is the input data loaded from the data set - result is the return value of the model - loss_fn: a function to compute loss. - i.e. it has signature of Callable[result, label] -> loss - where, result is what model_fn returned - loss is loaded from the dataloader. - optax_optimizer: the optimizer from optax library. for example, optax.adam - remat_policy: One of jax.ad_checkpoint.checkpoint_policies, specifies how - to do gradient checkpointing. If None, then it means checkpoint everything. - """ - env = torchax.default_env() - - def loss(weights, buffers, args, label): # inputs are XLATensor - with env, jax.named_scope('compute_loss'): - res = model_fn(weights, buffers, args) - l = loss_fn(res, label) - return l - - loss = interop.gradient_checkpoint(loss, kwargs={'policy': remat_policy}) - grad_fn = interop.jax_value_and_grad(loss) - - def step(weights, buffers, opt_state, args, label): #inputs are array - with jax.named_scope('compute_gradient'): - loss, gradient = grad_fn(weights, buffers, args, label) - - with jax.named_scope("optimizer_updates"): - updates, opt_state = interop.call_jax(optax_optimizer.update, gradient, - opt_state, weights) - weights = interop.call_jax(optax.apply_updates, weights, updates) - return loss, weights, opt_state - - # TODO: apply jax.jit so the user don't have to. - return step - - -class Container: - pass - - -class ScannedModule(torch.nn.Module): - - def __init__(self, module_list, checkpoint_policy=None): - super().__init__() - - self.c = None - assert module_list - self.c = Container() - self.c.one_mod = module_list[0] - self.checkpoint_policy = checkpoint_policy - - weights = self._stack_layer_weights(module_list) - self.layer_weights_keys = list(self.c.one_mod.state_dict().keys()) - self.params = torch.nn.ParameterDict({ - self._param_name_new(k): v for k, v in weights.items() - }) - - def _stack_layer_weights(self, module_list): - # Create weights such that, for every [n, m] weights - # becomes [k, n, m] where k is number of layer - # i.e. stacking layer weights together - temp = collections.defaultdict(list) - for m in module_list: - for k, v in m.state_dict().items(): - temp[k].append(v) - res = {k: torch.stack(v) for k, v in temp.items()} - return res - - def _param_name_new(self, old): - return '___'.join(old.split('.')) - - def _param_name_old(self, new): - return '.'.join(new.split('___')) - - def forward(self, *args, **kwargs): - assert not kwargs - weights = { - k: self.params[self._param_name_new(k)] for k in self.layer_weights_keys - } - scan = interop.torch_view(jax.lax.scan) - - def eval_one_layer(args, weight): - # unpack args - h, *rest = args - newh = torch.func.functional_call(self.c.one_mod, weight, args) - # next layer's input; and residual to be added to list - return (newh, *rest), None - - _eval_one_layer = interop.gradient_checkpoint( - eval_one_layer, - kwargs={'policy': self.checkpoint_policy}, - ) - h, _ = scan( - _eval_one_layer, - args, - weights, - ) - return h[0] diff --git a/torchax/torchax/types.py b/torchax/torchax/types.py deleted file mode 100644 index 72a2f678c961..000000000000 --- a/torchax/torchax/types.py +++ /dev/null @@ -1,12 +0,0 @@ -from typing import Callable, Any, Union, ParamSpec, TypeAlias -import torch -import jax -import jax.numpy as jnp -import sys - -P = ParamSpec('P') - -TorchValue: TypeAlias = Union[torch.Tensor, torch.dtype, 'TorchCallable', Any] -TorchCallable: TypeAlias = Callable[P, TorchValue] -JaxValue: TypeAlias = Union[jax.Array, jnp.dtype, 'JaxCallable', Any] -JaxCallable: TypeAlias = Callable[P, JaxValue] \ No newline at end of file diff --git a/torchax/torchax/util.py b/torchax/torchax/util.py deleted file mode 100644 index e34f77119d6f..000000000000 --- a/torchax/torchax/util.py +++ /dev/null @@ -1,88 +0,0 @@ -from typing import Any, Callable - - -def partition(original: list[Any], - func: Callable[[Any], bool]) -> tuple[list[Any], list[Any]]: - """Partitions elements into two parallel lists based on a predicate function. - - Iterates through the 'original' list, applying 'func' to each element 'a'. - - If `func(a)` returns True, 'a' is appended to the first list ('truthy') - and `None` is appended to the second list ('falsy'). - - If `func(a)` returns False, `None` is appended to the first list ('truthy') - and 'a' is appended to the second list ('falsy'). - - The result is two lists of the same length as the 'original' list, acting - as parallel representations of the partitioned elements, using `None` as - placeholders. - - This is useful when we want to mark a group of elements as static (via passing - static_argnums) or donated (via donate_argnums) when combining with jax.jit - and friends. - - Args: - original: The list of elements to partition. - func: A callable (function or lambda) that accepts an element from - 'original' and returns a boolean value (True or False). - - Returns: - A tuple containing two lists (`truthy`, `falsy`), both of the same - length as `original`: - - The first list contains elements `x` where `func(x)` was True, and - `None` otherwise. - - The second list contains elements `x` where `func(x)` was False, and - `None` otherwise. - - Example: - >>> def is_even(n): return n % 2 == 0 - >>> nums = [1, 2, 3, 4, 5, 6] - >>> truthy_list, falsy_list = partition(nums, is_even) - >>> truthy_list - [None, 2, None, 4, None, 6] - >>> falsy_list - [1, None, 3, None, 5, None] - """ - truthy = [] - falsy = [] - for a in original: - t, f = (a, None) if func(a) else (None, a) - truthy.append(t) - falsy.append(f) - return truthy, falsy - - -def merge(list1: list[Any], list2: list[Any]) -> list[Any]: - """Merges two lists element-wise, prioritizing non-None elements from list1. - - Creates a new list where each element is taken from the corresponding position - in 'list1', unless that element is None, in which case the element from the - corresponding position in 'list2' is used. Assumes both lists have the - same length. - - Invariant: merge(*partion(input_list, predicate)) == input_list for any predicate - - Args: - list1: The primary list. Its elements are preferred unless they are None. - list2: The secondary list. Its elements are used as fallbacks when the - corresponding element in list1 is None. - - Returns: - A new list representing the merged result. - - Raises: - AssertionError: If 'list1' and 'list2' do not have the same length. - - Example: - >>> l1 = [1, None, 3, None] - >>> l2 = [None, 2, None, 4] - >>> merge(l1, l2) - [1, 2, 3, 4] - >>> l3 = [None, 'b', None] - >>> l4 = ['a', None, 'c'] - >>> merge(l3, l4) - ['a', 'b', 'c'] - """ - assert len(list1) == len(list2) - res = [] - for a, b in zip(list1, list2): - res.append(b if a is None else a) - return res diff --git a/torchax/torchax/view.py b/torchax/torchax/view.py deleted file mode 100644 index 040fa24ef9e8..000000000000 --- a/torchax/torchax/view.py +++ /dev/null @@ -1,377 +0,0 @@ -import torch -import torch.utils._pytree as torch_pytree -import jax -from enum import Enum -from typing import Union, List, Tuple, Optional, Any, cast -from abc import ABC, abstractmethod - -# Reference to original PyTorch native functions -# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml - - -class ViewInfoType(Enum): - INVALID = 0 - NARROW = 1 - NO_OP = 2 - PERMUTE = 3 - RESHAPE = 4 - RESIZE = 5 - SELECT = 6 - AS_STRIDED = 7 - DIAGONAL = 8 - - -class ViewInfo(ABC): - """ - Abstract base class for all view operations. - Defines the interface for applying and updating view transformations. - """ - - def __init__( - self, - view_info_type: ViewInfoType = ViewInfoType.INVALID, - ): - """ - Initialize a ViewInfo object. - - Args: - view_info_type: The type of view operation - """ - self.view_info_type = view_info_type - - @abstractmethod - def update_tensor(self, new_value: jax.Array, - jax_array: jax.Array) -> jax.Array: - """ - Apply this view transformation to a JAX array and update its value. - - Args: - new_value: The new values to set in the view - jax_array: The parent array to update - - Returns: - Updated array - """ - pass - - @abstractmethod - def transform_tensor(self, jax_array: jax.Array) -> jax.Array: - """ - Apply this view transformation to a JAX array. - - Args: - jax_array: The array to transform - - Returns: - Transformed array - """ - pass - - @abstractmethod - def calculate_output_shape(self, source: jax.Array) -> List[int]: - """ - Calculate the resulting shape after applying this view. - - Args: - source: Original jax array before transformation - - Returns: - Resulting shape after transformation - """ - pass - - -class NarrowInfo(ViewInfo): - """ - Represents a slicing operation on a tensor. - Handles operations like tensor[1:3, :, 2:5:2]. - """ - - def __init__(self, slices: Union[slice, Tuple[slice]]) -> None: - """ - Args: - slices: The slice(s) to apply to the tensor. - E.g. jax_array.at[slices] will return the transformed tensor. - """ - super().__init__(ViewInfoType.NARROW) - self.slices = slices - - def __eq__(self, other: object) -> bool: - if not isinstance(other, NarrowInfo): - return False - return self.slices == other.slices - - def transform_tensor(self, jax_array: jax.Array) -> jax.Array: - try: - return jax_array[self.slices] - except IndexError as e: - raise IndexError("Invalid slice operation") from e - - def update_tensor(self, new_value: jax.Array, - jax_array: jax.Array) -> jax.Array: - return jax_array.at[self.slices].set(new_value) - - def calculate_output_shape(self, source: jax.Array) -> List[int]: - return source[self.slices].shape - - -class SelectInfo(ViewInfo): - """ - Represents a selection operation on a tensor. - Typically used for indexing operations that select specific elements. - """ - - def __init__(self, - dim: int = 0, - start: int = 0, - end: int = 0, - stride: int = 0) -> None: - super().__init__(ViewInfoType.SELECT) - self.dim: int = dim - self.start: int = start - self.end: int = end - self.stride: int = stride - - def __eq__(self, other: object) -> bool: - if not isinstance(other, SelectInfo): - return False - return (self.dim == other.dim and self.start == other.start and - self.end == other.end and self.stride == other.stride) - - def transform_tensor(self, jax_array: jax.Array) -> jax.Array: - raise NotImplementedError("SelectInfo.apply not implemented") - - def update_tensor(self, new_value: jax.Array, - jax_array: jax.Array) -> jax.Array: - raise NotImplementedError("SelectInfo.update not implemented") - - def calculate_output_shape(self, source: jax.Array) -> List[int]: - raise NotImplementedError( - "SelectInfo.calculate_output_shape not implemented") - - -class AsStridedInfo(ViewInfo): - """ - Information for as_strided operations. - """ - - def __init__(self, stride: List[int], offset: int = 0) -> None: - super().__init__(ViewInfoType.AS_STRIDED) - self.stride: List[int] = stride - self.offset: int = offset - - def __eq__(self, other: object) -> bool: - if not isinstance(other, AsStridedInfo): - return False - return self.offset == other.offset and self.stride == other.stride - - def transform_tensor(self, jax_array: jax.Array) -> jax.Array: - raise NotImplementedError("AsStridedInfo.apply not implemented") - - def update_tensor(self, new_value: jax.Array, - jax_array: jax.Array) -> jax.Array: - raise NotImplementedError("AsStridedInfo.update not implemented") - - def calculate_output_shape(self, source: jax.Array) -> List[int]: - raise NotImplementedError( - "AsStridedInfo.calculate_output_shape not implemented") - - -class DiagonalInfo(ViewInfo): - """ - Information for diagonal operations. - Extracts diagonal elements from a tensor. - """ - - def __init__(self, offset: int = 0, dim1: int = 0, dim2: int = 1) -> None: - """ - Args: - offset: Offset from the main diagonal - dim1: First dimension for diagonal extraction - dim2: Second dimension for diagonal extraction - """ - super().__init__(ViewInfoType.DIAGONAL) - self.offset: int = offset - self.dim1: int = dim1 - self.dim2: int = dim2 - - def __eq__(self, other: object) -> bool: - if not isinstance(other, DiagonalInfo): - return False - return (self.offset == other.offset and self.dim1 == other.dim1 and - self.dim2 == other.dim2) - - def transform_tensor(self, jax_array: jax.Array) -> jax.Array: - raise NotImplementedError("DiagonalInfo.apply not implemented") - - def update_tensor(self, new_value: jax.Array, - jax_array: jax.Array) -> jax.Array: - raise NotImplementedError("DiagonalInfo.update not implemented") - - def calculate_output_shape(self, source: jax.Array) -> List[int]: - raise NotImplementedError( - "DiagonalInfo.calculate_output_shape not implemented") - - -class View(torch.Tensor): - """ - A View is a reference to another Tensor or another View, - with a transformation applied to it. - """ - - @staticmethod - def __new__(cls, parent: Union["torchax.Tensor", "View"], view_info: ViewInfo, - env: Any) -> "View": - """ - Args: - parent: Parent tensor or view - view_info: Information about the view transformation - env: Environment for tensor operations - """ - shape = view_info.calculate_output_shape(parent.jax()) - return torch.Tensor._make_wrapper_subclass( - cls, - shape, - device="meta", - dtype=parent.dtype, - requires_grad=False, - ) - - def __init__(self, parent: Union["torchax.Tensor", "View"], - view_info: ViewInfo, env: Any) -> None: - super().__init__() - self.parent = parent - self.view_info = view_info - self._env = env - - def get_transformation_chain(self) -> List[ViewInfo]: - """ - Get all view transformations from the source tensor to this view. - """ - if isinstance(self.parent, View): - transformations = self.parent.get_transformation_chain() - transformations.append(self.view_info) - return transformations - else: - return [self.view_info] - - __torch_function__ = torch._C._disabled_torch_function_impl - - def source_jax(self) -> jax.Array: - """ - Returns the source tensor. - """ - if isinstance(self.parent, View): - return self.parent.source_jax() - else: - return self.parent.jax() - - def replace_source_jax(self, new_value: jax.Array) -> None: - """ - Update the source tensor with new values. - """ - if isinstance(self.parent, View): - self.parent.replace_source_jax(new_value) - else: - assert new_value.shape == self.parent._elem.shape - self.parent._elem = new_value - - def torch(self) -> "torchax.Tensor": - """ - Returns a Torchax tensor representing this view after all transformations - """ - from torchax.tensor import Tensor - - return Tensor(self.jax(), self._env) - - def update( - self, - new_values: Union[jax.Array, "View", "torchax.Tensor"], - view_infos: Optional[List[ViewInfo]] = None, - ) -> None: - """ - Update this view with new values, propagating changes back to source. - If view_infos is None, it will use the transformation chain - from the source tensor. - """ - if view_infos is None: - view_infos = self.get_transformation_chain() - - # Get the source JAX array - source_array = self.source_jax() - - # Get the new value - from torchax.tensor import Tensor - - if isinstance(new_values, View) or isinstance(new_values, Tensor): - new_values = new_values.jax() - - # Apply all view transformations to the source array - # And store intermediate values - intermediate_values = [source_array] - for view_info in view_infos[:-1]: - intermediate_values.append( - view_info.transform_tensor(intermediate_values[-1])) - - # TODO: Investigate efficiency of this algorithm - # Update the source array with the new value by - # applying inverse transformations in reverse order - for view_info, parent_array in zip( - reversed(view_infos), reversed(intermediate_values)): - # Apply the inverse transformation to propagate changes back - new_values = view_info.update_tensor(new_values, parent_array) - - # Update the source tensor with the new values - self.replace_source_jax(new_values) - - @classmethod - def __torch_dispatch__( - cls, - func: Any, - types: Tuple[Any, ...], - args: Tuple[Any, ...] = (), - kwargs: Optional[dict] = None, - ) -> Any: - raise AssertionError( - 'torchax Tensors can only do math within the torchax environment.' - 'Please wrap your code with `with torchax.default_env()` or ' - 'call torchax.enable_globally() before.') - - def create_sub_view(self, view_info: ViewInfo) -> "View": - """ - Create a new view that is a child of this view. - """ - return View(self, view_info, self._env) - - def __str__(self) -> str: - return f"View({self.torch()})" - - def jax(self) -> jax.Array: - """ - Returns a copy of the source tensor after transformations. - """ - result = self.source_jax() - for view_info in self.get_transformation_chain(): - result = view_info.transform_tensor(result) - return result - - def __setitem__(self, indexes, val): - view_infos = self.get_transformation_chain() + [NarrowInfo(indexes)] - self.update(view_infos=view_infos, new_values=val) - - def dim(self): - return self.ndim - - @property - def device(self): - return torch.device("jax:0") - - @property - def jax_device(self): - return self.jax().device - - @property - def ndim(self): - return len(self.shape) - - __repr__ = __str__