Skip to content

Fix hardcoded CUDA device in api.py to support MPS and CPU fallback#1516

Open
Mr-Neutr0n wants to merge 1 commit intozai-org:mainfrom
Mr-Neutr0n:fix/device-handling-api
Open

Fix hardcoded CUDA device in api.py to support MPS and CPU fallback#1516
Mr-Neutr0n wants to merge 1 commit intozai-org:mainfrom
Mr-Neutr0n:fix/device-handling-api

Conversation

@Mr-Neutr0n
Copy link
Copy Markdown

Problem

The API server (api.py) crashes immediately on non-CUDA systems because:

  1. DEVICE is hardcoded to "cuda", causing failures on macOS (Apple Silicon/MPS) and CPU-only machines
  2. Model loading uses .cuda() directly, which raises RuntimeError when CUDA is unavailable
  3. torch_gc() only handles CUDA cleanup, missing MPS cache management

Fix

  • Auto-detect the best available device at startup: CUDA > MPS > CPU
  • Replace .cuda() with .to(DEVICE) for portable device placement
  • Update torch_gc() to clear MPS cache when on Apple Silicon and skip CUDA-specific cleanup when not on a CUDA device

Testing

  • Verified the logic is consistent with PyTorch's device detection APIs
  • Backward compatible: behavior is identical on CUDA systems since torch.cuda.is_available() returns True and DEVICE resolves to "cuda" as before

The API server crashes on non-CUDA systems (e.g., macOS with Apple
Silicon or CPU-only machines) because DEVICE is hardcoded to "cuda"
and the model loading uses .cuda() directly.

Changes:
- Auto-detect the best available device (CUDA > MPS > CPU)
- Replace .cuda() with .to(DEVICE) for portable device placement
- Update torch_gc() to handle MPS cache clearing and skip
  CUDA-specific cleanup when not on a CUDA device
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant