Skip to content

Commit be27dbf

Browse files
janeyx99pytorchmergebot
authored andcommitted
Enable CPP/CUDAExtension with py_limited_api for python agnosticism (pytorch#138088)
Getting tested with ao, but now there is a real test i added. ## What does this PR do? We want to allow custom PyTorch extensions to be able to build one wheel for multiple Python versions, in other words, achieve python agnosticism. It turns out that there is such a way that setuptools/Python provides already! Namely, if the user promises to use only the Python limited API in their extension, they can pass in `py_limited_api` to their Extension class and to the bdist_wheel command (with a min python version) in order to build 1 wheel that will suffice across multiple Python versions. Sounds lovely! Why don't people do that already with PyTorch? Well 2 things. This workflow is hardly documented (even searching for python agnostic specifically does not reveal many answers) so I'd expect that people simply don't know about it. But even if they did, _PyTorch_ custom Extensions would still not work because we always link torch_python, which does not abide by py_limited_api rules. So this is where this PR comes in! We respect when the user specifies py_limited_api and skip linking torch_python under that condition, allowing users to enroll in the provided functionality I just described. ## How do I know this PR works? I manually tested my silly little ultra_norm locally (with `import python_agnostic`) and wrote a test case for the extension showing that - torch_python doesn't show up in the ldd tree - no Py- symbols show up It may be a little confusing that our test case is actually python-free (more clean than python-agnostic) but it is sufficient (and not necessary) towards showing that this change works. Pull Request resolved: pytorch#138088 Approved by: https://github.com/ezyang, https://github.com/albanD
1 parent fb02b40 commit be27dbf

File tree

8 files changed

+233
-18
lines changed

8 files changed

+233
-18
lines changed

test/cpp_extensions/open_registration_extension/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def run(self):
4848
name=PACKAGE_NAME,
4949
version=version,
5050
author="PyTorch Core Team",
51-
description="Example for PyTorch out of tree regitration",
51+
description="Example for PyTorch out of tree registration",
5252
packages=find_packages(exclude=("test",)),
5353
package_data={PACKAGE_NAME: ["*.dll", "*.dylib", "*.so"]},
5454
install_requires=[
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from pathlib import Path
2+
3+
import torch
4+
5+
6+
so_files = list(Path(__file__).parent.glob("_C*.so"))
7+
assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
8+
torch.ops.load_library(so_files[0])
9+
10+
from . import ops
11+
12+
13+
# ----------------------------------------------------------------------------- #
14+
# We've reached the end of what is normal in __init__ files.
15+
# The following is used to assert the ultra_norm op is properly loaded and
16+
# calculates correct results upon import of this extension.
17+
18+
inputs = [
19+
torch.tensor([1.0, 2.0, 3.0], device="cuda"),
20+
torch.tensor([-4.0, -5.0, -6.0], device="cuda"),
21+
]
22+
23+
assert torch.equal(
24+
ops.ultra_norm(inputs),
25+
torch.norm(torch.tensor([1.0, 2.0, 3.0, -4.0, -5.0, -6.0], device="cuda")),
26+
)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#include <ATen/ops/_foreach_norm_native.h>
2+
#include <ATen/ops/cat_cuda_dispatch.h>
3+
#include <ATen/ops/norm_cuda_dispatch.h>
4+
#include <ATen/ops/unsqueeze.h>
5+
#include <torch/extension.h>
6+
7+
at::Tensor ultra_norm(at::TensorList inputs) {
8+
auto res = at::native::foreach_tensor_norm_cuda(inputs);
9+
std::vector<at::Tensor> unsqueezed;
10+
for (const auto& scalar_tensor : res) {
11+
unsqueezed.push_back(at::unsqueeze(scalar_tensor, 0));
12+
}
13+
auto stacked = at::cuda::cat(unsqueezed);
14+
return at::cuda::norm(stacked, 2, at::IntArrayRef{}, false);
15+
}
16+
17+
TORCH_LIBRARY_IMPL(python_agnostic, CUDA, m) {
18+
m.impl("python_agnostic::ultra_norm", &ultra_norm);
19+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from typing import List
2+
3+
import torch
4+
from torch import Tensor
5+
6+
7+
lib = torch.library._scoped_library("python_agnostic", "FRAGMENT")
8+
lib.define("ultra_norm(Tensor[] inputs) -> Tensor")
9+
10+
11+
def ultra_norm(inputs: List[Tensor]) -> Tensor:
12+
"""
13+
Computes the ultra-L2-norm of a list of tensors via computing the norm of norms.
14+
15+
Assumes:
16+
- inputs should not be empty
17+
- all tensors in inputs should be on the same device and have the same dtype
18+
19+
Args:
20+
inputs: list of torch.tensors
21+
22+
Returns:
23+
Scalar torch.tensor of shape ()
24+
25+
"""
26+
return torch.ops.python_agnostic.ultra_norm.default(inputs)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# This source code is licensed under the license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import distutils.command.clean
7+
import shutil
8+
from pathlib import Path
9+
10+
from setuptools import setup
11+
12+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
13+
14+
15+
ROOT_DIR = Path(__file__).parent
16+
CSRC_DIR = ROOT_DIR / "python_agnostic" / "csrc"
17+
18+
19+
class clean(distutils.command.clean.clean):
20+
def run(self):
21+
# Run default behavior first
22+
distutils.command.clean.clean.run(self)
23+
24+
# Remove extension
25+
for path in (ROOT_DIR / "python_agnostic").glob("**/*.so"):
26+
path.unlink()
27+
# Remove build and dist and egg-info directories
28+
dirs = [
29+
ROOT_DIR / "build",
30+
ROOT_DIR / "dist",
31+
ROOT_DIR / "python_agnostic.egg-info",
32+
]
33+
for path in dirs:
34+
if path.exists():
35+
shutil.rmtree(str(path), ignore_errors=True)
36+
37+
38+
def get_extension():
39+
extra_compile_args = {
40+
"cxx": ["-fdiagnostics-color=always"],
41+
}
42+
43+
sources = list(CSRC_DIR.glob("**/*.cu"))
44+
45+
return [
46+
CUDAExtension(
47+
"python_agnostic._C",
48+
sources=sorted(str(s) for s in sources),
49+
py_limited_api=True,
50+
extra_compile_args=extra_compile_args,
51+
extra_link_args=[],
52+
)
53+
]
54+
55+
56+
setup(
57+
name="python_agnostic",
58+
version="0.0",
59+
author="PyTorch Core Team",
60+
description="Example of python agnostic extension",
61+
ext_modules=get_extension(),
62+
cmdclass={
63+
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
64+
"clean": clean,
65+
},
66+
options={"bdist_wheel": {"py_limited_api": "cp39"}},
67+
)

test/run_test.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,18 +1031,23 @@ def _test_cpp_extensions_aot(test_directory, options, use_ninja):
10311031
# Build the test cpp extensions modules
10321032
shell_env = os.environ.copy()
10331033
shell_env["USE_NINJA"] = str(1 if use_ninja else 0)
1034-
cmd = [sys.executable, "setup.py", "install", "--root", "./install"]
1035-
return_code = shell(cmd, cwd=cpp_extensions_test_dir, env=shell_env)
1034+
install_cmd = [sys.executable, "setup.py", "install", "--root", "./install"]
1035+
wheel_cmd = [sys.executable, "setup.py", "bdist_wheel"]
1036+
return_code = shell(install_cmd, cwd=cpp_extensions_test_dir, env=shell_env)
10361037
if return_code != 0:
10371038
return return_code
10381039
if sys.platform != "win32":
1039-
return_code = shell(
1040-
cmd,
1041-
cwd=os.path.join(cpp_extensions_test_dir, "no_python_abi_suffix_test"),
1042-
env=shell_env,
1043-
)
1044-
if return_code != 0:
1045-
return return_code
1040+
exts_to_build = [(install_cmd, "no_python_abi_suffix_test")]
1041+
if TEST_CUDA:
1042+
exts_to_build.append((wheel_cmd, "python_agnostic_extension"))
1043+
for cmd, extension_dir in exts_to_build:
1044+
return_code = shell(
1045+
cmd,
1046+
cwd=os.path.join(cpp_extensions_test_dir, extension_dir),
1047+
env=shell_env,
1048+
)
1049+
if return_code != 0:
1050+
return return_code
10461051

10471052
# "install" the test modules and run tests
10481053
python_path = os.environ.get("PYTHONPATH", "")

test/test_cpp_extensions_aot.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22

33
import os
44
import re
5+
import subprocess
6+
import sys
57
import unittest
68
from itertools import repeat
9+
from pathlib import Path
710
from typing import get_args, get_origin, Union
811

912
import torch
@@ -13,6 +16,7 @@
1316
from torch.testing._internal.common_cuda import TEST_CUDA
1417
from torch.testing._internal.common_utils import (
1518
IS_WINDOWS,
19+
shell,
1620
skipIfTorchDynamo,
1721
xfailIfTorchDynamo,
1822
)
@@ -164,6 +168,48 @@ def test_cuda_dlink_libs(self):
164168
test = cuda_dlink.add(a, b)
165169
self.assertEqual(test, ref)
166170

171+
@unittest.skipIf(not TEST_CUDA, "python_agnostic is a CUDA extension + needs CUDA")
172+
@unittest.skipIf(not common.IS_LINUX, "test requires linux tools ldd and nm")
173+
def test_python_agnostic(self):
174+
# For this test, run_test.py will call `python setup.py bdist_wheel` in the
175+
# cpp_extensions/python_agnostic_extension folder, where the extension and
176+
# setup calls specify py_limited_api to `True`. To approximate that the
177+
# extension is indeed python agnostic, we test
178+
# a. The extension wheel name contains "cp39-abi3", meaning the wheel
179+
# should be runnable for any Python 3 version after and including 3.9
180+
# b. The produced shared library does not have libtorch_python.so as a
181+
# dependency from the output of "ldd _C.so"
182+
# c. The .so does not need any python related symbols. We approximate
183+
# this by running "nm -u _C.so" and grepping that nothing starts with "Py"
184+
185+
dist_root = os.path.join("cpp_extensions", "python_agnostic_extension", "dist")
186+
matches = list(Path(dist_root).glob("*.whl"))
187+
self.assertEqual(len(matches), 1, msg=str(matches))
188+
whl_file = matches[0]
189+
self.assertRegex(str(whl_file), r".*python_agnostic-0\.0-cp39-abi3-.*\.whl")
190+
191+
build_root = os.path.join(
192+
"cpp_extensions", "python_agnostic_extension", "build"
193+
)
194+
matches = list(Path(build_root).glob("**/*.so"))
195+
self.assertEqual(len(matches), 1, msg=str(matches))
196+
so_file = matches[0]
197+
lddtree = subprocess.check_output(["ldd", so_file]).decode("utf-8")
198+
self.assertFalse("torch_python" in lddtree)
199+
200+
missing_symbols = subprocess.check_output(["nm", "-u", so_file]).decode("utf-8")
201+
self.assertFalse("Py" in missing_symbols)
202+
203+
# finally, clean up the folder
204+
cmd = [sys.executable, "setup.py", "clean"]
205+
return_code = shell(
206+
cmd,
207+
cwd=os.path.join("cpp_extensions", "python_agnostic_extension"),
208+
env=os.environ.copy(),
209+
)
210+
if return_code != 0:
211+
return return_code
212+
167213

168214
@torch.testing._internal.common_utils.markDynamoStrictTest
169215
class TestPybindTypeCasters(common.TestCase):

torch/utils/cpp_extension.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -812,11 +812,11 @@ def win_wrap_ninja_compile(sources,
812812
output_dir = os.path.abspath(output_dir)
813813

814814
# Note [Absolute include_dirs]
815-
# Convert relative path in self.compiler.include_dirs to absolute path if any,
816-
# For ninja build, the build location is not local, the build happens
817-
# in a in script created build folder, relative path lost their correctness.
815+
# Convert relative path in self.compiler.include_dirs to absolute path if any.
816+
# For ninja build, the build location is not local, but instead, the build happens
817+
# in a script-created build folder. Thus, relative paths lose their correctness.
818818
# To be consistent with jit extension, we allow user to enter relative include_dirs
819-
# in setuptools.setup, and we convert the relative path to absolute path here
819+
# in setuptools.setup, and we convert the relative path to absolute path here.
820820
convert_to_absolute_paths_inplace(self.compiler.include_dirs)
821821

822822
_, objects, extra_postargs, pp_opts, _ = \
@@ -964,6 +964,15 @@ def CppExtension(name, sources, *args, **kwargs):
964964
constructor. Full list arguments can be found at
965965
https://setuptools.pypa.io/en/latest/userguide/ext_modules.html#extension-api-reference
966966
967+
.. note::
968+
The PyTorch python API (as provided in libtorch_python) cannot be built
969+
with the flag ``py_limited_api=True``. When this flag is passed, it is
970+
the user's responsibility in their library to not use APIs from
971+
libtorch_python (in particular pytorch/python bindings) and to only use
972+
APIs from libtorch (aten objects, operators and the dispatcher). For
973+
example, to give access to custom ops from python, the library should
974+
register the ops through the dispatcher.
975+
967976
Example:
968977
>>> # xdoctest: +SKIP
969978
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT)
@@ -994,7 +1003,9 @@ def CppExtension(name, sources, *args, **kwargs):
9941003
libraries.append('c10')
9951004
libraries.append('torch')
9961005
libraries.append('torch_cpu')
997-
libraries.append('torch_python')
1006+
if not kwargs.get('py_limited_api', False):
1007+
# torch_python uses more than the python limited api
1008+
libraries.append('torch_python')
9981009
if IS_WINDOWS and platform.machine().lower() != "arm64":
9991010
libraries.append("sleef")
10001011

@@ -1017,6 +1028,15 @@ def CUDAExtension(name, sources, *args, **kwargs):
10171028
constructor. Full list arguments can be found at
10181029
https://setuptools.pypa.io/en/latest/userguide/ext_modules.html#extension-api-reference
10191030
1031+
.. note::
1032+
The PyTorch python API (as provided in libtorch_python) cannot be built
1033+
with the flag ``py_limited_api=True``. When this flag is passed, it is
1034+
the user's responsibility in their library to not use APIs from
1035+
libtorch_python (in particular pytorch/python bindings) and to only use
1036+
APIs from libtorch (aten objects, operators and the dispatcher). For
1037+
example, to give access to custom ops from python, the library should
1038+
register the ops through the dispatcher.
1039+
10201040
Example:
10211041
>>> # xdoctest: +SKIP
10221042
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT)
@@ -1041,7 +1061,7 @@ def CUDAExtension(name, sources, *args, **kwargs):
10411061
By default the extension will be compiled to run on all archs of the cards visible during the
10421062
building process of the extension, plus PTX. If down the road a new card is installed the
10431063
extension may need to be recompiled. If a visible card has a compute capability (CC) that's
1044-
newer than the newest version for which your nvcc can build fully-compiled binaries, Pytorch
1064+
newer than the newest version for which your nvcc can build fully-compiled binaries, PyTorch
10451065
will make nvcc fall back to building kernels with the newest version of PTX your nvcc does
10461066
support (see below for details on PTX).
10471067
@@ -1085,7 +1105,7 @@ def CUDAExtension(name, sources, *args, **kwargs):
10851105
An exception to this rule is "dynamic parallelism" (nested kernel launches) which is not used a lot anymore.
10861106
`Relocatable device code` is less optimized so it needs to be used only on object files that need it.
10871107
Using `-dlto` (Device Link Time Optimization) at the device code compilation step and `dlink` step
1088-
help reduce the protentional perf degradation of `-rdc`.
1108+
helps reduce the protentional perf degradation of `-rdc`.
10891109
Note that it needs to be used at both steps to be useful.
10901110
10911111
If you have `rdc` objects you need to have an extra `-dlink` (device linking) step before the CPU symbol linking step.
@@ -1114,7 +1134,9 @@ def CUDAExtension(name, sources, *args, **kwargs):
11141134
libraries.append('c10')
11151135
libraries.append('torch')
11161136
libraries.append('torch_cpu')
1117-
libraries.append('torch_python')
1137+
if not kwargs.get('py_limited_api', False):
1138+
# torch_python uses more than the python limited api
1139+
libraries.append('torch_python')
11181140
if IS_HIP_EXTENSION:
11191141
libraries.append('amdhip64')
11201142
libraries.append('c10_hip')
@@ -1381,6 +1403,10 @@ def _get_pybind11_abi_build_flags():
13811403
# that can cause a hard to debug segfaults.
13821404
# For PyTorch extensions we want to relax those restrictions and pass compiler, stdlib and abi properties
13831405
# captured during PyTorch native library compilation in torch/csrc/Module.cpp
1406+
#
1407+
# Note that these flags don't have side effects even if the PyTorch extension does not
1408+
# require nor use pybind, so we do not do anything differently for them in the py_limited_api
1409+
# case.
13841410

13851411
abi_cflags = []
13861412
for pname in ["COMPILER_TYPE", "STDLIB", "BUILD_ABI"]:

0 commit comments

Comments
 (0)