Skip to content

Commit fd12d6a

Browse files
committed
move into benchmarks/ folder
1 parent f13f6f1 commit fd12d6a

File tree

2 files changed

+170
-132
lines changed

2 files changed

+170
-132
lines changed

benchmark_transforms.py renamed to benchmarks/benchmark_transforms.py

Lines changed: 2 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
import torch
1414
import random
1515
import warnings
16-
from time import perf_counter_ns
17-
from typing import Callable, List, Tuple, Dict, Any
16+
from typing import Dict, Any
1817
import torchvision.transforms.v2.functional as F
1918
import numpy as np
19+
from utils import bench, report_stats, print_comparison_table, print_benchmark_info
2020

2121
# Filter out the specific TF32 warning
2222
warnings.filterwarnings(
@@ -28,86 +28,30 @@
2828

2929
try:
3030
import cv2
31-
3231
HAS_OPENCV = True
3332
except ImportError:
3433
HAS_OPENCV = False
3534

3635
try:
3736
import albumentations as A
38-
3937
HAS_ALBUMENTATIONS = True
4038
except ImportError:
4139
HAS_ALBUMENTATIONS = False
4240

4341
try:
4442
import kornia as K
4543
import kornia.augmentation as KA
46-
4744
HAS_KORNIA = True
4845
except ImportError:
4946
HAS_KORNIA = False
5047

5148
from PIL import Image
52-
from tabulate import tabulate
53-
import torchvision
5449

5550
# ImageNet normalization constants
5651
NORM_MEAN = [0.485, 0.456, 0.406]
5752
NORM_STD = [0.229, 0.224, 0.225]
5853

5954

60-
def bench(f: Callable, data_generator: Callable, num_exp: int, warmup: int) -> torch.Tensor:
61-
"""
62-
Benchmark function execution time with fresh data for each experiment.
63-
64-
Args:
65-
f: Function to benchmark
66-
data_generator: Callable that returns fresh data for each experiment
67-
num_exp: Number of experiments to run
68-
warmup: Number of warmup runs
69-
70-
Returns:
71-
Tensor of execution times in nanoseconds
72-
"""
73-
for _ in range(warmup):
74-
data = data_generator()
75-
f(data)
76-
77-
times = []
78-
for _ in range(num_exp):
79-
data = data_generator()
80-
start = perf_counter_ns()
81-
result = f(data)
82-
end = perf_counter_ns()
83-
times.append(end - start)
84-
del result
85-
86-
return torch.tensor(times, dtype=torch.float32)
87-
88-
89-
def report_stats(times: torch.Tensor, unit: str, verbose: bool = True) -> Dict[str, float]:
90-
mul = {
91-
"ns": 1,
92-
"µs": 1e-3,
93-
"ms": 1e-6,
94-
"s": 1e-9,
95-
}[unit]
96-
97-
times = times * mul
98-
stats = {
99-
"std": times.std().item(),
100-
"median": times.median().item(),
101-
"mean": times.mean().item(),
102-
"min": times.min().item(),
103-
"max": times.max().item(),
104-
}
105-
106-
if verbose:
107-
print(f" Median: {stats['median']:.2f}{unit} ± {stats['std']:.2f}{unit}")
108-
print(f" Mean: {stats['mean']:.2f}{unit}, Min: {stats['min']:.2f}{unit}, Max: {stats['max']:.2f}{unit}")
109-
110-
return stats
11155

11256

11357
def torchvision_pipeline(images: torch.Tensor, target_size: int) -> torch.Tensor:
@@ -259,80 +203,6 @@ def generate_test_images():
259203
return {"backend": args.backend, "stats": stats}
260204

261205

262-
def print_comparison_table(results: List[Dict[str, Any]]) -> None:
263-
torchvision_median = next((r["stats"]["median"] for r in results if r["backend"].lower() == "tv"), None)
264-
265-
table_data = []
266-
for result in results:
267-
stats = result["stats"]
268-
relative = f"{stats['median'] / torchvision_median:.2f}x" if torchvision_median else "N/A"
269-
270-
table_data.append(
271-
{
272-
"Backend": result["backend"],
273-
"Median (ms)": f"{stats['median']:.2f}",
274-
"Std (ms)": f"{stats['std']:.2f}",
275-
"Mean (ms)": f"{stats['mean']:.2f}",
276-
"Min (ms)": f"{stats['min']:.2f}",
277-
"Max (ms)": f"{stats['max']:.2f}",
278-
"Relative": relative,
279-
}
280-
)
281-
282-
print(tabulate(table_data, headers="keys", tablefmt="grid"))
283-
284-
285-
def print_benchmark_info(args):
286-
"""Print benchmark configuration and library versions."""
287-
device = args.device.lower()
288-
if device in ['gpu', 'cuda']:
289-
device = 'cuda'
290-
else:
291-
device = 'cpu'
292-
293-
memory_format = 'channels_last' if args.contiguity == 'CL' else 'channels_first'
294-
295-
config = [
296-
["Device", device],
297-
["Threads", args.num_threads],
298-
["Batch size", args.batch_size],
299-
["Memory format", memory_format],
300-
["Experiments", f"{args.num_exp} (+ {args.warmup} warmup)"],
301-
["Input → output size", f"{args.min_size}-{args.max_size}{args.target_size}×{args.target_size}"],
302-
]
303-
304-
print(tabulate(config, headers=["Parameter", "Value"], tablefmt="simple"))
305-
print()
306-
307-
versions = [
308-
["PyTorch", torch.__version__],
309-
["TorchVision", torchvision.__version__],
310-
]
311-
312-
if HAS_OPENCV:
313-
versions.append(["OpenCV", cv2.__version__])
314-
else:
315-
versions.append(["OpenCV", "Not available"])
316-
317-
try:
318-
versions.append(["PIL/Pillow", Image.__version__])
319-
except AttributeError:
320-
versions.append(["PIL/Pillow", "Version unavailable"])
321-
322-
if HAS_ALBUMENTATIONS:
323-
versions.append(["Albumentations", A.__version__])
324-
else:
325-
versions.append(["Albumentations", "Not available"])
326-
327-
if HAS_KORNIA:
328-
versions.append(["Kornia", K.__version__])
329-
else:
330-
versions.append(["Kornia", "Not available"])
331-
332-
print(tabulate(versions, headers=["Library", "Version"], tablefmt="simple"))
333-
334-
print("=" * 80)
335-
print()
336206

337207

338208
def main():

benchmarks/utils.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
"""
2+
Utility functions for benchmarking transforms.
3+
"""
4+
5+
import torch
6+
import torchvision
7+
from time import perf_counter_ns
8+
from typing import Callable, List, Dict, Any
9+
from tabulate import tabulate
10+
11+
try:
12+
import cv2
13+
HAS_OPENCV = True
14+
except ImportError:
15+
HAS_OPENCV = False
16+
17+
try:
18+
import albumentations as A
19+
HAS_ALBUMENTATIONS = True
20+
except ImportError:
21+
HAS_ALBUMENTATIONS = False
22+
23+
try:
24+
import kornia as K
25+
HAS_KORNIA = True
26+
except ImportError:
27+
HAS_KORNIA = False
28+
29+
from PIL import Image
30+
31+
32+
def bench(f: Callable, data_generator: Callable, num_exp: int, warmup: int) -> torch.Tensor:
33+
"""
34+
Benchmark function execution time with fresh data for each experiment.
35+
36+
Args:
37+
f: Function to benchmark
38+
data_generator: Callable that returns fresh data for each experiment
39+
num_exp: Number of experiments to run
40+
warmup: Number of warmup runs
41+
42+
Returns:
43+
Tensor of execution times in nanoseconds
44+
"""
45+
for _ in range(warmup):
46+
data = data_generator()
47+
f(data)
48+
49+
times = []
50+
for _ in range(num_exp):
51+
data = data_generator()
52+
start = perf_counter_ns()
53+
result = f(data)
54+
end = perf_counter_ns()
55+
times.append(end - start)
56+
del result
57+
58+
return torch.tensor(times, dtype=torch.float32)
59+
60+
61+
def report_stats(times: torch.Tensor, unit: str, verbose: bool = True) -> Dict[str, float]:
62+
mul = {
63+
"ns": 1,
64+
"µs": 1e-3,
65+
"ms": 1e-6,
66+
"s": 1e-9,
67+
}[unit]
68+
69+
times = times * mul
70+
stats = {
71+
"std": times.std().item(),
72+
"median": times.median().item(),
73+
"mean": times.mean().item(),
74+
"min": times.min().item(),
75+
"max": times.max().item(),
76+
}
77+
78+
if verbose:
79+
print(f" Median: {stats['median']:.2f}{unit} ± {stats['std']:.2f}{unit}")
80+
print(f" Mean: {stats['mean']:.2f}{unit}, Min: {stats['min']:.2f}{unit}, Max: {stats['max']:.2f}{unit}")
81+
82+
return stats
83+
84+
85+
def print_comparison_table(results: List[Dict[str, Any]]) -> None:
86+
torchvision_median = next((r["stats"]["median"] for r in results if r["backend"].lower() == "tv"), None)
87+
88+
table_data = []
89+
for result in results:
90+
stats = result["stats"]
91+
relative = f"{stats['median'] / torchvision_median:.2f}x" if torchvision_median else "N/A"
92+
93+
table_data.append(
94+
{
95+
"Backend": result["backend"],
96+
"Median (ms)": f"{stats['median']:.2f}",
97+
"Std (ms)": f"{stats['std']:.2f}",
98+
"Mean (ms)": f"{stats['mean']:.2f}",
99+
"Min (ms)": f"{stats['min']:.2f}",
100+
"Max (ms)": f"{stats['max']:.2f}",
101+
"Relative": relative,
102+
}
103+
)
104+
105+
print(tabulate(table_data, headers="keys", tablefmt="grid"))
106+
107+
108+
def print_benchmark_info(args):
109+
"""Print benchmark configuration and library versions."""
110+
device = args.device.lower()
111+
if device in ['gpu', 'cuda']:
112+
device = 'cuda'
113+
else:
114+
device = 'cpu'
115+
116+
memory_format = 'channels_last' if args.contiguity == 'CL' else 'channels_first'
117+
118+
print("=" * 80)
119+
print("BENCHMARK CONFIGURATION")
120+
print("=" * 80)
121+
122+
# Collect configuration info
123+
config = [
124+
["Device", device],
125+
["Threads", args.num_threads],
126+
["Batch size", args.batch_size],
127+
["Memory format", memory_format],
128+
["Experiments", f"{args.num_exp} (+ {args.warmup} warmup)"],
129+
["Input → output size", f"{args.min_size}-{args.max_size}{args.target_size}×{args.target_size}"],
130+
]
131+
132+
print(tabulate(config, headers=["Parameter", "Value"], tablefmt="simple"))
133+
print()
134+
135+
print("=" * 80)
136+
print("LIBRARY VERSIONS")
137+
print("=" * 80)
138+
139+
# Collect library versions
140+
versions = [
141+
["PyTorch", torch.__version__],
142+
["TorchVision", torchvision.__version__],
143+
]
144+
145+
if HAS_OPENCV:
146+
versions.append(["OpenCV", cv2.__version__])
147+
else:
148+
versions.append(["OpenCV", "Not available"])
149+
150+
# PIL version
151+
try:
152+
versions.append(["PIL/Pillow", Image.__version__])
153+
except AttributeError:
154+
versions.append(["PIL/Pillow", "Version unavailable"])
155+
156+
if HAS_ALBUMENTATIONS:
157+
versions.append(["Albumentations", A.__version__])
158+
else:
159+
versions.append(["Albumentations", "Not available"])
160+
161+
if HAS_KORNIA:
162+
versions.append(["Kornia", K.__version__])
163+
else:
164+
versions.append(["Kornia", "Not available"])
165+
166+
print(tabulate(versions, headers=["Library", "Version"], tablefmt="simple"))
167+
print("=" * 80)
168+
print()

0 commit comments

Comments
 (0)