1717 TaylorSeerCalibratorConfig ,
1818)
1919
20+ from cache_dit .platforms import current_platform
21+
2022logger = init_logger (__name__ )
2123
2224
2325class MemoryTracker :
2426 """Track peak GPU memory usage during execution."""
2527
2628 def __init__ (self , device = None ):
27- self .device = device if device is not None else torch . cuda .current_device ()
28- self .enabled = torch . cuda . is_available ()
29+ self .device = device if device is not None else current_platform .current_device ()
30+ self .enabled = current_platform . is_accelerator_available ()
2931 self .peak_memory = 0
3032
3133 def __enter__ (self ):
3234 if self .enabled :
33- torch . cuda .reset_peak_memory_stats (self .device )
34- torch . cuda .synchronize (self .device )
35+ current_platform .reset_peak_memory_stats (self .device )
36+ current_platform .synchronize (self .device )
3537 return self
3638
3739 def __exit__ (self , exc_type , exc_val , exc_tb ):
3840 if self .enabled :
39- torch . cuda .synchronize (self .device )
40- self .peak_memory = torch . cuda .max_memory_allocated (self .device )
41+ current_platform .synchronize (self .device )
42+ self .peak_memory = current_platform .max_memory_allocated (self .device )
4143
4244 def get_peak_memory_gb (self ):
4345 """Get peak memory in GB."""
@@ -54,10 +56,10 @@ def report(self):
5456
5557def GiB ():
5658 try :
57- if not torch . cuda . is_available ():
59+ if not current_platform . is_accelerator_available ():
5860 return 0
59- total_memory_bytes = torch . cuda .get_device_properties (
60- torch . cuda .current_device (),
61+ total_memory_bytes = current_platform .get_device_properties (
62+ current_platform .current_device (),
6163 ).total_memory
6264 total_memory_gib = total_memory_bytes / (1024 ** 3 )
6365 return int (total_memory_gib )
@@ -1346,21 +1348,32 @@ def strify(args, pipe_or_stats):
13461348
13471349
13481350def get_rank_device ():
1351+ available = current_platform .is_accelerator_available ()
1352+ device_type = current_platform .device_type
13491353 if dist .is_initialized ():
13501354 rank = dist .get_rank ()
1351- device = torch .device ("cuda" , rank % torch . cuda .device_count ())
1355+ device = torch .device (device_type , rank % current_platform .device_count ())
13521356 return rank , device
1353- return 0 , torch .device ("cuda" if torch . cuda . is_available () else "cpu" )
1357+ return 0 , torch .device (device_type if available else "cpu" )
13541358
13551359
13561360def maybe_init_distributed (args = None ):
1361+ from cache_dit .platforms .platform import CpuPlatform
1362+
1363+ platform_full_backend = current_platform .full_dist_backend
1364+ cpu_full_backend = CpuPlatform .full_dist_backend
1365+ backend = (
1366+ f"{ cpu_full_backend } ,{ platform_full_backend } "
1367+ if args .ulysses_anything
1368+ else current_platform .dist_backend
1369+ )
13571370 if args is not None :
13581371 if args .parallel_type is not None :
13591372 dist .init_process_group (
1360- backend = "cpu:gloo,cuda:nccl" if args . ulysses_anything else "nccl" ,
1373+ backend = backend ,
13611374 )
13621375 rank , device = get_rank_device ()
1363- torch . cuda .set_device (device )
1376+ current_platform .set_device (device )
13641377 return rank , device
13651378 else :
13661379 # no distributed needed
@@ -1370,10 +1383,10 @@ def maybe_init_distributed(args=None):
13701383 # always init distributed for other examples
13711384 if not dist .is_initialized ():
13721385 dist .init_process_group (
1373- backend = "nccl" ,
1386+ backend = platform_full_backend ,
13741387 )
13751388 rank , device = get_rank_device ()
1376- torch . cuda .set_device (device )
1389+ current_platform .set_device (device )
13771390 return rank , device
13781391
13791392
0 commit comments