File tree Expand file tree Collapse file tree 2 files changed +19
-0
lines changed
tests/quantization/quanto Expand file tree Collapse file tree 2 files changed +19
-0
lines changed Original file line number Diff line number Diff line change 101101 mps_backend_registered = hasattr (torch .backends , "mps" )
102102 torch_device = "mps" if (mps_backend_registered and torch .backends .mps .is_available ()) else torch_device
103103
104+ from .torch_utils import get_torch_cuda_device_capability
105+
104106
105107def torch_all_close (a , b , * args , ** kwargs ):
106108 if not is_torch_available ():
@@ -282,6 +284,20 @@ def require_torch_gpu(test_case):
282284 )
283285
284286
287+ def require_torch_cuda_compatibility (expected_compute_capability ):
288+ def decorator (test_case ):
289+ if not torch .cuda .is_available ():
290+ return unittest .skip (test_case )
291+ else :
292+ current_compute_capability = get_torch_cuda_device_capability ()
293+ return unittest .skipUnless (
294+ float (current_compute_capability ) == float (expected_compute_capability ),
295+ "Test not supported for this compute capability." ,
296+ )
297+
298+ return decorator
299+
300+
285301# These decorators are for accelerator-specific behaviours that are not GPU-specific
286302def require_torch_accelerator (test_case ):
287303 """Decorator marking a test that requires an accelerator backend and PyTorch."""
Original file line number Diff line number Diff line change 1010 numpy_cosine_similarity_distance ,
1111 require_accelerate ,
1212 require_big_gpu_with_torch_cuda ,
13+ require_torch_cuda_compatibility ,
1314 torch_device ,
1415)
1516
@@ -311,13 +312,15 @@ def get_dummy_init_kwargs(self):
311312 return {"weights_dtype" : "int8" }
312313
313314
315+ @require_torch_cuda_compatibility (8.0 )
314316class FluxTransformerInt4WeightsTest (FluxTransformerQuantoMixin , unittest .TestCase ):
315317 expected_memory_reduction = 0.55
316318
317319 def get_dummy_init_kwargs (self ):
318320 return {"weights_dtype" : "int4" }
319321
320322
323+ @require_torch_cuda_compatibility (8.0 )
321324class FluxTransformerInt2WeightsTest (FluxTransformerQuantoMixin , unittest .TestCase ):
322325 expected_memory_reduction = 0.65
323326
You can’t perform that action at this time.
0 commit comments