Skip to content

Commit f13f6f1

Browse files
committed
Add CUDA, didn't test. Add benchmark info
1 parent 4d1ba4d commit f13f6f1

File tree

1 file changed

+69
-2
lines changed

1 file changed

+69
-2
lines changed

benchmark_transforms.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050

5151
from PIL import Image
5252
from tabulate import tabulate
53+
import torchvision
5354

5455
# ImageNet normalization constants
5556
NORM_MEAN = [0.485, 0.456, 0.406]
@@ -165,6 +166,14 @@ def kornia_pipeline(image: torch.Tensor, target_size: int) -> torch.Tensor:
165166

166167
def run_benchmark(args) -> Dict[str, Any]:
167168
backend = args.backend.lower()
169+
170+
device = args.device.lower()
171+
# Check device compatibility
172+
if device == 'cuda' and backend not in ['tv', 'tv-compiled']:
173+
raise RuntimeError(f"CUDA device not supported for {backend} backend. Only 'tv' and 'tv-compiled' support CUDA.")
174+
175+
if device == 'cuda' and not torch.cuda.is_available():
176+
raise RuntimeError("CUDA not available. Install cuda-enabled torch and torchvision, or use 'cpu' device.")
168177

169178
if backend == "opencv" and not HAS_OPENCV:
170179
raise RuntimeError("OpenCV not available. Install with: pip install opencv-python")
@@ -176,7 +185,7 @@ def run_benchmark(args) -> Dict[str, Any]:
176185
if args.verbose:
177186
backend_display = args.backend.upper()
178187
print(f"\n=== {backend_display} ===")
179-
print(f"Threads: {args.num_threads}, Batch size: {args.batch_size}")
188+
print(f"Device: {device}, Threads: {args.num_threads}, Batch size: {args.batch_size}")
180189

181190
memory_format = torch.channels_last if args.contiguity == "CL" else torch.contiguous_format
182191
print(f"Memory format: {'channels_last' if memory_format == torch.channels_last else 'channels_first'}")
@@ -208,6 +217,10 @@ def generate_test_images():
208217
memory_format = torch.channels_last if args.contiguity == "CL" else torch.contiguous_format
209218
if memory_format == torch.channels_last:
210219
images = images.to(memory_format=torch.channels_last)
220+
221+
# Move to device for torchvision backends
222+
if backend in ['tv', 'tv-compiled']:
223+
images = images.to(device)
211224

212225
if args.batch_size == 1:
213226
images = images[0]
@@ -269,6 +282,59 @@ def print_comparison_table(results: List[Dict[str, Any]]) -> None:
269282
print(tabulate(table_data, headers="keys", tablefmt="grid"))
270283

271284

285+
def print_benchmark_info(args):
286+
"""Print benchmark configuration and library versions."""
287+
device = args.device.lower()
288+
if device in ['gpu', 'cuda']:
289+
device = 'cuda'
290+
else:
291+
device = 'cpu'
292+
293+
memory_format = 'channels_last' if args.contiguity == 'CL' else 'channels_first'
294+
295+
config = [
296+
["Device", device],
297+
["Threads", args.num_threads],
298+
["Batch size", args.batch_size],
299+
["Memory format", memory_format],
300+
["Experiments", f"{args.num_exp} (+ {args.warmup} warmup)"],
301+
["Input → output size", f"{args.min_size}-{args.max_size}{args.target_size}×{args.target_size}"],
302+
]
303+
304+
print(tabulate(config, headers=["Parameter", "Value"], tablefmt="simple"))
305+
print()
306+
307+
versions = [
308+
["PyTorch", torch.__version__],
309+
["TorchVision", torchvision.__version__],
310+
]
311+
312+
if HAS_OPENCV:
313+
versions.append(["OpenCV", cv2.__version__])
314+
else:
315+
versions.append(["OpenCV", "Not available"])
316+
317+
try:
318+
versions.append(["PIL/Pillow", Image.__version__])
319+
except AttributeError:
320+
versions.append(["PIL/Pillow", "Version unavailable"])
321+
322+
if HAS_ALBUMENTATIONS:
323+
versions.append(["Albumentations", A.__version__])
324+
else:
325+
versions.append(["Albumentations", "Not available"])
326+
327+
if HAS_KORNIA:
328+
versions.append(["Kornia", K.__version__])
329+
else:
330+
versions.append(["Kornia", "Not available"])
331+
332+
print(tabulate(versions, headers=["Library", "Version"], tablefmt="simple"))
333+
334+
print("=" * 80)
335+
print()
336+
337+
272338
def main():
273339
parser = argparse.ArgumentParser(description="Benchmark torchvision transforms")
274340
parser.add_argument("--num-exp", type=int, default=100, help="Number of experiments we average over")
@@ -297,10 +363,11 @@ def main():
297363
"--backend", type=str.lower, choices=all_backends + ["all"], default="all", help="Backend to benchmark"
298364
)
299365
parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose output")
366+
parser.add_argument("--device", type=str, default="cpu", help="Device to use: cpu, cuda, or gpu (default: cpu)")
300367

301368
args = parser.parse_args()
302369

303-
print(f"Averaging over {args.num_exp} runs, {args.warmup} warmup runs")
370+
print_benchmark_info(args)
304371

305372
backends_to_run = all_backends if args.backend.lower() == "all" else [args.backend]
306373
results = []

0 commit comments

Comments
 (0)