Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit f0b2df7

Browse files
authored
CLI: Fix unsafe arg access of unused args (#987)
* CLI: Remove unsafe access of unused args * Annotate the args conditional on subcommands in functions * Typo in generate.py
1 parent cd0307a commit f0b2df7

File tree

3 files changed

+91
-69
lines changed

3 files changed

+91
-69
lines changed

build/builder.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,18 @@ def from_args(cls, args): # -> BuilderArgs:
103103
model_config.transformer_params_key or model_config.name.split("/")[-1]
104104
)
105105

106+
dso_path = getattr(args, "dso_path", None)
107+
pte_path = getattr(args, "pte_path", None)
108+
106109
is_chat_model = False
107110
if args.is_chat_model:
108111
is_chat_model = True
109112
else:
110113
for path in [
111114
checkpoint_path,
112115
checkpoint_dir,
113-
args.dso_path,
114-
args.pte_path,
116+
dso_path,
117+
pte_path,
115118
args.gguf_path,
116119
]:
117120
if path is not None:
@@ -125,7 +128,10 @@ def from_args(cls, args): # -> BuilderArgs:
125128
if "chat" in path_basename or "instruct" in path_basename:
126129
is_chat_model = True
127130

128-
if args.output_pte_path and args.dtype.startswith("fast"):
131+
132+
output_pte_path = getattr(args, "output_pte_path", None)
133+
output_dso_path = getattr(args, "output_dso_path", None)
134+
if output_pte_path and args.dtype.startswith("fast"):
129135
if args.dtype == "fast":
130136
# As per Kimish, float32 should be faster on ET XNNPACK
131137
# (because fp16 is implemented as upcast to fp32 for several
@@ -144,11 +150,11 @@ def from_args(cls, args): # -> BuilderArgs:
144150
params_table=params_table,
145151
gguf_path=args.gguf_path,
146152
gguf_kwargs=None,
147-
dso_path=args.dso_path,
148-
pte_path=args.pte_path,
153+
dso_path=dso_path,
154+
pte_path=pte_path,
149155
device=args.device,
150156
precision=dtype,
151-
setup_caches=(args.output_dso_path or args.output_pte_path),
157+
setup_caches=(output_dso_path or output_pte_path),
152158
use_distributed=args.distributed,
153159
is_chat_model=is_chat_model,
154160
)
@@ -355,27 +361,27 @@ def _maybe_init_distributed(
355361
builder_args: BuilderArgs,
356362
) -> Tuple[Optional[DeviceMesh], Optional[ParallelDims]]:
357363
"""
358-
Initialize distributed related setups if the user specified
364+
Initialize distributed related setups if the user specified
359365
using distributed inference. If not, this is a no-op.
360366
361367
Args:
362368
builder_args (:class:`BuilderArgs`):
363369
Command args for model building.
364370
Returns:
365-
Tuple[Optional[DeviceMesh], Optional[ParallelDims]]:
366-
- The first element is an optional DeviceMesh object,
371+
Tuple[Optional[DeviceMesh], Optional[ParallelDims]]:
372+
- The first element is an optional DeviceMesh object,
367373
which which describes the mesh topology of devices for the DTensor.
368-
- The second element is an optional ParallelDims object,
374+
- The second element is an optional ParallelDims object,
369375
which represents the parallel dimensions configuration.
370376
"""
371377
if not builder_args.use_distributed:
372378
return None, None
373379
dist_config = 'llama3_8B.toml' # TODO - integrate with chat cmd line
374-
375-
world_mesh, parallel_dims = launch_distributed(dist_config)
376-
380+
381+
world_mesh, parallel_dims = launch_distributed(dist_config)
382+
377383
assert world_mesh is not None and parallel_dims is not None, f"failed to launch distributed using {dist_config}"
378-
384+
379385
return world_mesh, parallel_dims
380386

381387

cli.py

Lines changed: 68 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
INVENTORY_VERBS = ["download", "list", "remove", "where"]
3030

3131
# Subcommands related to generating inference output based on user prompts
32-
GENERATION_VERBS = ["browser", "chat", "generate", "server"]
32+
GENERATION_VERBS = ["browser", "chat", "generate", "server"]
3333

3434
# List of all supported subcommands in torchchat
3535
KNOWN_VERBS = GENERATION_VERBS + ["eval", "export"] + INVENTORY_VERBS
@@ -49,9 +49,6 @@ def check_args(args, verb: str) -> None:
4949

5050
# Given a arg parser and a subcommand (verb), add the appropriate arguments
5151
# for that subcommand.
52-
#
53-
# Note the use of argparse.SUPPRESS to hide arguments from --help due to
54-
# legacy CLI arg parsing. See https://github.com/pytorch/torchchat/issues/932
5552
def add_arguments_for_verb(parser, verb: str) -> None:
5653
# Argument closure for inventory related subcommands
5754
if verb in INVENTORY_VERBS:
@@ -62,17 +59,17 @@ def add_arguments_for_verb(parser, verb: str) -> None:
6259
# Add argument groups for model specification (what base model to use)
6360
_add_model_specification_args(parser)
6461

