|
13 | 13 | import torch |
14 | 14 | import random |
15 | 15 | import warnings |
16 | | -from time import perf_counter_ns |
17 | | -from typing import Callable, List, Tuple, Dict, Any |
| 16 | +from typing import Dict, Any |
18 | 17 | import torchvision.transforms.v2.functional as F |
19 | 18 | import numpy as np |
| 19 | +from utils import bench, report_stats, print_comparison_table, print_benchmark_info |
20 | 20 |
|
21 | 21 | # Filter out the specific TF32 warning |
22 | 22 | warnings.filterwarnings( |
|
28 | 28 |
|
29 | 29 | try: |
30 | 30 | import cv2 |
31 | | - |
32 | 31 | HAS_OPENCV = True |
33 | 32 | except ImportError: |
34 | 33 | HAS_OPENCV = False |
35 | 34 |
|
36 | 35 | try: |
37 | 36 | import albumentations as A |
38 | | - |
39 | 37 | HAS_ALBUMENTATIONS = True |
40 | 38 | except ImportError: |
41 | 39 | HAS_ALBUMENTATIONS = False |
42 | 40 |
|
43 | 41 | try: |
44 | 42 | import kornia as K |
45 | 43 | import kornia.augmentation as KA |
46 | | - |
47 | 44 | HAS_KORNIA = True |
48 | 45 | except ImportError: |
49 | 46 | HAS_KORNIA = False |
50 | 47 |
|
51 | 48 | from PIL import Image |
52 | | -from tabulate import tabulate |
53 | | -import torchvision |
54 | 49 |
|
55 | 50 | # ImageNet normalization constants |
56 | 51 | NORM_MEAN = [0.485, 0.456, 0.406] |
57 | 52 | NORM_STD = [0.229, 0.224, 0.225] |
58 | 53 |
|
59 | 54 |
|
60 | | -def bench(f: Callable, data_generator: Callable, num_exp: int, warmup: int) -> torch.Tensor: |
61 | | - """ |
62 | | - Benchmark function execution time with fresh data for each experiment. |
63 | | -
|
64 | | - Args: |
65 | | - f: Function to benchmark |
66 | | - data_generator: Callable that returns fresh data for each experiment |
67 | | - num_exp: Number of experiments to run |
68 | | - warmup: Number of warmup runs |
69 | | -
|
70 | | - Returns: |
71 | | - Tensor of execution times in nanoseconds |
72 | | - """ |
73 | | - for _ in range(warmup): |
74 | | - data = data_generator() |
75 | | - f(data) |
76 | | - |
77 | | - times = [] |
78 | | - for _ in range(num_exp): |
79 | | - data = data_generator() |
80 | | - start = perf_counter_ns() |
81 | | - result = f(data) |
82 | | - end = perf_counter_ns() |
83 | | - times.append(end - start) |
84 | | - del result |
85 | | - |
86 | | - return torch.tensor(times, dtype=torch.float32) |
87 | | - |
88 | | - |
89 | | -def report_stats(times: torch.Tensor, unit: str, verbose: bool = True) -> Dict[str, float]: |
90 | | - mul = { |
91 | | - "ns": 1, |
92 | | - "µs": 1e-3, |
93 | | - "ms": 1e-6, |
94 | | - "s": 1e-9, |
95 | | - }[unit] |
96 | | - |
97 | | - times = times * mul |
98 | | - stats = { |
99 | | - "std": times.std().item(), |
100 | | - "median": times.median().item(), |
101 | | - "mean": times.mean().item(), |
102 | | - "min": times.min().item(), |
103 | | - "max": times.max().item(), |
104 | | - } |
105 | | - |
106 | | - if verbose: |
107 | | - print(f" Median: {stats['median']:.2f}{unit} ± {stats['std']:.2f}{unit}") |
108 | | - print(f" Mean: {stats['mean']:.2f}{unit}, Min: {stats['min']:.2f}{unit}, Max: {stats['max']:.2f}{unit}") |
109 | | - |
110 | | - return stats |
111 | 55 |
|
112 | 56 |
|
113 | 57 | def torchvision_pipeline(images: torch.Tensor, target_size: int) -> torch.Tensor: |
@@ -259,80 +203,6 @@ def generate_test_images(): |
259 | 203 | return {"backend": args.backend, "stats": stats} |
260 | 204 |
|
261 | 205 |
|
262 | | -def print_comparison_table(results: List[Dict[str, Any]]) -> None: |
263 | | - torchvision_median = next((r["stats"]["median"] for r in results if r["backend"].lower() == "tv"), None) |
264 | | - |
265 | | - table_data = [] |
266 | | - for result in results: |
267 | | - stats = result["stats"] |
268 | | - relative = f"{stats['median'] / torchvision_median:.2f}x" if torchvision_median else "N/A" |
269 | | - |
270 | | - table_data.append( |
271 | | - { |
272 | | - "Backend": result["backend"], |
273 | | - "Median (ms)": f"{stats['median']:.2f}", |
274 | | - "Std (ms)": f"{stats['std']:.2f}", |
275 | | - "Mean (ms)": f"{stats['mean']:.2f}", |
276 | | - "Min (ms)": f"{stats['min']:.2f}", |
277 | | - "Max (ms)": f"{stats['max']:.2f}", |
278 | | - "Relative": relative, |
279 | | - } |
280 | | - ) |
281 | | - |
282 | | - print(tabulate(table_data, headers="keys", tablefmt="grid")) |
283 | | - |
284 | | - |
285 | | -def print_benchmark_info(args): |
286 | | - """Print benchmark configuration and library versions.""" |
287 | | - device = args.device.lower() |
288 | | - if device in ['gpu', 'cuda']: |
289 | | - device = 'cuda' |
290 | | - else: |
291 | | - device = 'cpu' |
292 | | - |
293 | | - memory_format = 'channels_last' if args.contiguity == 'CL' else 'channels_first' |
294 | | - |
295 | | - config = [ |
296 | | - ["Device", device], |
297 | | - ["Threads", args.num_threads], |
298 | | - ["Batch size", args.batch_size], |
299 | | - ["Memory format", memory_format], |
300 | | - ["Experiments", f"{args.num_exp} (+ {args.warmup} warmup)"], |
301 | | - ["Input → output size", f"{args.min_size}-{args.max_size} → {args.target_size}×{args.target_size}"], |
302 | | - ] |
303 | | - |
304 | | - print(tabulate(config, headers=["Parameter", "Value"], tablefmt="simple")) |
305 | | - print() |
306 | | - |
307 | | - versions = [ |
308 | | - ["PyTorch", torch.__version__], |
309 | | - ["TorchVision", torchvision.__version__], |
310 | | - ] |
311 | | - |
312 | | - if HAS_OPENCV: |
313 | | - versions.append(["OpenCV", cv2.__version__]) |
314 | | - else: |
315 | | - versions.append(["OpenCV", "Not available"]) |
316 | | - |
317 | | - try: |
318 | | - versions.append(["PIL/Pillow", Image.__version__]) |
319 | | - except AttributeError: |
320 | | - versions.append(["PIL/Pillow", "Version unavailable"]) |
321 | | - |
322 | | - if HAS_ALBUMENTATIONS: |
323 | | - versions.append(["Albumentations", A.__version__]) |
324 | | - else: |
325 | | - versions.append(["Albumentations", "Not available"]) |
326 | | - |
327 | | - if HAS_KORNIA: |
328 | | - versions.append(["Kornia", K.__version__]) |
329 | | - else: |
330 | | - versions.append(["Kornia", "Not available"]) |
331 | | - |
332 | | - print(tabulate(versions, headers=["Library", "Version"], tablefmt="simple")) |
333 | | - |
334 | | - print("=" * 80) |
335 | | - print() |
336 | 206 |
|
337 | 207 |
|
338 | 208 | def main(): |
|
0 commit comments