Skip to content

Commit 9478b4d

Browse files
committed
Allow multiple backends, invert 'relative' column to show speed-up against first row
1 parent b88b906 commit 9478b4d

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

benchmarks/benchmark_transforms.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def main():
231231
)
232232
all_backends = ["tv", "tv-compiled", "opencv", "pil", "albumentations", "kornia"]
233233
parser.add_argument(
234-
"--backend", type=str.lower, choices=all_backends + ["all"], default="all", help="Backend to benchmark"
234+
"--backends", type=str, default="all", help="Backends to benchmark (comma-separated list or 'all'). First backend is used as reference for comparison."
235235
)
236236
parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose output")
237237
parser.add_argument("--device", type=str, default="cpu", help="Device to use: cpu or cuda (default: cpu)")
@@ -240,7 +240,18 @@ def main():
240240

241241
print_benchmark_info(args)
242242

243-
backends_to_run = all_backends if args.backend.lower() == "all" else [args.backend]
243+
# Parse backends parameter
244+
if args.backends.lower() == "all":
245+
backends_to_run = all_backends
246+
else:
247+
backends_to_run = [backend.strip().lower() for backend in args.backends.split(",")]
248+
# Validate backends
249+
invalid_backends = [b for b in backends_to_run if b not in all_backends]
250+
if invalid_backends:
251+
print(f"ERROR: Invalid backends: {', '.join(invalid_backends)}")
252+
print(f"Available backends: {', '.join(all_backends)}")
253+
return
254+
244255
results = []
245256

246257
for backend in backends_to_run:

benchmarks/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,13 @@ def report_stats(times: torch.Tensor, unit: str, verbose: bool = True) -> Dict[s
8686

8787

8888
def print_comparison_table(results: List[Dict[str, Any]]) -> None:
89-
torchvision_median = next((r["stats"]["median"] for r in results if r["backend"].lower() == "tv"), None)
89+
# Use first backend as reference for relative comparison
90+
reference_median = results[0]["stats"]["median"] if results else None
9091

9192
table_data = []
9293
for result in results:
9394
stats = result["stats"]
94-
relative = f"{stats['median'] / torchvision_median:.2f}x" if torchvision_median else "N/A"
95+
speed_up = f"{reference_median / stats['median']:.2f}x" if reference_median else "N/A"
9596

9697
table_data.append(
9798
{
@@ -101,7 +102,7 @@ def print_comparison_table(results: List[Dict[str, Any]]) -> None:
101102
"Mean (ms)": f"{stats['mean']:.2f}",
102103
"Min (ms)": f"{stats['min']:.2f}",
103104
"Max (ms)": f"{stats['max']:.2f}",
104-
"Relative": relative,
105+
"Speed-up": speed_up,
105106
}
106107
)
107108

0 commit comments

Comments
 (0)