|
15 | 15 | import warnings |
16 | 16 | from typing import Dict, Any |
17 | 17 | import torchvision.transforms.v2.functional as F |
| 18 | +import torchvision.transforms.functional as Fv1 |
18 | 19 | import numpy as np |
19 | 20 | from utils import bench, report_stats, print_comparison_table, print_benchmark_info |
20 | 21 |
|
@@ -63,6 +64,14 @@ def torchvision_pipeline(images: torch.Tensor, target_size: int) -> torch.Tensor |
63 | 64 | images = F.normalize(images, mean=NORM_MEAN, std=NORM_STD) |
64 | 65 | return images |
65 | 66 |
|
| 67 | +def torchvision_v1_pipeline(images: torch.Tensor, target_size: int) -> torch.Tensor: |
| 68 | + images = images.float() / 255. # rough equivalent of to_tensor() |
| 69 | + images = Fv1.resize( |
| 70 | + images, size=(target_size, target_size), interpolation=Fv1.InterpolationMode.BILINEAR, antialias=True |
| 71 | + ) |
| 72 | + images = Fv1.normalize(images, mean=NORM_MEAN, std=NORM_STD) |
| 73 | + return images |
| 74 | + |
66 | 75 |
|
67 | 76 | def opencv_pipeline(image: np.ndarray, target_size: int) -> torch.Tensor: |
68 | 77 | img = cv2.resize(image, (target_size, target_size), interpolation=cv2.INTER_LINEAR) # no antialias in OpenCV |
@@ -140,6 +149,9 @@ def run_benchmark(args) -> Dict[str, Any]: |
140 | 149 | if backend == "tv": |
141 | 150 | torch.set_num_threads(args.num_threads) |
142 | 151 | pipeline = torchvision_pipeline |
| 152 | + if backend == "tv-v1": |
| 153 | + torch.set_num_threads(args.num_threads) |
| 154 | + pipeline = torchvision_v1_pipeline |
143 | 155 | elif backend == "tv-compiled": |
144 | 156 | torch.set_num_threads(args.num_threads) |
145 | 157 | pipeline = compiled_torchvision_pipeline |
@@ -229,7 +241,7 @@ def main(): |
229 | 241 | default="CF", |
230 | 242 | help="Memory format: CL (channels_last) or CF (channels_first, i.e. contiguous)", |
231 | 243 | ) |
232 | | - all_backends = ["tv", "tv-compiled", "opencv", "pil", "albumentations", "kornia"] |
| 244 | + all_backends = ["tv", "tv-v1", "tv-compiled", "opencv", "pil", "albumentations", "kornia"] |
233 | 245 | parser.add_argument( |
234 | 246 | "--backends", type=str, default="all", help="Backends to benchmark (comma-separated list or 'all'). First backend is used as reference for comparison." |
235 | 247 | ) |
|
0 commit comments