52
52
import contextlib
53
53
import distutils .ccompiler
54
54
import distutils .command .clean
55
+ import importlib .util
55
56
import os
56
57
import re
57
58
import requests
61
62
import tempfile
62
63
import zipfile
63
64
64
- import build_util
65
+ # This gloop imports build_util.py such that it works in Python 3.12's isolated
66
+ # build environment while also not contaminating sys.path which breaks bdist_wheel.
67
+ _PROJECT_DIR = os .path .dirname (os .path .abspath (__file__ ))
68
+ _build_util_path = os .path .join (_PROJECT_DIR , 'build_util.py' )
69
+ spec = importlib .util .spec_from_file_location ('build_util' , _build_util_path )
70
+ build_util = importlib .util .module_from_spec (spec )
71
+ spec .loader .exec_module (build_util )
65
72
66
73
import platform
67
74
@@ -270,15 +277,21 @@ def __init__(self, bazel_target):
270
277
class BuildBazelExtension (build_ext .build_ext ):
271
278
"""A command that runs Bazel to build a C/C++ extension."""
272
279
273
- def run (self ):
274
- for ext in self .extensions :
275
- self .bazel_build (ext )
276
- command .build_ext .build_ext .run (self ) # type: ignore
280
+ def build_extension (self , ext : Extension ) -> None :
281
+ """
282
+ This method is called by setuptools to build a single extension.
283
+ We override it to implement our custom Bazel build logic.
284
+ """
285
+ if not isinstance (ext , BazelExtension ):
286
+ # If it's not our custom extension type, let setuptools handle it.
287
+ super ().build_extension (ext )
288
+ return
277
289
278
- def bazel_build ( self , ext ):
290
+ # 1. Ensure the temporary build directory exists
279
291
if not os .path .exists (self .build_temp ):
280
292
os .makedirs (self .build_temp )
281
293
294
+ # 2. Prepare the Bazel command
282
295
bazel_argv = [
283
296
'bazel' , 'build' , ext .bazel_target ,
284
297
f"--symlink_prefix={ os .path .join (self .build_temp , 'bazel-' )} "
@@ -288,22 +301,31 @@ def bazel_build(self, ext):
288
301
if build_cpp_tests :
289
302
bazel_argv .append ('//:cpp_tests' )
290
303
291
- import torch
292
- cxx_abi = os .getenv ('CXX_ABI' ) or getattr (torch ._C ,
293
- '_GLIBCXX_USE_CXX11_ABI' , None )
294
- if cxx_abi is not None :
295
- bazel_argv .append (f'--cxxopt=-D_GLIBCXX_USE_CXX11_ABI={ int (cxx_abi )} ' )
304
+ cxx_abi = os .getenv ('CXX_ABI' )
305
+ if cxx_abi is None :
306
+ try :
307
+ import torch
308
+ cxx_abi = getattr (torch ._C , '_GLIBCXX_USE_CXX11_ABI' , None )
309
+ except :
310
+ pass
311
+ if cxx_abi is None :
312
+ # Default to building with C++11 ABI, which has been the case since PyTorch 2.7
313
+ cxx_abi = "1"
314
+ bazel_argv .append (f'--cxxopt=-D_GLIBCXX_USE_CXX11_ABI={ int (cxx_abi )} ' )
296
315
297
316
bazel_argv .extend (build_util .bazel_options_from_env ())
298
317
318
+ # 3. Run the Bazel build
299
319
self .spawn (bazel_argv )
300
320
321
+ # 4. Copy the output file to the location setuptools expects
301
322
ext_bazel_bin_path = os .path .join (self .build_temp , 'bazel-bin' , ext .relpath ,
302
323
ext .target_name )
303
324
ext_dest_path = self .get_ext_fullpath (ext .name )
304
325
ext_dest_dir = os .path .dirname (ext_dest_path )
305
326
if not os .path .exists (ext_dest_dir ):
306
327
os .makedirs (ext_dest_dir )
328
+
307
329
shutil .copyfile (ext_bazel_bin_path , ext_dest_path )
308
330
309
331
@@ -313,17 +335,28 @@ def bazel_build(self, ext):
313
335
long_description = f .read ()
314
336
315
337
# Finds torch_xla and its subpackages
316
- packages_to_include = find_packages (include = ['torch_xla*' ])
317
- # Explicitly add torchax
318
- packages_to_include .extend (find_packages (where = 'torchax' , include = ['torchax*' ]))
338
+ # 1. Find `torch_xla` and its subpackages automatically from the root.
339
+ packages_to_include = find_packages (include = ['torch_xla' , 'torch_xla.*' ])
340
+
341
+ # 2. Explicitly find the contents of the nested `torchax` package.
342
+ # Find all sub-packages within the torchax directory (e.g., 'ops').
343
+ torchax_source_dir = 'torchax/torchax'
344
+ torchax_subpackages = find_packages (where = torchax_source_dir )
345
+ # Construct the full list of packages, starting with the top-level
346
+ # 'torchax' and adding all the discovered sub-packages.
347
+ packages_to_include .extend (['torchax' ] +
348
+ ['torchax.' + pkg for pkg in torchax_subpackages ])
319
349
320
- # Map the top-level 'torchax' package name to its source location
321
- torchax_dir = os .path .join (cwd , 'torchax' )
322
- package_dir_mapping = {'torch_xla' : os .path .join (cwd , 'torch_xla' )}
323
- package_dir_mapping ['torchax' ] = os .path .join (torchax_dir , 'torchax' )
350
+ # 3. The package_dir mapping explicitly tells setuptools where the 'torchax'
351
+ # package's source code begins. `torch_xla` source code is inferred.
352
+ package_dir_mapping = {'torchax' : torchax_source_dir }
324
353
325
354
326
355
class Develop (develop .develop ):
356
+ """
357
+ Custom develop command to build C++ extensions and create a .pth file
358
+ for a multi-package editable install.
359
+ """
327
360
328
361
def run (self ):
329
362
# Build the C++ extension
@@ -348,44 +381,42 @@ def link_packages(self):
348
381
(`python setup.py develop`). Nightly and release wheel builds work out of the box
349
382
without egg-link/pth.
350
383
"""
384
+ import glob
385
+
351
386
# Ensure paths like self.install_dir are set
352
387
self .ensure_finalized ()
353
388
354
- # Get the site-packages directory
355
- target_dir = self .install_dir
356
-
357
- # Remove the standard .egg-link file
358
- # It's usually named based on the distribution name
359
389
dist_name = self .distribution .get_name ()
360
- egg_link_file = os .path .join (target_dir , dist_name + '.egg-link' )
361
- if os .path .exists (egg_link_file ):
362
- print (f"Removing default egg-link file: { egg_link_file } " )
363
- try :
364
- os .remove (egg_link_file )
365
- except OSError as e :
366
- print (f"Warning: Could not remove { egg_link_file } : { e } " )
367
-
368
- # Create our custom .pth file with specific paths
369
- cwd = os .path .dirname (__file__ )
370
- # Path containing 'torch_xla' package source: ROOT
371
- path_for_torch_xla = os .path .abspath (cwd )
372
- # Path containing 'torchax' package source: ROOT/torchax
373
- path_for_torchax = os .path .abspath (os .path .join (cwd , 'torchax' ))
374
-
375
- paths_to_add = {path_for_torch_xla , path_for_torchax }
376
-
377
- # Construct a suitable .pth filename (PEP 660 style is good practice)
378
- version = self .distribution .get_version ()
379
- # Sanitize name and version for filename (replace runs of non-alphanumeric chars with '-')
380
- sanitized_name = re .sub (r"[^a-zA-Z0-9.]+" , "_" , dist_name )
381
- sanitized_version = re .sub (r"[^a-zA-Z0-9.]+" , "_" , version )
382
- pth_filename = os .path .join (
383
- target_dir , f"__editable_{ sanitized_name } _{ sanitized_version } .pth" )
384
-
385
- # Ensure site-packages exists
386
- os .makedirs (target_dir , exist_ok = True )
387
-
388
- # Write the paths to the .pth file, one per line
390
+ install_cmd = self .get_finalized_command ('install' )
391
+ target_dir = install_cmd .install_lib
392
+ assert target_dir is not None
393
+
394
+ # Use glob to robustly find and remove the conflicting files.
395
+ # This is safer than trying to guess the exact sanitized filename.
396
+ safe_name_part = re .sub (r"[^a-zA-Z0-9]+" , "_" , dist_name )
397
+
398
+ for pattern in [
399
+ # Remove `.pth` files generated in Python 3.12.
400
+ f"__editable__.*{ safe_name_part } *.pth" ,
401
+ f"__editable___*{ safe_name_part } *_finder.py" ,
402
+ # Also remove the legacy egg-link format.
403
+ f"{ dist_name } .egg-link"
404
+ ]:
405
+ for filepath in glob .glob (os .path .join (target_dir , pattern )):
406
+ print (f"Cleaning up conflicting install file: { filepath } " )
407
+ with contextlib .suppress (OSError ):
408
+ os .remove (filepath )
409
+
410
+ # Finally, create our own simple, multi-path .pth file.
411
+ # We name it simply, e.g., "torch_xla.pth".
412
+ pth_filename = os .path .join (target_dir , f"{ dist_name } .pth" )
413
+
414
+ project_root = os .path .dirname (os .path .abspath (__file__ ))
415
+ paths_to_add = {
416
+ project_root , # For `torch_xla`
417
+ os .path .abspath (os .path .join (project_root , 'torchax' )), # For `torchax`
418
+ }
419
+
389
420
with open (pth_filename , "w" , encoding = 'utf-8' ) as f :
390
421
for path in sorted (paths_to_add ):
391
422
f .write (path + "\n " )
0 commit comments