Skip to content

Commit f07ae8e

Browse files
committed
Added kornia backend
1 parent 9a10b07 commit f07ae8e

File tree

1 file changed

+91
-38
lines changed

1 file changed

+91
-38
lines changed

benchmark_transforms.py

Lines changed: 91 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,26 @@
1919

2020
try:
2121
import cv2
22+
2223
HAS_OPENCV = True
2324
except ImportError:
2425
HAS_OPENCV = False
2526

2627
try:
2728
import albumentations as A
29+
2830
HAS_ALBUMENTATIONS = True
2931
except ImportError:
3032
HAS_ALBUMENTATIONS = False
3133

34+
try:
35+
import kornia as K
36+
import kornia.augmentation as KA
37+
38+
HAS_KORNIA = True
39+
except ImportError:
40+
HAS_KORNIA = False
41+
3242
from PIL import Image
3343
from tabulate import tabulate
3444

@@ -82,16 +92,18 @@ def report_stats(times: torch.Tensor, unit: str, verbose: bool = True) -> Dict[s
8292
"min": times.min().item(),
8393
"max": times.max().item(),
8494
}
85-
95+
8696
if verbose:
8797
print(f" Median: {stats['median']:.2f}{unit} ± {stats['std']:.2f}{unit}")
8898
print(f" Mean: {stats['mean']:.2f}{unit}, Min: {stats['min']:.2f}{unit}, Max: {stats['max']:.2f}{unit}")
89-
99+
90100
return stats
91101

92102

93103
def torchvision_pipeline(images: torch.Tensor, target_size: int) -> torch.Tensor:
94-
images = F.resize(images, size=(target_size, target_size), interpolation=F.InterpolationMode.BILINEAR, antialias=True)
104+
images = F.resize(
105+
images, size=(target_size, target_size), interpolation=F.InterpolationMode.BILINEAR, antialias=True
106+
)
95107
images = F.to_dtype(images, dtype=torch.float32, scale=True)
96108
images = F.normalize(images, mean=NORM_MEAN, std=NORM_STD)
97109
return images
@@ -114,35 +126,52 @@ def pil_pipeline(image: Image.Image, target_size: int) -> torch.Tensor:
114126

115127

116128
def albumentations_pipeline(image: np.ndarray, target_size: int) -> torch.Tensor:
117-
transform = A.Compose([
118-
A.Resize(target_size, target_size, interpolation=cv2.INTER_LINEAR),
119-
A.Normalize(mean=NORM_MEAN, std=NORM_STD, max_pixel_value=255.0)
120-
])
129+
transform = A.Compose(
130+
[
131+
A.Resize(target_size, target_size, interpolation=cv2.INTER_LINEAR),
132+
A.Normalize(mean=NORM_MEAN, std=NORM_STD, max_pixel_value=255.0),
133+
]
134+
)
121135
img = transform(image=image)["image"]
122136
img = torch.from_numpy(img).permute(2, 0, 1)
123137
return img
124138

125139

140+
def kornia_pipeline(image: torch.Tensor, target_size: int) -> torch.Tensor:
141+
# Kornia expects float tensors in [0, 1] range
142+
# TODO check that this is needed?
143+
img = image.float() / 255.0
144+
img = img.unsqueeze(0) # Add batch dimension for kornia
145+
146+
img = K.geometry.transform.resize(img, (target_size, target_size), interpolation="bilinear")
147+
148+
img = K.enhance.normalize(img, mean=torch.tensor(NORM_MEAN), std=torch.tensor(NORM_STD))
149+
150+
return img.squeeze(0) # Remove batch dimension
151+
152+
126153
# TODO double check that this works as expected: no graph break, and no issues with dynamic shapes
127154
compiled_torchvision_pipeline = torch.compile(torchvision_pipeline, mode="default", fullgraph=True, dynamic=True)
128155

129156

130157
def run_benchmark(args) -> Dict[str, Any]:
131158
backend = args.backend.lower()
132-
159+
133160
if backend == "opencv" and not HAS_OPENCV:
134161
raise RuntimeError("OpenCV not available. Install with: pip install opencv-python")
135162
if backend == "albumentations" and not HAS_ALBUMENTATIONS:
136163
raise RuntimeError("Albumentations not available. Install with: pip install albumentations")
137-
164+
if backend == "kornia" and not HAS_KORNIA:
165+
raise RuntimeError("Kornia not available. Install with: pip install kornia")
166+
138167
if args.verbose:
139168
backend_display = args.backend.upper()
140169
print(f"\n=== {backend_display} ===")
141170
print(f"Threads: {args.num_threads}, Batch size: {args.batch_size}")
142171

143172
memory_format = torch.channels_last if args.contiguity == "CL" else torch.contiguous_format
144173
print(f"Memory format: {'channels_last' if memory_format == torch.channels_last else 'channels_first'}")
145-
174+
146175
if backend == "tv":
147176
torch.set_num_threads(args.num_threads)
148177
pipeline = torchvision_pipeline
@@ -158,19 +187,22 @@ def run_benchmark(args) -> Dict[str, Any]:
158187
elif backend == "albumentations":
159188
cv2.setNumThreads(args.num_threads)
160189
pipeline = albumentations_pipeline
161-
190+
elif backend == "kornia":
191+
torch.set_num_threads(args.num_threads)
192+
pipeline = kornia_pipeline
193+
162194
def generate_test_images():
163195
height = random.randint(args.min_size, args.max_size)
164196
width = random.randint(args.min_size, args.max_size)
165197
images = torch.randint(0, 256, (args.batch_size, 3, height, width), dtype=torch.uint8)
166-
198+
167199
memory_format = torch.channels_last if args.contiguity == "CL" else torch.contiguous_format
168200
if memory_format == torch.channels_last:
169201
images = images.to(memory_format=torch.channels_last)
170-
202+
171203
if args.batch_size == 1:
172204
images = images[0]
173-
205+
174206
if backend == "opencv":
175207
if args.batch_size > 1:
176208
raise ValueError("Batches not supported in OpenCV pipeline")
@@ -187,70 +219,91 @@ def generate_test_images():
187219
# TODO is that true????
188220
raise ValueError("Batches not supported in Albumentations pipeline")
189221
images = images.numpy().transpose(1, 2, 0)
222+
elif backend == "kornia":
223+
if args.batch_size > 1:
224+
# TODO is that true????
225+
raise ValueError("Batches not supported in Kornia pipeline")
190226

