Skip to content

Commit 8882898

Browse files
committed
Add tv-v1 pipeline
1 parent 9478b4d commit 8882898

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

benchmarks/benchmark_transforms.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import warnings
1616
from typing import Dict, Any
1717
import torchvision.transforms.v2.functional as F
18+
import torchvision.transforms.functional as Fv1
1819
import numpy as np
1920
from utils import bench, report_stats, print_comparison_table, print_benchmark_info
2021

@@ -63,6 +64,14 @@ def torchvision_pipeline(images: torch.Tensor, target_size: int) -> torch.Tensor
6364
images = F.normalize(images, mean=NORM_MEAN, std=NORM_STD)
6465
return images
6566

67+
def torchvision_v1_pipeline(images: torch.Tensor, target_size: int) -> torch.Tensor:
68+
images = images.float() / 255. # rough equivalent of to_tensor()
69+
images = Fv1.resize(
70+
images, size=(target_size, target_size), interpolation=Fv1.InterpolationMode.BILINEAR, antialias=True
71+
)
72+
images = Fv1.normalize(images, mean=NORM_MEAN, std=NORM_STD)
73+
return images
74+
6675

6776
def opencv_pipeline(image: np.ndarray, target_size: int) -> torch.Tensor:
6877
img = cv2.resize(image, (target_size, target_size), interpolation=cv2.INTER_LINEAR) # no antialias in OpenCV
@@ -140,6 +149,9 @@ def run_benchmark(args) -> Dict[str, Any]:
140149
if backend == "tv":
141150
torch.set_num_threads(args.num_threads)
142151
pipeline = torchvision_pipeline
152+
if backend == "tv-v1":
153+
torch.set_num_threads(args.num_threads)
154+
pipeline = torchvision_v1_pipeline
143155
elif backend == "tv-compiled":
144156
torch.set_num_threads(args.num_threads)
145157
pipeline = compiled_torchvision_pipeline
@@ -229,7 +241,7 @@ def main():
229241
default="CF",
230242
help="Memory format: CL (channels_last) or CF (channels_first, i.e. contiguous)",
231243
)
232-
all_backends = ["tv", "tv-compiled", "opencv", "pil", "albumentations", "kornia"]
244+
all_backends = ["tv", "tv-v1", "tv-compiled", "opencv", "pil", "albumentations", "kornia"]
233245
parser.add_argument(
234246
"--backends", type=str, default="all", help="Backends to benchmark (comma-separated list or 'all'). First backend is used as reference for comparison."
235247
)

0 commit comments

Comments
 (0)