65-
# Add argument groups for exported model path IO
66-
_add_exported_input_path_args(parser, verb)
67-
_add_export_output_path_args(parser, verb)
68-
6962
# Add argument groups for model configuration (compilation, quant, etc)
7063
_add_model_config_args(parser, verb)
7164

7265
# Add thematic argument groups based on the subcommand
73-
if verb in ["browser", "chat", "generate", "server"]:
66+
if verb in GENERATION_VERBS:
67+
_add_exported_input_path_args(parser)
7468
_add_generation_args(parser, verb)
69+
if verb == "export":
70+
_add_export_output_path_args(parser)
7571
if verb == "eval":
72+
_add_exported_input_path_args(parser)
7673
_add_evaluation_args(parser)
7774

7875
# Add CLI Args related to downloading of model artifacts (if not already downloaded)
@@ -89,8 +86,13 @@ def add_arguments_for_verb(parser, verb: str) -> None:
8986

9087
# Add CLI Args related to model specification (what base model to use)
9188
def _add_model_specification_args(parser) -> None:
92-
model_specification_parser = parser.add_argument_group("Model Specification", "(REQUIRED) Specify the base model. Args are mutually exclusive.")
93-
exclusive_parser = model_specification_parser.add_mutually_exclusive_group(required=True)
89+
model_specification_parser = parser.add_argument_group(
90+
"Model Specification",
91+
"(REQUIRED) Specify the base model. Args are mutually exclusive.",
92+
)
93+
exclusive_parser = model_specification_parser.add_mutually_exclusive_group(
94+
required=True
95+
)
9496
exclusive_parser.add_argument(
9597
"model",
9698
type=str,
@@ -120,20 +122,26 @@ def _add_model_specification_args(parser) -> None:
120122
help=argparse.SUPPRESS,
121123
)
122124

