Skip to content

Commit 156d913

Browse files
tengyifeibhavya01
andauthored
Support editable install with setuptools>=80.0.0 (#9428)
Co-authored-by: Bhavya Bahl <[email protected]>
1 parent 8999ba5 commit 156d913

File tree

6 files changed

+135
-55
lines changed

6 files changed

+135
-55
lines changed

.devcontainer/tpu-contributor/devcontainer.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,4 @@
2323
]
2424
}
2525
}
26-
}
26+
}

.devcontainer/tpu-internal/devcontainer.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
],
99
"containerEnv": {
1010
"BAZEL_REMOTE_CACHE": "1",
11-
"SILO_NAME": "cache-silo-${localEnv:USER}-tpuvm"
11+
"SILO_NAME": "cache-silo-${localEnv:USER}-tpuvm-312"
1212
},
1313
"initializeCommand": "docker pull us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:tpu",
1414
"customizations": {

pyproject.toml

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
[build-system]
2+
# These are the packages required to run `setup.py` in an isolated environment.
3+
# Pip will install these *before* executing the build backend.
4+
requires = [
5+
"setuptools>=42",
6+
"wheel",
7+
"requests",
8+
"numpy",
9+
"pyyaml",
10+
]
11+
build-backend = "setuptools.build_meta"
12+
13+
[project]
14+
name = "torch-xla"
15+
description = "XLA bridge for PyTorch"
16+
readme = "README.md"
17+
authors = [
18+
{ name = "PyTorch/XLA Dev Team", email = "[email protected]" },
19+
]
20+
license = { file = "LICENSE" }
21+
requires-python = ">=3.10"
22+
classifiers = [
23+
"Development Status :: 5 - Production/Stable",
24+
"Intended Audience :: Developers",
25+
"Intended Audience :: Science/Research",
26+
"License :: OSI Approved :: BSD License",
27+
"Programming Language :: Python :: 3",
28+
"Programming Language :: Python :: 3.10",
29+
"Programming Language :: Python :: 3.11",
30+
"Programming Language :: Python :: 3.12",
31+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
32+
"Topic :: Software Development :: Libraries :: Python Modules",
33+
]
34+
keywords = ["pytorch", "xla", "tpu", "deep learning", "compiler"]
35+
36+
# This tells build tools to get this info from setup.py instead of this file.
37+
dynamic = [
38+
"version",
39+
"dependencies",
40+
"optional-dependencies",
41+
"entry-points",
42+
"scripts"
43+
]
44+
45+
[project.urls]
46+
Homepage = "https://github.com/pytorch/xla"
47+
Repository = "https://github.com/pytorch/xla"
48+
"Bug Tracker" = "https://github.com/pytorch/xla/issues"

scripts/build_developer.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ cd $_SCRIPT_DIR/..
8080
pip uninstall torch_xla torchax torch_xla2 -y
8181

8282
# Build the wheel too, which is useful for other testing purposes.
83+
rm -f torch_xla.egg-info/SOURCES.txt
8384
python3 setup.py bdist_wheel
8485

8586
# Link the source files for local development.

setup.py

Lines changed: 83 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
import contextlib
5353
import distutils.ccompiler
5454
import distutils.command.clean
55+
import importlib.util
5556
import os
5657
import re
5758
import requests
@@ -61,7 +62,13 @@
6162
import tempfile
6263
import zipfile
6364

64-
import build_util
65+
# This gloop imports build_util.py such that it works in Python 3.12's isolated
66+
# build environment while also not contaminating sys.path which breaks bdist_wheel.
67+
_PROJECT_DIR = os.path.dirname(os.path.abspath(__file__))
68+
_build_util_path = os.path.join(_PROJECT_DIR, 'build_util.py')
69+
spec = importlib.util.spec_from_file_location('build_util', _build_util_path)
70+
build_util = importlib.util.module_from_spec(spec)
71+
spec.loader.exec_module(build_util)
6572

6673
import platform
6774

@@ -270,15 +277,21 @@ def __init__(self, bazel_target):
270277
class BuildBazelExtension(build_ext.build_ext):
271278
"""A command that runs Bazel to build a C/C++ extension."""
272279

273-
def run(self):
274-
for ext in self.extensions:
275-
self.bazel_build(ext)
276-
command.build_ext.build_ext.run(self) # type: ignore
280+
def build_extension(self, ext: Extension) -> None:
281+
"""
282+
This method is called by setuptools to build a single extension.
283+
We override it to implement our custom Bazel build logic.
284+
"""
285+
if not isinstance(ext, BazelExtension):
286+
# If it's not our custom extension type, let setuptools handle it.
287+
super().build_extension(ext)
288+
return
277289

278-
def bazel_build(self, ext):
290+
# 1. Ensure the temporary build directory exists
279291
if not os.path.exists(self.build_temp):
280292
os.makedirs(self.build_temp)
281293

294+
# 2. Prepare the Bazel command
282295
bazel_argv = [
283296
'bazel', 'build', ext.bazel_target,
284297
f"--symlink_prefix={os.path.join(self.build_temp, 'bazel-')}"
@@ -288,22 +301,31 @@ def bazel_build(self, ext):
288301
if build_cpp_tests:
289302
bazel_argv.append('//:cpp_tests')
290303

291-
import torch
292-
cxx_abi = os.getenv('CXX_ABI') or getattr(torch._C,
293-
'_GLIBCXX_USE_CXX11_ABI', None)
294-
if cxx_abi is not None:
295-
bazel_argv.append(f'--cxxopt=-D_GLIBCXX_USE_CXX11_ABI={int(cxx_abi)}')
304+
cxx_abi = os.getenv('CXX_ABI')
305+
if cxx_abi is None:
306+
try:
307+
import torch
308+
cxx_abi = getattr(torch._C, '_GLIBCXX_USE_CXX11_ABI', None)
309+
except:
310+
pass
311+
if cxx_abi is None:
312+
# Default to building with C++11 ABI, which has been the case since PyTorch 2.7
313+
cxx_abi = "1"
314+
bazel_argv.append(f'--cxxopt=-D_GLIBCXX_USE_CXX11_ABI={int(cxx_abi)}')
296315

297316
bazel_argv.extend(build_util.bazel_options_from_env())
298317

318+
# 3. Run the Bazel build
299319
self.spawn(bazel_argv)
300320

321+
# 4. Copy the output file to the location setuptools expects
301322
ext_bazel_bin_path = os.path.join(self.build_temp, 'bazel-bin', ext.relpath,
302323
ext.target_name)
303324
ext_dest_path = self.get_ext_fullpath(ext.name)
304325
ext_dest_dir = os.path.dirname(ext_dest_path)
305326
if not os.path.exists(ext_dest_dir):
306327
os.makedirs(ext_dest_dir)
328+
307329
shutil.copyfile(ext_bazel_bin_path, ext_dest_path)
308330

309331

@@ -313,17 +335,28 @@ def bazel_build(self, ext):
313335
long_description = f.read()
314336

315337
# Finds torch_xla and its subpackages
316-
packages_to_include = find_packages(include=['torch_xla*'])
317-
# Explicitly add torchax
318-
packages_to_include.extend(find_packages(where='torchax', include=['torchax*']))
338+
# 1. Find `torch_xla` and its subpackages automatically from the root.
339+
packages_to_include = find_packages(include=['torch_xla', 'torch_xla.*'])
340+
341+
# 2. Explicitly find the contents of the nested `torchax` package.
342+
# Find all sub-packages within the torchax directory (e.g., 'ops').
343+
torchax_source_dir = 'torchax/torchax'
344+
torchax_subpackages = find_packages(where=torchax_source_dir)
345+
# Construct the full list of packages, starting with the top-level
346+
# 'torchax' and adding all the discovered sub-packages.
347+
packages_to_include.extend(['torchax'] +
348+
['torchax.' + pkg for pkg in torchax_subpackages])
319349

320-
# Map the top-level 'torchax' package name to its source location
321-
torchax_dir = os.path.join(cwd, 'torchax')
322-
package_dir_mapping = {'torch_xla': os.path.join(cwd, 'torch_xla')}
323-
package_dir_mapping['torchax'] = os.path.join(torchax_dir, 'torchax')
350+
# 3. The package_dir mapping explicitly tells setuptools where the 'torchax'
351+
# package's source code begins. `torch_xla` source code is inferred.
352+
package_dir_mapping = {'torchax': torchax_source_dir}
324353

325354

326355
class Develop(develop.develop):
356+
"""
357+
Custom develop command to build C++ extensions and create a .pth file
358+
for a multi-package editable install.
359+
"""
327360

328361
def run(self):
329362
# Build the C++ extension
@@ -348,44 +381,42 @@ def link_packages(self):
348381
(`python setup.py develop`). Nightly and release wheel builds work out of the box
349382
without egg-link/pth.
350383
"""
384+
import glob
385+
351386
# Ensure paths like self.install_dir are set
352387
self.ensure_finalized()
353388

354-
# Get the site-packages directory
355-
target_dir = self.install_dir
356-
357-
# Remove the standard .egg-link file
358-
# It's usually named based on the distribution name
359389
dist_name = self.distribution.get_name()
360-
egg_link_file = os.path.join(target_dir, dist_name + '.egg-link')
361-
if os.path.exists(egg_link_file):
362-
print(f"Removing default egg-link file: {egg_link_file}")
363-
try:
364-
os.remove(egg_link_file)
365-
except OSError as e:
366-
print(f"Warning: Could not remove {egg_link_file}: {e}")
367-
368-
# Create our custom .pth file with specific paths
369-
cwd = os.path.dirname(__file__)
370-
# Path containing 'torch_xla' package source: ROOT
371-
path_for_torch_xla = os.path.abspath(cwd)
372-
# Path containing 'torchax' package source: ROOT/torchax
373-
path_for_torchax = os.path.abspath(os.path.join(cwd, 'torchax'))
374-
375-
paths_to_add = {path_for_torch_xla, path_for_torchax}
376-
377-
# Construct a suitable .pth filename (PEP 660 style is good practice)
378-
version = self.distribution.get_version()
379-
# Sanitize name and version for filename (replace runs of non-alphanumeric chars with '-')
380-
sanitized_name = re.sub(r"[^a-zA-Z0-9.]+", "_", dist_name)
381-
sanitized_version = re.sub(r"[^a-zA-Z0-9.]+", "_", version)
382-
pth_filename = os.path.join(
383-
target_dir, f"__editable_{sanitized_name}_{sanitized_version}.pth")
384-
385-
# Ensure site-packages exists
386-
os.makedirs(target_dir, exist_ok=True)
387-
388-
# Write the paths to the .pth file, one per line
390+
install_cmd = self.get_finalized_command('install')
391+
target_dir = install_cmd.install_lib
392+
assert target_dir is not None
393+
394+
# Use glob to robustly find and remove the conflicting files.
395+
# This is safer than trying to guess the exact sanitized filename.
396+
safe_name_part = re.sub(r"[^a-zA-Z0-9]+", "_", dist_name)
397+
398+
for pattern in [
399+
# Remove `.pth` files generated in Python 3.12.
400+
f"__editable__.*{safe_name_part}*.pth",
401+
f"__editable___*{safe_name_part}*_finder.py",
402+
# Also remove the legacy egg-link format.
403+
f"{dist_name}.egg-link"
404+
]:
405+
for filepath in glob.glob(os.path.join(target_dir, pattern)):
406+
print(f"Cleaning up conflicting install file: {filepath}")
407+
with contextlib.suppress(OSError):
408+
os.remove(filepath)
409+
410+
# Finally, create our own simple, multi-path .pth file.
411+
# We name it simply, e.g., "torch_xla.pth".
412+
pth_filename = os.path.join(target_dir, f"{dist_name}.pth")
413+
414+
project_root = os.path.dirname(os.path.abspath(__file__))
415+
paths_to_add = {
416+
project_root, # For `torch_xla`
417+
os.path.abspath(os.path.join(project_root, 'torchax')), # For `torchax`
418+
}
419+
389420
with open(pth_filename, "w", encoding='utf-8') as f:
390421
for path in sorted(paths_to_add):
391422
f.write(path + "\n")

torch_xla/experimental/gru.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def __init__(self, *args, **kwargs):
9999
super().__init__(*args, **kwargs)
100100

101101
def forward(self, input, hx=None):
102-
"""
102+
r"""
103103
Args:
104104
input: Tensor of shape (seq_len, batch, input_size)
105105
hx: Optional initial hidden state of shape (num_layers, batch, hidden_size).

0 commit comments

Comments
 (0)