2323except ImportError :
2424 HAS_OPENCV = False
2525
26+ try :
27+ import albumentations as A
28+ HAS_ALBUMENTATIONS = True
29+ except ImportError :
30+ HAS_ALBUMENTATIONS = False
31+
2632from PIL import Image
2733from tabulate import tabulate
2834
@@ -107,6 +113,16 @@ def pil_pipeline(image: Image.Image, target_size: int) -> torch.Tensor:
107113 return img
108114
109115
116+ def albumentations_pipeline (image : np .ndarray , target_size : int ) -> torch .Tensor :
117+ transform = A .Compose ([
118+ A .Resize (target_size , target_size , interpolation = cv2 .INTER_LINEAR ),
119+ A .Normalize (mean = NORM_MEAN , std = NORM_STD , max_pixel_value = 255.0 )
120+ ])
121+ img = transform (image = image )["image" ]
122+ img = torch .from_numpy (img ).permute (2 , 0 , 1 )
123+ return img
124+
125+
110126# TODO double check that this works as expected: no graph break, and no issues with dynamic shapes
111127compiled_torchvision_pipeline = torch .compile (torchvision_pipeline , mode = "default" , fullgraph = True , dynamic = True )
112128
@@ -116,6 +132,8 @@ def run_benchmark(args) -> Dict[str, Any]:
116132
117133 if backend == "opencv" and not HAS_OPENCV :
118134 raise RuntimeError ("OpenCV not available. Install with: pip install opencv-python" )
135+ if backend == "albumentations" and not HAS_ALBUMENTATIONS :
136+ raise RuntimeError ("Albumentations not available. Install with: pip install albumentations" )
119137
120138 if args .verbose :
121139 backend_display = args .backend .upper ()
@@ -137,6 +155,9 @@ def run_benchmark(args) -> Dict[str, Any]:
137155 elif backend == "pil" :
138156 torch .set_num_threads (args .num_threads )
139157 pipeline = pil_pipeline
158+ elif backend == "albumentations" :
159+ cv2 .setNumThreads (args .num_threads )
160+ pipeline = albumentations_pipeline
140161
141162 def generate_test_images ():
142163 height = random .randint (args .min_size , args .max_size )
@@ -161,6 +182,11 @@ def generate_test_images():
161182 # Convert to PIL Image (CHW -> HWC)
162183 images = images .numpy ().transpose (1 , 2 , 0 )
163184 images = Image .fromarray (images )
185+ elif backend == "albumentations" :
186+ if args .batch_size > 1 :
187+ # TODO is that true????
188+ raise ValueError ("Batches not supported in Albumentations pipeline" )
189+ images = images .numpy ().transpose (1 , 2 , 0 )
164190
165191 return images
166192
@@ -206,7 +232,7 @@ def main():
206232 parser .add_argument ("--num-threads" , type = int , default = 1 , help = "Number of intra-op threads as set with torch.set_num_threads()" )
207233 parser .add_argument ("--batch-size" , type = int , default = 1 , help = "Batch size. 1 means single image processing without a batch dimension" )
208234 parser .add_argument ("--contiguity" , choices = ["CL" , "CF" ], default = "CF" , help = "Memory format: CL (channels_last) or CF (channels_first, i.e. contiguous)" )
209- all_backends = ["tv" , "tv-compiled" , "opencv" , "pil" ]
235+ all_backends = ["tv" , "tv-compiled" , "opencv" , "pil" , "albumentations" ]
210236 parser .add_argument ("--backend" , type = str .lower , choices = all_backends + ["all" ], default = "all" , help = "Backend to use for transforms" )
211237 parser .add_argument ("-v" , "--verbose" , action = "store_true" , help = "Enable verbose output" )
212238
0 commit comments