125+
123126
# Add CLI Args related to model configuration (compilation, quant, etc)
127+
# Excludes compile args if subcommand is export
124128
def _add_model_config_args(parser, verb: str) -> None:
125-
is_not_export = verb != "export"
126-
model_config_parser = parser.add_argument_group("Model Configuration", "Specify model configurations")
127-
model_config_parser.add_argument(
128-
"--compile",
129-
action="store_true",
130-
help="Whether to compile the model with torch.compile" if is_not_export else argparse.SUPPRESS,
131-
)
132-
model_config_parser.add_argument(
133-
"--compile-prefill",
134-
action="store_true",
135-
help="Whether to compile the prefill. Improves prefill perf, but has higher compile times." if is_not_export else argparse.SUPPRESS,
129+
model_config_parser = parser.add_argument_group(
130+
"Model Configuration", "Specify model configurations"
136131
)
132+
133+
if verb != "export":
134+
model_config_parser.add_argument(
135+
"--compile",
136+
action="store_true",
137+
help="Whether to compile the model with torch.compile",
138+
)
139+
model_config_parser.add_argument(
140+
"--compile-prefill",
141+
action="store_true",
142+
help="Whether to compile the prefill. Improves prefill perf, but has higher compile times.",
143+
)
144+
137145
model_config_parser.add_argument(
138146
"--dtype",
139147
default="fast",
@@ -157,54 +165,55 @@ def _add_model_config_args(parser, verb: str) -> None:
157165
help="Hardware device to use. Options: cpu, cuda, mps",
158166
)
159167

160-
# Add CLI Args representing output paths of exported model files
161-
def _add_export_output_path_args(parser, verb: str) -> None:
162-
is_export = verb == "export"
163168

169+
# Add CLI Args representing output paths of exported model files
170+
def _add_export_output_path_args(parser) -> None:
164171
output_path_parser = parser.add_argument_group(
165-
"Export Output Path" if is_export else None,
166-
"Specify the output path for the exported model files" if is_export else None,
172+
"Export Output Path",
173+
"Specify the output path for the exported model files",
167174
)
168175
exclusive_parser = output_path_parser.add_mutually_exclusive_group()
169176
exclusive_parser.add_argument(
170177
"--output-pte-path",
171178
type=str,
172179
default=None,
173-
help="Output to the specified ExecuTorch .pte model file" if is_export else argparse.SUPPRESS,
180+
help="Output to the specified ExecuTorch .pte model file",
174181
)
175182
exclusive_parser.add_argument(
176183
"--output-dso-path",
177184
type=str,
178185
default=None,
179-
help="Output to the specified AOT Inductor .dso model file" if is_export else argparse.SUPPRESS,
186+
help="Output to the specified AOT Inductor .dso model file",
180187
)
181188

182189

183190
# Add CLI Args representing user provided exported model files
184-
def _add_exported_input_path_args(parser, verb: str) -> None:
185-
is_generation_verb = verb in GENERATION_VERBS
186-
191+
def _add_exported_input_path_args(parser) -> None:
187192
exported_model_path_parser = parser.add_argument_group(
188-
"Exported Model Path" if is_generation_verb else None,
189-
"Specify the path of the exported model files to ingest" if is_generation_verb else None,
193+
"Exported Model Path",
194+
"Specify the path of the exported model files to ingest",
190195
)
191196
exclusive_parser = exported_model_path_parser.add_mutually_exclusive_group()
192197
exclusive_parser.add_argument(
193198
"--dso-path",
194199
type=Path,
195200
default=None,
196-
help="Use the specified AOT Inductor .dso model file" if is_generation_verb else argparse.SUPPRESS,
201+
help="Use the specified AOT Inductor .dso model file",
197202
)
198203
exclusive_parser.add_argument(
199204
"--pte-path",
200205
type=Path,
201206
default=None,
202-
help="Use the specified ExecuTorch .pte model file" if is_generation_verb else argparse.SUPPRESS,
207+
help="Use the specified ExecuTorch .pte model file",
203208
)
204209

210+
205211
# Add CLI Args related to JIT downloading of model artifacts
206212
def _add_jit_downloading_args(parser) -> None:
207-
jit_downloading_parser = parser.add_argument_group("Model Downloading", "Specify args for model downloading (if model is not downloaded)",)
213+
jit_downloading_parser = parser.add_argument_group(
214+
"Model Downloading",
215+
"Specify args for model downloading (if model is not downloaded)",
216+
)
208217
jit_downloading_parser.add_argument(
209218
"--hf-token",
210219
type=str,
@@ -217,7 +226,8 @@ def _add_jit_downloading_args(parser) -> None:
217226
default=default_model_dir,
218227
help=f"The directory to store downloaded model artifacts. Default: {default_model_dir}",
219228
)
220-
229+
230+
221231
# Add CLI Args that are general to subcommand cli execution
222232
def _add_cli_metadata_args(parser) -> None:
223233
parser.add_argument(
@@ -270,16 +280,26 @@ def _configure_artifact_inventory_args(parser, verb: str) -> None:
270280

271281

272282
# Add CLI Args specific to user prompted generation
283+
# Include prompt and num_sample args when the subcommand is generate
273284
def _add_generation_args(parser, verb: str) -> None:
274285
generator_parser = parser.add_argument_group(
275286
"Generation", "Configs for generating output based on provided prompt"
276287
)
277-
generator_parser.add_argument(
278-
"--prompt",
279-
type=str,
280-
default="Hello, my name is",
281-
help="Input prompt for manual output generation" if verb == "generate" else argparse.SUPPRESS,
282-
)
288+
289+
if verb == "generate":
290+
generator_parser.add_argument(
291+
"--prompt",
292+
type=str,
293+
default="Hello, my name is",
294+
help="Input prompt for manual output generation",
295+
)
296+
generator_parser.add_argument(
297+
"--num-samples",
298+
type=int,
299+
default=1,
300+
help="Number of samples",
301+
)
302+
283303
generator_parser.add_argument(
284304
"--chat",
285305
action="store_true",
@@ -292,12 +312,6 @@ def _add_generation_args(parser, verb: str) -> None:
292312
# help="Whether to use a web UI for an interactive chat session",
293313
help=argparse.SUPPRESS,
294314
)
295-
generator_parser.add_argument(
296-
"--num-samples",
297-
type=int,
298-
default=1,
299-
help="Number of samples" if verb == "generate" else argparse.SUPPRESS,
300-
)
301315
generator_parser.add_argument(
302316
"--max-new-tokens",
303317
type=int,
@@ -441,7 +455,7 @@ def arg_init(args):
441455
# if we specify dtype in quantization recipe, replicate it as args.dtype
442456
args.dtype = args.quantize.get("precision", {}).get("dtype", args.dtype)
443457

444-
if args.output_pte_path:
458+
if getattr(args, "output_pte_path", None):
445459
if args.device not in ["cpu", "fast"]:
446460
raise RuntimeError("Device not supported by ExecuTorch")
447461
args.device = "cpu"
@@ -451,12 +465,12 @@ def arg_init(args):
451465
)
452466

453467
if "mps" in args.device:
454-
if args.compile or args.compile_prefill:
468+
if hasattr(args, "compile") and hasattr(args, "compile_prefill"):
455469
print(
456470
"Warning: compilation is not available with device MPS, ignoring option to engage compilation"
457471
)
458-
args.compile = False
459-
args.compile_prefill = False
472+
vars(args)["compile"] = False
473+
vars(args)["compile_prefill"] = False
460474

461475
if hasattr(args, "seed") and args.seed:
462476
torch.manual_seed(args.seed)

generate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,10 @@ def validate_build(
103103

104104
@classmethod
105105
def from_args(cls, args):
106+
dso_path = getattr(args, "dso_path", None)
107+
pte_path = getattr(args, "pte_path", None)
106108
sequential_prefill = (
107-
args.sequential_prefill or bool(args.dso_path) or bool(args.pte_path)
109+
args.sequential_prefill or bool(dso_path) or bool(pte_path)
108110
)
109111

110112
return cls(

0 commit comments

Comments
 (0)