Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 93f713f

Browse files
mikekgfbJack-Khuu
andauthored
Update cli.py to make --device/--dtype pre-empt quantize dict-specified values (#1359)
* Update cli.py to make --device/--dtype pre-empt quantize dict-specified values Users may expect that cli parameters override the JSON, as per #1278. Invert logic - case split: 1 - if none (no value) is specified, use value specified in quantize dict, if present; else 2 - if value is specified, override the respective handler if present. * Fix typo in cli.py fix typo --------- Co-authored-by: Jack-Khuu <[email protected]>
1 parent 0b385d3 commit 93f713f

File tree

1 file changed

+26
-10
lines changed

1 file changed

+26
-10
lines changed

torchchat/cli/cli.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
logger = logging.getLogger(__name__)
2222

2323
default_device = os.getenv("TORCHCHAT_DEVICE", "fast")
24+
default_dtype = os.getenv("TORCHCHAT_PRECISION", "fast")
25+
2426
default_model_dir = Path(
2527
os.getenv("TORCHCHAT_MODELDIR", "~/.torchchat/model-cache")
2628
).expanduser()
@@ -149,9 +151,9 @@ def _add_model_config_args(parser, verb: str) -> None:
149151

150152
model_config_parser.add_argument(
151153
"--dtype",
152-
default="fast",
154+
default=None,
153155
choices=allowable_dtype_names(),
154-
help="Override the dtype of the model (default is the checkpoint dtype). Options: bf16, fp16, fp32, fast16, fast",
156+
help="Override the dtype of the model. Options: bf16, fp16, fp32, fast16, fast",
155157
)
156158
model_config_parser.add_argument(
157159
"--quantize",
@@ -165,9 +167,9 @@ def _add_model_config_args(parser, verb: str) -> None:
165167
model_config_parser.add_argument(
166168
"--device",
167169
type=str,
168-
default=default_device,
170+
default=None,
169171
choices=["fast", "cpu", "cuda", "mps"],
170-
help="Hardware device to use. Options: cpu, cuda, mps",
172+
help="Hardware device to use. Options: fast, cpu, cuda, mps",
171173
)
172174

173175

@@ -513,20 +515,34 @@ def arg_init(args):
513515
if isinstance(args.quantize, str):
514516
args.quantize = json.loads(args.quantize)
515517

516-
# if we specify dtype in quantization recipe, replicate it as args.dtype
517-
args.dtype = args.quantize.get("precision", {}).get("dtype", args.dtype)
518+
# if we specify dtype in quantization recipe, allow args.dtype top override if specified
519+
if args.dtype is None:
520+
args.dtype = args.quantize.get("precision", {}).get("dtype", default_dtype)
521+
else:
522+
precision_handler = args.quantize.get("precision", None)
523+
if precision_handler:
524+
if precision_handler["dtype"] != args.dtype:
525+
print('overriding json-specified dtype {precision_handler["dtype"]} with cli dtype {args.dtype}')
526+
precision_handler["dtype"] = args.dtype
518527

519528
if getattr(args, "output_pte_path", None):
520-
if args.device not in ["cpu", "fast"]:
529+
if args.device not in [None, "cpu", "fast"]:
521530
raise RuntimeError("Device not supported by ExecuTorch")
522531
args.device = "cpu"
523532
else:
524533
# Localized import to minimize expensive imports
525534
from torchchat.utils.build_utils import get_device_str
526535

527-
args.device = get_device_str(
528-
args.quantize.get("executor", {}).get("accelerator", args.device)
529-
)
536+
if args.device is None:
537+
args.device = get_device_str(
538+
args.quantize.get("executor", {}).get("accelerator", default_device)
539+
)
540+
else:
541+
executor_handler = args.quantize.get("executor", None)
542+
if executor_handler:
543+
if executor_handler["accelerator"] != args.device:
544+
print('overriding json-specified device {executor_handler["accelerator"]} with cli device {args.device}')
545+
executor_handler["accelerator"] = args.device
530546

531547
if "mps" in args.device:
532548
if getattr(args, "compile", False) or getattr(args, "compile_prefill", False):

0 commit comments

Comments
 (0)