Skip to content

Commit eb37dfd

Browse files
authored
Remove torchax from torchxla repo. (#9672)
It's now in a separate repo at google/torchax Optional dependency on torchax is handled the same way as JAX, (via pip dependency)
1 parent 11590c1 commit eb37dfd

File tree

103 files changed

+118
-24687
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

103 files changed

+118
-24687
lines changed

.github/workflows/lintercheck.yml

Lines changed: 104 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,113 +1,113 @@
11
name: Linter check
22
on:
3-
pull_request:
4-
push:
5-
branches:
6-
- master
7-
tags:
8-
- r[0-9]+.[0-9]+
3+
pull_request:
4+
push:
5+
branches:
6+
- master
7+
tags:
8+
- r[0-9]+.[0-9]+
99

1010
jobs:
11-
check_code_changes:
12-
name: Check Code Changes
13-
uses: ./.github/workflows/_check_code_changes.yml
14-
with:
15-
event_name: ${{ github.event_name }}
16-
# For pull_request, use PR's base and head. For push, use event's before and sha.
17-
base_sha: ${{ github.event_name == 'pull_request' && github.event.pull_request.base.sha || github.event.before }}
18-
head_sha: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
19-
linter_check:
20-
runs-on: ubuntu-24.04
21-
needs: [check_code_changes]
22-
steps:
23-
- name: Checkout repo
24-
if: needs.check_code_changes.outputs.has_code_changes == 'true'
25-
uses: actions/checkout@v3
26-
- name: Setup Python
27-
if: needs.check_code_changes.outputs.has_code_changes == 'true'
28-
uses: actions/setup-python@v4
11+
check_code_changes:
12+
name: Check Code Changes
13+
uses: ./.github/workflows/_check_code_changes.yml
2914
with:
30-
python-version: '3.10'
31-
cache: 'pip'
32-
- run: pip install yapf==0.40.2 # N.B.: keep in sync with `torchax/dev-requirements.txt`, `infra/ansible/config/pip.yaml`
15+
event_name: ${{ github.event_name }}
16+
# For pull_request, use PR's base and head. For push, use event's before and sha.
17+
base_sha: ${{ github.event_name == 'pull_request' && github.event.pull_request.base.sha || github.event.before }}
18+
head_sha: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
19+
linter_check:
20+
runs-on: ubuntu-24.04
21+
needs: [check_code_changes]
22+
steps:
23+
- name: Checkout repo
24+
if: needs.check_code_changes.outputs.has_code_changes == 'true'
25+
uses: actions/checkout@v3
26+
- name: Setup Python
27+
if: needs.check_code_changes.outputs.has_code_changes == 'true'
28+
uses: actions/setup-python@v4
29+
with:
30+
python-version: "3.10"
31+
cache: "pip"
32+
- run: pip install yapf==0.40.2 # N.B.: keep in sync with `infra/ansible/config/pip.yaml`
3333

34-
- name: Check no TORCH_PIN
35-
if: >
36-
(github.event_name == 'push' && github.event.ref == 'refs/heads/master') &&
37-
needs.check_code_changes.outputs.has_code_changes == 'true'
38-
shell: bash
39-
run: |
40-
TORCH_PIN=./.torch_pin
41-
if [[ -f "${TORCH_PIN}" ]]; then
42-
echo "Please remove ${TORCH_PIN} before landing."
43-
exit 1
44-
else
45-
echo "No ${TORCH_PIN} found, safe to land..."
46-
fi
47-
- name: Check .cc file extension
48-
shell: bash
49-
run: |
50-
# Find *.cc files recursively in the current directory, limiting to files only.
51-
found_files=$(find . -type f -name "*.cc")
34+
- name: Check no TORCH_PIN
35+
if: >
36+
(github.event_name == 'push' && github.event.ref == 'refs/heads/master') &&
37+
needs.check_code_changes.outputs.has_code_changes == 'true'
38+
shell: bash
39+
run: |
40+
TORCH_PIN=./.torch_pin
41+
if [[ -f "${TORCH_PIN}" ]]; then
42+
echo "Please remove ${TORCH_PIN} before landing."
43+
exit 1
44+
else
45+
echo "No ${TORCH_PIN} found, safe to land..."
46+
fi
47+
- name: Check .cc file extension
48+
shell: bash
49+
run: |
50+
# Find *.cc files recursively in the current directory, limiting to files only.
51+
found_files=$(find . -type f -name "*.cc")
5252
53-
# Check if any files were found.
54-
if [ -n "$found_files" ]; then
55-
echo "Found *.cc files:"
56-
echo "$found_files"
57-
echo "Please rename them to *.cpp for consistency."
58-
exit 1
59-
else
60-
echo "PASSED *.cc file extension check"
61-
fi
62-
- name: Run clang-format
63-
if: needs.check_code_changes.outputs.has_code_changes == 'true'
64-
shell: bash
65-
env:
66-
CLANG_FORMAT: clang-format-16
67-
run: |
68-
sudo apt-get update
69-
sudo apt install -y "${CLANG_FORMAT}"
70-
git_status=$(git status --porcelain)
71-
if [[ $git_status ]]; then
72-
echo "Checkout code is not clean"
73-
echo "${git_status}"
74-
exit 1
75-
fi
53+
# Check if any files were found.
54+
if [ -n "$found_files" ]; then
55+
echo "Found *.cc files:"
56+
echo "$found_files"
57+
echo "Please rename them to *.cpp for consistency."
58+
exit 1
59+
else
60+
echo "PASSED *.cc file extension check"
61+
fi
62+
- name: Run clang-format
63+
if: needs.check_code_changes.outputs.has_code_changes == 'true'
64+
shell: bash
65+
env:
66+
CLANG_FORMAT: clang-format-16
67+
run: |
68+
sudo apt-get update
69+
sudo apt install -y "${CLANG_FORMAT}"
70+
git_status=$(git status --porcelain)
71+
if [[ $git_status ]]; then
72+
echo "Checkout code is not clean"
73+
echo "${git_status}"
74+
exit 1
75+
fi
7676
77-
find . -name '*.cpp' -o -name '*.h' -o -name '*.cc' | xargs "${CLANG_FORMAT}" -i -style=file
78-
git_status=$(git status --porcelain)
79-
if [[ $git_status ]]; then
80-
git diff
81-
echo "${CLANG_FORMAT} recommends the changes above, please manually apply them OR automatically apply the changes "
82-
echo "by running \"${CLANG_FORMAT} -i -style=file /PATH/TO/foo.cpp\" to the following files"
83-
echo "${git_status}"
84-
exit 1
85-
else
86-
echo "PASSED C++ format"
87-
fi
88-
- name: Run yapf
89-
if: needs.check_code_changes.outputs.has_code_changes == 'true'
90-
shell: bash
91-
run: |
92-
git_status=$(git status --porcelain)
93-
if [[ $git_status ]]; then
94-
echo "Checkout code is not clean"
95-
echo "${git_status}"
96-
exit 1
97-
fi
77+
find . -name '*.cpp' -o -name '*.h' -o -name '*.cc' | xargs "${CLANG_FORMAT}" -i -style=file
78+
git_status=$(git status --porcelain)
79+
if [[ $git_status ]]; then
80+
git diff
81+
echo "${CLANG_FORMAT} recommends the changes above, please manually apply them OR automatically apply the changes "
82+
echo "by running \"${CLANG_FORMAT} -i -style=file /PATH/TO/foo.cpp\" to the following files"
83+
echo "${git_status}"
84+
exit 1
85+
else
86+
echo "PASSED C++ format"
87+
fi
88+
- name: Run yapf
89+
if: needs.check_code_changes.outputs.has_code_changes == 'true'
90+
shell: bash
91+
run: |
92+
git_status=$(git status --porcelain)
93+
if [[ $git_status ]]; then
94+
echo "Checkout code is not clean"
95+
echo "${git_status}"
96+
exit 1
97+
fi
9898
99-
yapf -i -r *.py test/ scripts/ torch_xla/ benchmarks/ torchax/
100-
git_status=$(git status --porcelain)
101-
if [[ $git_status ]]; then
102-
git diff
103-
echo "yapf recommends the changes above, please manually apply them OR automatically apply the changes "
104-
echo "by running `yapf -i /PATH/TO/foo.py` to the following files"
105-
echo "${git_status}"
106-
exit 1
107-
else
108-
echo "PASSED Python format"
109-
fi
110-
- name: Report no code changes
111-
if: needs.check_code_changes.outputs.has_code_changes == 'false'
112-
run: |
113-
echo "No code changes were detected that require running the full test suite."
99+
yapf -i -r *.py test/ scripts/ torch_xla/ benchmarks/
100+
git_status=$(git status --porcelain)
101+
if [[ $git_status ]]; then
102+
git diff
103+
echo "yapf recommends the changes above, please manually apply them OR automatically apply the changes "
104+
echo "by running `yapf -i /PATH/TO/foo.py` to the following files"
105+
echo "${git_status}"
106+
exit 1
107+
else
108+
echo "PASSED Python format"
109+
fi
110+
- name: Report no code changes
111+
if: needs.check_code_changes.outputs.has_code_changes == 'false'
112+
run: |
113+
echo "No code changes were detected that require running the full test suite."

.github/workflows/torchax.yml

Lines changed: 0 additions & 73 deletions
This file was deleted.

setup.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@
119119
_jaxlib_version = '0.8.0'
120120
_jax_date = '20251001' # Date for jax and jaxlib.
121121

122+
_torchax_version = '0.0.7' # likely stay the same
123+
122124
if USE_NIGHTLY:
123125
_libtpu_version += f".dev{_libtpu_date}+nightly"
124126
_jax_version += f'.dev{_jax_date}'
@@ -335,19 +337,6 @@ def build_extension(self, ext: Extension) -> None:
335337
# 1. Find `torch_xla` and its subpackages automatically from the root.
336338
packages_to_include = find_packages(include=['torch_xla', 'torch_xla.*'])
337339

338-
# 2. Explicitly find the contents of the nested `torchax` package.
339-
# Find all sub-packages within the torchax directory (e.g., 'ops').
340-
torchax_source_dir = 'torchax/torchax'
341-
torchax_subpackages = find_packages(where=torchax_source_dir)
342-
# Construct the full list of packages, starting with the top-level
343-
# 'torchax' and adding all the discovered sub-packages.
344-
packages_to_include.extend(['torchax'] +
345-
['torchax.' + pkg for pkg in torchax_subpackages])
346-
347-
# 3. The package_dir mapping explicitly tells setuptools where the 'torchax'
348-
# package's source code begins. `torch_xla` source code is inferred.
349-
package_dir_mapping = {'torchax': torchax_source_dir}
350-
351340

352341
class Develop(develop.develop):
353342
"""
@@ -372,7 +361,7 @@ def link_packages(self):
372361
and `.pth` files. setuptools uses `.egg-link` by default. However, `.egg-link`
373362
only supports linking a single directory containg one editable package.
374363
This function removes the `.egg-link` file and generates a `.pth` file that can
375-
be used to link multiple packages, in particular, `torch_xla` and `torchax`.
364+
be used to link multiple packages.
376365
377366
Note that this function is only relevant in the editable package development path
378367
(`python setup.py develop`). Nightly and release wheel builds work out of the box
@@ -409,18 +398,13 @@ def link_packages(self):
409398
pth_filename = os.path.join(target_dir, f"{dist_name}.pth")
410399

411400
project_root = os.path.dirname(os.path.abspath(__file__))
412-
paths_to_add = {
413-
project_root, # For `torch_xla`
414-
os.path.abspath(os.path.join(project_root, 'torchax')), # For `torchax`
415-
}
416-
417401
with open(pth_filename, "w", encoding='utf-8') as f:
418-
for path in sorted(paths_to_add):
419-
f.write(path + "\n")
402+
f.write(project_root + "\n")
420403

421404

422405
def _get_jax_install_requirements():
423406
return [
407+
f'torchax=={_torchax_version}',
424408
f'jaxlib=={_jaxlib_version}',
425409
f'jax=={_jax_version}',
426410
]
@@ -452,7 +436,6 @@ def _get_jax_install_requirements():
452436
],
453437
python_requires=">=3.10.0",
454438
packages=packages_to_include,
455-
package_dir=package_dir_mapping,
456439
ext_modules=[
457440
BazelExtension('//:_XLAC.so'),
458441
],

torchax/LICENSE

Lines changed: 0 additions & 28 deletions
This file was deleted.

0 commit comments

Comments
 (0)