diff --git a/test/smoke_test.py b/test/smoke_test.py index 6f1ec06da30..e2a3b5068ab 100644 --- a/test/smoke_test.py +++ b/test/smoke_test.py @@ -96,7 +96,7 @@ def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None: model.eval() # Step 2: Initialize the inference transforms - preprocess = weights.transforms(antialias=(device != "mps")) # antialias not supported on MPS + preprocess = weights.transforms(antialias=True) # Step 3: Apply inference preprocessing transforms batch = preprocess(img).unsqueeze(0)