191227
return images
192-
228+
193229
times = bench(
194230
lambda images: pipeline(images, args.target_size),
195231
data_generator=generate_test_images,
196232
num_exp=args.num_exp,
197233
warmup=args.warmup,
198234
)
199-
235+
200236
stats = report_stats(times, "ms", args.verbose)
201237
return {"backend": args.backend, "stats": stats}
202238

203239

204240
def print_comparison_table(results: List[Dict[str, Any]]) -> None:
205241
torchvision_median = next((r["stats"]["median"] for r in results if r["backend"].lower() == "tv"), None)
206-
242+
207243
table_data = []
208244
for result in results:
209245
stats = result["stats"]
210246
relative = f"{stats['median'] / torchvision_median:.2f}x" if torchvision_median else "N/A"
211-
212-
table_data.append({
213-
"Backend": result["backend"],
214-
"Median (ms)": f"{stats['median']:.2f}",
215-
"Std (ms)": f"{stats['std']:.2f}",
216-
"Mean (ms)": f"{stats['mean']:.2f}",
217-
"Min (ms)": f"{stats['min']:.2f}",
218-
"Max (ms)": f"{stats['max']:.2f}",
219-
"Relative": relative
220-
})
221-
247+
248+
table_data.append(
249+
{
250+
"Backend": result["backend"],
251+
"Median (ms)": f"{stats['median']:.2f}",
252+
"Std (ms)": f"{stats['std']:.2f}",
253+
"Mean (ms)": f"{stats['mean']:.2f}",
254+
"Min (ms)": f"{stats['min']:.2f}",
255+
"Max (ms)": f"{stats['max']:.2f}",
256+
"Relative": relative,
257+
}
258+
)
259+
222260
print(tabulate(table_data, headers="keys", tablefmt="grid"))
223261

224262

225263
def main():
226264
parser = argparse.ArgumentParser(description="Benchmark torchvision transforms")
227265
parser.add_argument("--num-exp", type=int, default=100, help="Number of experiments we average over")
228-
parser.add_argument("--warmup", type=int, default=10, help="Number of warmup runs before running the num-exp experiments")
229-
parser.add_argument("--target-size", type=int, default=224, help="Resize target size")
266+
parser.add_argument(
267+
"--warmup", type=int, default=10, help="Number of warmup runs before running the num-exp experiments"
268+
)
269+
parser.add_argument(
270+
"--target-size", type=int, default=224, help="size parameter of the Resize step, for both H and W."
271+
)
230272
parser.add_argument("--min-size", type=int, default=128, help="Minimum input image size for random generation")
231273
parser.add_argument("--max-size", type=int, default=512, help="Maximum input image size for random generation")
232-
parser.add_argument("--num-threads", type=int, default=1, help="Number of intra-op threads as set with torch.set_num_threads()")
233-
parser.add_argument("--batch-size", type=int, default=1, help="Batch size. 1 means single image processing without a batch dimension")
234-
parser.add_argument("--contiguity", choices=["CL", "CF"], default="CF", help="Memory format: CL (channels_last) or CF (channels_first, i.e. contiguous)")
235-
all_backends = ["tv", "tv-compiled", "opencv", "pil", "albumentations"]
236-
parser.add_argument("--backend", type=str.lower, choices=all_backends + ["all"], default="all", help="Backend to use for transforms")
274+
parser.add_argument(
275+
"--num-threads", type=int, default=1, help="Number of intra-op threads as set with torch.set_num_threads() & Co"
276+
)
277+
parser.add_argument(
278+
"--batch-size", type=int, default=1, help="Batch size. 1 means single 3D image without a batch dimension"
279+
)
280+
parser.add_argument(
281+
"--contiguity",
282+
choices=["CL", "CF"],
283+
default="CF",
284+
help="Memory format: CL (channels_last) or CF (channels_first, i.e. contiguous)",
285+
)
286+
all_backends = ["tv", "tv-compiled", "opencv", "pil", "albumentations", "kornia"]
287+
parser.add_argument(
288+
"--backend", type=str.lower, choices=all_backends + ["all"], default="all", help="Backend to benchmark"
289+
)
237290
parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose output")
238291

239292
args = parser.parse_args()
240-
293+
241294
print(f"Averaging over {args.num_exp} runs, {args.warmup} warmup runs")
242295

243296
backends_to_run = all_backends if args.backend.lower() == "all" else [args.backend]
244297
results = []
245-
298+
246299
for backend in backends_to_run:
247300
args.backend = backend
248301
try:
249302
result = run_benchmark(args)
250303
results.append(result)
251304
except Exception as e:
252305
print(f"ERROR with {backend}: {e}")
253-
306+
254307
if len(results) > 1:
255308
print_comparison_table(results)
256309

0 commit comments

Comments
 (0)