1919
2020try :
2121 import cv2
22+
2223 HAS_OPENCV = True
2324except ImportError :
2425 HAS_OPENCV = False
2526
2627try :
2728 import albumentations as A
29+
2830 HAS_ALBUMENTATIONS = True
2931except ImportError :
3032 HAS_ALBUMENTATIONS = False
3133
34+ try :
35+ import kornia as K
36+ import kornia .augmentation as KA
37+
38+ HAS_KORNIA = True
39+ except ImportError :
40+ HAS_KORNIA = False
41+
3242from PIL import Image
3343from tabulate import tabulate
3444
@@ -82,16 +92,18 @@ def report_stats(times: torch.Tensor, unit: str, verbose: bool = True) -> Dict[s
8292 "min" : times .min ().item (),
8393 "max" : times .max ().item (),
8494 }
85-
95+
8696 if verbose :
8797 print (f" Median: { stats ['median' ]:.2f} { unit } ± { stats ['std' ]:.2f} { unit } " )
8898 print (f" Mean: { stats ['mean' ]:.2f} { unit } , Min: { stats ['min' ]:.2f} { unit } , Max: { stats ['max' ]:.2f} { unit } " )
89-
99+
90100 return stats
91101
92102
93103def torchvision_pipeline (images : torch .Tensor , target_size : int ) -> torch .Tensor :
94- images = F .resize (images , size = (target_size , target_size ), interpolation = F .InterpolationMode .BILINEAR , antialias = True )
104+ images = F .resize (
105+ images , size = (target_size , target_size ), interpolation = F .InterpolationMode .BILINEAR , antialias = True
106+ )
95107 images = F .to_dtype (images , dtype = torch .float32 , scale = True )
96108 images = F .normalize (images , mean = NORM_MEAN , std = NORM_STD )
97109 return images
@@ -114,35 +126,52 @@ def pil_pipeline(image: Image.Image, target_size: int) -> torch.Tensor:
114126
115127
116128def albumentations_pipeline (image : np .ndarray , target_size : int ) -> torch .Tensor :
117- transform = A .Compose ([
118- A .Resize (target_size , target_size , interpolation = cv2 .INTER_LINEAR ),
119- A .Normalize (mean = NORM_MEAN , std = NORM_STD , max_pixel_value = 255.0 )
120- ])
129+ transform = A .Compose (
130+ [
131+ A .Resize (target_size , target_size , interpolation = cv2 .INTER_LINEAR ),
132+ A .Normalize (mean = NORM_MEAN , std = NORM_STD , max_pixel_value = 255.0 ),
133+ ]
134+ )
121135 img = transform (image = image )["image" ]
122136 img = torch .from_numpy (img ).permute (2 , 0 , 1 )
123137 return img
124138
125139
140+ def kornia_pipeline (image : torch .Tensor , target_size : int ) -> torch .Tensor :
141+ # Kornia expects float tensors in [0, 1] range
142+ # TODO check that this is needed?
143+ img = image .float () / 255.0
144+ img = img .unsqueeze (0 ) # Add batch dimension for kornia
145+
146+ img = K .geometry .transform .resize (img , (target_size , target_size ), interpolation = "bilinear" )
147+
148+ img = K .enhance .normalize (img , mean = torch .tensor (NORM_MEAN ), std = torch .tensor (NORM_STD ))
149+
150+ return img .squeeze (0 ) # Remove batch dimension
151+
152+
126153# TODO double check that this works as expected: no graph break, and no issues with dynamic shapes
127154compiled_torchvision_pipeline = torch .compile (torchvision_pipeline , mode = "default" , fullgraph = True , dynamic = True )
128155
129156
130157def run_benchmark (args ) -> Dict [str , Any ]:
131158 backend = args .backend .lower ()
132-
159+
133160 if backend == "opencv" and not HAS_OPENCV :
134161 raise RuntimeError ("OpenCV not available. Install with: pip install opencv-python" )
135162 if backend == "albumentations" and not HAS_ALBUMENTATIONS :
136163 raise RuntimeError ("Albumentations not available. Install with: pip install albumentations" )
137-
164+ if backend == "kornia" and not HAS_KORNIA :
165+ raise RuntimeError ("Kornia not available. Install with: pip install kornia" )
166+
138167 if args .verbose :
139168 backend_display = args .backend .upper ()
140169 print (f"\n === { backend_display } ===" )
141170 print (f"Threads: { args .num_threads } , Batch size: { args .batch_size } " )
142171
143172 memory_format = torch .channels_last if args .contiguity == "CL" else torch .contiguous_format
144173 print (f"Memory format: { 'channels_last' if memory_format == torch .channels_last else 'channels_first' } " )
145-
174+
146175 if backend == "tv" :
147176 torch .set_num_threads (args .num_threads )
148177 pipeline = torchvision_pipeline
@@ -158,19 +187,22 @@ def run_benchmark(args) -> Dict[str, Any]:
158187 elif backend == "albumentations" :
159188 cv2 .setNumThreads (args .num_threads )
160189 pipeline = albumentations_pipeline
161-
190+ elif backend == "kornia" :
191+ torch .set_num_threads (args .num_threads )
192+ pipeline = kornia_pipeline
193+
162194 def generate_test_images ():
163195 height = random .randint (args .min_size , args .max_size )
164196 width = random .randint (args .min_size , args .max_size )
165197 images = torch .randint (0 , 256 , (args .batch_size , 3 , height , width ), dtype = torch .uint8 )
166-
198+
167199 memory_format = torch .channels_last if args .contiguity == "CL" else torch .contiguous_format
168200 if memory_format == torch .channels_last :
169201 images = images .to (memory_format = torch .channels_last )
170-
202+
171203 if args .batch_size == 1 :
172204 images = images [0 ]
173-
205+
174206 if backend == "opencv" :
175207 if args .batch_size > 1 :
176208 raise ValueError ("Batches not supported in OpenCV pipeline" )
@@ -187,70 +219,91 @@ def generate_test_images():
187219 # TODO is that true????
188220 raise ValueError ("Batches not supported in Albumentations pipeline" )
189221 images = images .numpy ().transpose (1 , 2 , 0 )
222+ elif backend == "kornia" :
223+ if args .batch_size > 1 :
224+ # TODO is that true????
225+ raise ValueError ("Batches not supported in Kornia pipeline" )
190226
191227 return images
192-
228+
193229 times = bench (
194230 lambda images : pipeline (images , args .target_size ),
195231 data_generator = generate_test_images ,
196232 num_exp = args .num_exp ,
197233 warmup = args .warmup ,
198234 )
199-
235+
200236 stats = report_stats (times , "ms" , args .verbose )
201237 return {"backend" : args .backend , "stats" : stats }
202238
203239
204240def print_comparison_table (results : List [Dict [str , Any ]]) -> None :
205241 torchvision_median = next ((r ["stats" ]["median" ] for r in results if r ["backend" ].lower () == "tv" ), None )
206-
242+
207243 table_data = []
208244 for result in results :
209245 stats = result ["stats" ]
210246 relative = f"{ stats ['median' ] / torchvision_median :.2f} x" if torchvision_median else "N/A"
211-
212- table_data .append ({
213- "Backend" : result ["backend" ],
214- "Median (ms)" : f"{ stats ['median' ]:.2f} " ,
215- "Std (ms)" : f"{ stats ['std' ]:.2f} " ,
216- "Mean (ms)" : f"{ stats ['mean' ]:.2f} " ,
217- "Min (ms)" : f"{ stats ['min' ]:.2f} " ,
218- "Max (ms)" : f"{ stats ['max' ]:.2f} " ,
219- "Relative" : relative
220- })
221-
247+
248+ table_data .append (
249+ {
250+ "Backend" : result ["backend" ],
251+ "Median (ms)" : f"{ stats ['median' ]:.2f} " ,
252+ "Std (ms)" : f"{ stats ['std' ]:.2f} " ,
253+ "Mean (ms)" : f"{ stats ['mean' ]:.2f} " ,
254+ "Min (ms)" : f"{ stats ['min' ]:.2f} " ,
255+ "Max (ms)" : f"{ stats ['max' ]:.2f} " ,
256+ "Relative" : relative ,
257+ }
258+ )
259+
222260 print (tabulate (table_data , headers = "keys" , tablefmt = "grid" ))
223261
224262
225263def main ():
226264 parser = argparse .ArgumentParser (description = "Benchmark torchvision transforms" )
227265 parser .add_argument ("--num-exp" , type = int , default = 100 , help = "Number of experiments we average over" )
228- parser .add_argument ("--warmup" , type = int , default = 10 , help = "Number of warmup runs before running the num-exp experiments" )
229- parser .add_argument ("--target-size" , type = int , default = 224 , help = "Resize target size" )
266+ parser .add_argument (
267+ "--warmup" , type = int , default = 10 , help = "Number of warmup runs before running the num-exp experiments"
268+ )
269+ parser .add_argument (
270+ "--target-size" , type = int , default = 224 , help = "size parameter of the Resize step, for both H and W."
271+ )
230272 parser .add_argument ("--min-size" , type = int , default = 128 , help = "Minimum input image size for random generation" )
231273 parser .add_argument ("--max-size" , type = int , default = 512 , help = "Maximum input image size for random generation" )
232- parser .add_argument ("--num-threads" , type = int , default = 1 , help = "Number of intra-op threads as set with torch.set_num_threads()" )
233- parser .add_argument ("--batch-size" , type = int , default = 1 , help = "Batch size. 1 means single image processing without a batch dimension" )
234- parser .add_argument ("--contiguity" , choices = ["CL" , "CF" ], default = "CF" , help = "Memory format: CL (channels_last) or CF (channels_first, i.e. contiguous)" )
235- all_backends = ["tv" , "tv-compiled" , "opencv" , "pil" , "albumentations" ]
236- parser .add_argument ("--backend" , type = str .lower , choices = all_backends + ["all" ], default = "all" , help = "Backend to use for transforms" )
274+ parser .add_argument (
275+ "--num-threads" , type = int , default = 1 , help = "Number of intra-op threads as set with torch.set_num_threads() & Co"
276+ )
277+ parser .add_argument (
278+ "--batch-size" , type = int , default = 1 , help = "Batch size. 1 means single 3D image without a batch dimension"
279+ )
280+ parser .add_argument (
281+ "--contiguity" ,
282+ choices = ["CL" , "CF" ],
283+ default = "CF" ,
284+ help = "Memory format: CL (channels_last) or CF (channels_first, i.e. contiguous)" ,
285+ )
286+ all_backends = ["tv" , "tv-compiled" , "opencv" , "pil" , "albumentations" , "kornia" ]
287+ parser .add_argument (
288+ "--backend" , type = str .lower , choices = all_backends + ["all" ], default = "all" , help = "Backend to benchmark"
289+ )
237290 parser .add_argument ("-v" , "--verbose" , action = "store_true" , help = "Enable verbose output" )
238291
239292 args = parser .parse_args ()
240-
293+
241294 print (f"Averaging over { args .num_exp } runs, { args .warmup } warmup runs" )
242295
243296 backends_to_run = all_backends if args .backend .lower () == "all" else [args .backend ]
244297 results = []
245-
298+
246299 for backend in backends_to_run :
247300 args .backend = backend
248301 try :
249302 result = run_benchmark (args )
250303 results .append (result )
251304 except Exception as e :
252305 print (f"ERROR with { backend } : { e } " )
253-
306+
254307 if len (results ) > 1 :
255308 print_comparison_table (results )
256309
0 commit comments