Skip to content

Commit 9a10b07

Browse files
committed
Add albumentation backend
1 parent 5d1b099 commit 9a10b07

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

benchmark_transforms.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@
2323
except ImportError:
2424
HAS_OPENCV = False
2525

26+
try:
27+
import albumentations as A
28+
HAS_ALBUMENTATIONS = True
29+
except ImportError:
30+
HAS_ALBUMENTATIONS = False
31+
2632
from PIL import Image
2733
from tabulate import tabulate
2834

@@ -107,6 +113,16 @@ def pil_pipeline(image: Image.Image, target_size: int) -> torch.Tensor:
107113
return img
108114

109115

116+
def albumentations_pipeline(image: np.ndarray, target_size: int) -> torch.Tensor:
117+
transform = A.Compose([
118+
A.Resize(target_size, target_size, interpolation=cv2.INTER_LINEAR),
119+
A.Normalize(mean=NORM_MEAN, std=NORM_STD, max_pixel_value=255.0)
120+
])
121+
img = transform(image=image)["image"]
122+
img = torch.from_numpy(img).permute(2, 0, 1)
123+
return img
124+
125+
110126
# TODO double check that this works as expected: no graph break, and no issues with dynamic shapes
111127
compiled_torchvision_pipeline = torch.compile(torchvision_pipeline, mode="default", fullgraph=True, dynamic=True)
112128

@@ -116,6 +132,8 @@ def run_benchmark(args) -> Dict[str, Any]:
116132

117133
if backend == "opencv" and not HAS_OPENCV:
118134
raise RuntimeError("OpenCV not available. Install with: pip install opencv-python")
135+
if backend == "albumentations" and not HAS_ALBUMENTATIONS:
136+
raise RuntimeError("Albumentations not available. Install with: pip install albumentations")
119137

120138
if args.verbose:
121139
backend_display = args.backend.upper()
@@ -137,6 +155,9 @@ def run_benchmark(args) -> Dict[str, Any]:
137155
elif backend == "pil":
138156
torch.set_num_threads(args.num_threads)
139157
pipeline = pil_pipeline
158+
elif backend == "albumentations":
159+
cv2.setNumThreads(args.num_threads)
160+
pipeline = albumentations_pipeline
140161

141162
def generate_test_images():
142163
height = random.randint(args.min_size, args.max_size)
@@ -161,6 +182,11 @@ def generate_test_images():
161182
# Convert to PIL Image (CHW -> HWC)
162183
images = images.numpy().transpose(1, 2, 0)
163184
images = Image.fromarray(images)
185+
elif backend == "albumentations":
186+
if args.batch_size > 1:
187+
# TODO is that true????
188+
raise ValueError("Batches not supported in Albumentations pipeline")
189+
images = images.numpy().transpose(1, 2, 0)
164190

165191
return images
166192

@@ -206,7 +232,7 @@ def main():
206232
parser.add_argument("--num-threads", type=int, default=1, help="Number of intra-op threads as set with torch.set_num_threads()")
207233
parser.add_argument("--batch-size", type=int, default=1, help="Batch size. 1 means single image processing without a batch dimension")
208234
parser.add_argument("--contiguity", choices=["CL", "CF"], default="CF", help="Memory format: CL (channels_last) or CF (channels_first, i.e. contiguous)")
209-
all_backends = ["tv", "tv-compiled", "opencv", "pil"]
235+
all_backends = ["tv", "tv-compiled", "opencv", "pil", "albumentations"]
210236
parser.add_argument("--backend", type=str.lower, choices=all_backends + ["all"], default="all", help="Backend to use for transforms")
211237
parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose output")
212238

0 commit comments

Comments
 (0)