diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py index f7d00181b..3a7c85937 100644 --- a/torchchat/cli/cli.py +++ b/torchchat/cli/cli.py @@ -533,7 +533,7 @@ def arg_init(args): # Localized import to minimize expensive imports from torchchat.utils.build_utils import get_device_str - if args.device is None: + if args.device is None or args.device == "fast": args.device = get_device_str( args.quantize.get("executor", {}).get("accelerator", default_device) )