5050
5151from PIL import Image
5252from tabulate import tabulate
53+ import torchvision
5354
5455# ImageNet normalization constants
5556NORM_MEAN = [0.485 , 0.456 , 0.406 ]
@@ -165,6 +166,14 @@ def kornia_pipeline(image: torch.Tensor, target_size: int) -> torch.Tensor:
165166
166167def run_benchmark (args ) -> Dict [str , Any ]:
167168 backend = args .backend .lower ()
169+
170+ device = args .device .lower ()
171+ # Check device compatibility
172+ if device == 'cuda' and backend not in ['tv' , 'tv-compiled' ]:
173+ raise RuntimeError (f"CUDA device not supported for { backend } backend. Only 'tv' and 'tv-compiled' support CUDA." )
174+
175+ if device == 'cuda' and not torch .cuda .is_available ():
176+ raise RuntimeError ("CUDA not available. Install cuda-enabled torch and torchvision, or use 'cpu' device." )
168177
169178 if backend == "opencv" and not HAS_OPENCV :
170179 raise RuntimeError ("OpenCV not available. Install with: pip install opencv-python" )
@@ -176,7 +185,7 @@ def run_benchmark(args) -> Dict[str, Any]:
176185 if args .verbose :
177186 backend_display = args .backend .upper ()
178187 print (f"\n === { backend_display } ===" )
179- print (f"Threads: { args .num_threads } , Batch size: { args .batch_size } " )
188+ print (f"Device: { device } , Threads: { args .num_threads } , Batch size: { args .batch_size } " )
180189
181190 memory_format = torch .channels_last if args .contiguity == "CL" else torch .contiguous_format
182191 print (f"Memory format: { 'channels_last' if memory_format == torch .channels_last else 'channels_first' } " )
@@ -208,6 +217,10 @@ def generate_test_images():
208217 memory_format = torch .channels_last if args .contiguity == "CL" else torch .contiguous_format
209218 if memory_format == torch .channels_last :
210219 images = images .to (memory_format = torch .channels_last )
220+
221+ # Move to device for torchvision backends
222+ if backend in ['tv' , 'tv-compiled' ]:
223+ images = images .to (device )
211224
212225 if args .batch_size == 1 :
213226 images = images [0 ]
@@ -269,6 +282,59 @@ def print_comparison_table(results: List[Dict[str, Any]]) -> None:
269282 print (tabulate (table_data , headers = "keys" , tablefmt = "grid" ))
270283
271284
285+ def print_benchmark_info (args ):
286+ """Print benchmark configuration and library versions."""
287+ device = args .device .lower ()
288+ if device in ['gpu' , 'cuda' ]:
289+ device = 'cuda'
290+ else :
291+ device = 'cpu'
292+
293+ memory_format = 'channels_last' if args .contiguity == 'CL' else 'channels_first'
294+
295+ config = [
296+ ["Device" , device ],
297+ ["Threads" , args .num_threads ],
298+ ["Batch size" , args .batch_size ],
299+ ["Memory format" , memory_format ],
300+ ["Experiments" , f"{ args .num_exp } (+ { args .warmup } warmup)" ],
301+ ["Input → output size" , f"{ args .min_size } -{ args .max_size } → { args .target_size } ×{ args .target_size } " ],
302+ ]
303+
304+ print (tabulate (config , headers = ["Parameter" , "Value" ], tablefmt = "simple" ))
305+ print ()
306+
307+ versions = [
308+ ["PyTorch" , torch .__version__ ],
309+ ["TorchVision" , torchvision .__version__ ],
310+ ]
311+
312+ if HAS_OPENCV :
313+ versions .append (["OpenCV" , cv2 .__version__ ])
314+ else :
315+ versions .append (["OpenCV" , "Not available" ])
316+
317+ try :
318+ versions .append (["PIL/Pillow" , Image .__version__ ])
319+ except AttributeError :
320+ versions .append (["PIL/Pillow" , "Version unavailable" ])
321+
322+ if HAS_ALBUMENTATIONS :
323+ versions .append (["Albumentations" , A .__version__ ])
324+ else :
325+ versions .append (["Albumentations" , "Not available" ])
326+
327+ if HAS_KORNIA :
328+ versions .append (["Kornia" , K .__version__ ])
329+ else :
330+ versions .append (["Kornia" , "Not available" ])
331+
332+ print (tabulate (versions , headers = ["Library" , "Version" ], tablefmt = "simple" ))
333+
334+ print ("=" * 80 )
335+ print ()
336+
337+
272338def main ():
273339 parser = argparse .ArgumentParser (description = "Benchmark torchvision transforms" )
274340 parser .add_argument ("--num-exp" , type = int , default = 100 , help = "Number of experiments we average over" )
@@ -297,10 +363,11 @@ def main():
297363 "--backend" , type = str .lower , choices = all_backends + ["all" ], default = "all" , help = "Backend to benchmark"
298364 )
299365 parser .add_argument ("-v" , "--verbose" , action = "store_true" , help = "Enable verbose output" )
366+ parser .add_argument ("--device" , type = str , default = "cpu" , help = "Device to use: cpu, cuda, or gpu (default: cpu)" )
300367
301368 args = parser .parse_args ()
302369
303- print ( f"Averaging over { args . num_exp } runs, { args . warmup } warmup runs" )
370+ print_benchmark_info ( args )
304371
305372 backends_to_run = all_backends if args .backend .lower () == "all" else [args .backend ]
306373 results = []
0 commit comments