diff --git a/demo_page.py b/demo_page.py index 537d5a7..88bfd92 100644 --- a/demo_page.py +++ b/demo_page.py @@ -28,10 +28,12 @@ def __init__(self, model_id_or_path): self.model.eval() # Set device and precision - self.device = "cuda" if torch.cuda.is_available() else "cpu" + # Cuda -> MPS -> CPU + self.device = "cuda" if torch.cuda.is_available() else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu") + self.model.to(self.device) - # Use float16 on CUDA, float32 on CPU - if self.device == "cuda": + # Use float16 on CUDA and MPS, float32 on CPU + if self.device == "cuda" or self.device == "mps": self.model = self.model.half() else: self.model = self.model.float()