2929INVENTORY_VERBS = ["download" , "list" , "remove" , "where" ]
3030
3131# Subcommands related to generating inference output based on user prompts
32- GENERATION_VERBS = ["browser" , "chat" , "generate" , "server" ]
32+ GENERATION_VERBS = ["browser" , "chat" , "generate" , "server" ]
3333
3434# List of all supported subcommands in torchchat
3535KNOWN_VERBS = GENERATION_VERBS + ["eval" , "export" ] + INVENTORY_VERBS
@@ -49,9 +49,6 @@ def check_args(args, verb: str) -> None:
4949
5050# Given a arg parser and a subcommand (verb), add the appropriate arguments
5151# for that subcommand.
52- #
53- # Note the use of argparse.SUPPRESS to hide arguments from --help due to
54- # legacy CLI arg parsing. See https://github.com/pytorch/torchchat/issues/932
5552def add_arguments_for_verb (parser , verb : str ) -> None :
5653 # Argument closure for inventory related subcommands
5754 if verb in INVENTORY_VERBS :
@@ -62,17 +59,17 @@ def add_arguments_for_verb(parser, verb: str) -> None:
6259 # Add argument groups for model specification (what base model to use)
6360 _add_model_specification_args (parser )
6461
65- # Add argument groups for exported model path IO
66- _add_exported_input_path_args (parser , verb )
67- _add_export_output_path_args (parser , verb )
68-
6962 # Add argument groups for model configuration (compilation, quant, etc)
7063 _add_model_config_args (parser , verb )
7164
7265 # Add thematic argument groups based on the subcommand
73- if verb in ["browser" , "chat" , "generate" , "server" ]:
66+ if verb in GENERATION_VERBS :
67+ _add_exported_input_path_args (parser )
7468 _add_generation_args (parser , verb )
69+ if verb == "export" :
70+ _add_export_output_path_args (parser )
7571 if verb == "eval" :
72+ _add_exported_input_path_args (parser )
7673 _add_evaluation_args (parser )
7774
7875 # Add CLI Args related to downloading of model artifacts (if not already downloaded)
@@ -89,8 +86,13 @@ def add_arguments_for_verb(parser, verb: str) -> None:
8986
9087# Add CLI Args related to model specification (what base model to use)
9188def _add_model_specification_args (parser ) -> None :
92- model_specification_parser = parser .add_argument_group ("Model Specification" , "(REQUIRED) Specify the base model. Args are mutually exclusive." )
93- exclusive_parser = model_specification_parser .add_mutually_exclusive_group (required = True )
89+ model_specification_parser = parser .add_argument_group (
90+ "Model Specification" ,
91+ "(REQUIRED) Specify the base model. Args are mutually exclusive." ,
92+ )
93+ exclusive_parser = model_specification_parser .add_mutually_exclusive_group (
94+ required = True
95+ )
9496 exclusive_parser .add_argument (
9597 "model" ,
9698 type = str ,
@@ -120,20 +122,26 @@ def _add_model_specification_args(parser) -> None:
120122 help = argparse .SUPPRESS ,
121123 )
122124
125+
123126# Add CLI Args related to model configuration (compilation, quant, etc)
127+ # Excludes compile args if subcommand is export
124128def _add_model_config_args (parser , verb : str ) -> None :
125- is_not_export = verb != "export"
126- model_config_parser = parser .add_argument_group ("Model Configuration" , "Specify model configurations" )
127- model_config_parser .add_argument (
128- "--compile" ,
129- action = "store_true" ,
130- help = "Whether to compile the model with torch.compile" if is_not_export else argparse .SUPPRESS ,
131- )
132- model_config_parser .add_argument (
133- "--compile-prefill" ,
134- action = "store_true" ,
135- help = "Whether to compile the prefill. Improves prefill perf, but has higher compile times." if is_not_export else argparse .SUPPRESS ,
129+ model_config_parser = parser .add_argument_group (
130+ "Model Configuration" , "Specify model configurations"
136131 )
132+
133+ if verb != "export" :
134+ model_config_parser .add_argument (
135+ "--compile" ,
136+ action = "store_true" ,
137+ help = "Whether to compile the model with torch.compile" ,
138+ )
139+ model_config_parser .add_argument (
140+ "--compile-prefill" ,
141+ action = "store_true" ,
142+ help = "Whether to compile the prefill. Improves prefill perf, but has higher compile times." ,
143+ )
144+
137145 model_config_parser .add_argument (
138146 "--dtype" ,
139147 default = "fast" ,
@@ -157,54 +165,55 @@ def _add_model_config_args(parser, verb: str) -> None:
157165 help = "Hardware device to use. Options: cpu, cuda, mps" ,
158166 )
159167
160- # Add CLI Args representing output paths of exported model files
161- def _add_export_output_path_args (parser , verb : str ) -> None :
162- is_export = verb == "export"
163168
169+ # Add CLI Args representing output paths of exported model files
170+ def _add_export_output_path_args (parser ) -> None :
164171 output_path_parser = parser .add_argument_group (
165- "Export Output Path" if is_export else None ,
166- "Specify the output path for the exported model files" if is_export else None ,
172+ "Export Output Path" ,
173+ "Specify the output path for the exported model files" ,
167174 )
168175 exclusive_parser = output_path_parser .add_mutually_exclusive_group ()
169176 exclusive_parser .add_argument (
170177 "--output-pte-path" ,
171178 type = str ,
172179 default = None ,
173- help = "Output to the specified ExecuTorch .pte model file" if is_export else argparse . SUPPRESS ,
180+ help = "Output to the specified ExecuTorch .pte model file" ,
174181 )
175182 exclusive_parser .add_argument (
176183 "--output-dso-path" ,
177184 type = str ,
178185 default = None ,
179- help = "Output to the specified AOT Inductor .dso model file" if is_export else argparse . SUPPRESS ,
186+ help = "Output to the specified AOT Inductor .dso model file" ,
180187 )
181188
182189
183190# Add CLI Args representing user provided exported model files
184- def _add_exported_input_path_args (parser , verb : str ) -> None :
185- is_generation_verb = verb in GENERATION_VERBS
186-
191+ def _add_exported_input_path_args (parser ) -> None :
187192 exported_model_path_parser = parser .add_argument_group (
188- "Exported Model Path" if is_generation_verb else None ,
189- "Specify the path of the exported model files to ingest" if is_generation_verb else None ,
193+ "Exported Model Path" ,
194+ "Specify the path of the exported model files to ingest" ,
190195 )
191196 exclusive_parser = exported_model_path_parser .add_mutually_exclusive_group ()
192197 exclusive_parser .add_argument (
193198 "--dso-path" ,
194199 type = Path ,
195200 default = None ,
196- help = "Use the specified AOT Inductor .dso model file" if is_generation_verb else argparse . SUPPRESS ,
201+ help = "Use the specified AOT Inductor .dso model file" ,
197202 )
198203 exclusive_parser .add_argument (
199204 "--pte-path" ,
200205 type = Path ,
201206 default = None ,
202- help = "Use the specified ExecuTorch .pte model file" if is_generation_verb else argparse . SUPPRESS ,
207+ help = "Use the specified ExecuTorch .pte model file" ,
203208 )
204209
210+
205211# Add CLI Args related to JIT downloading of model artifacts
206212def _add_jit_downloading_args (parser ) -> None :
207- jit_downloading_parser = parser .add_argument_group ("Model Downloading" , "Specify args for model downloading (if model is not downloaded)" ,)
213+ jit_downloading_parser = parser .add_argument_group (
214+ "Model Downloading" ,
215+ "Specify args for model downloading (if model is not downloaded)" ,
216+ )
208217 jit_downloading_parser .add_argument (
209218 "--hf-token" ,
210219 type = str ,
@@ -217,7 +226,8 @@ def _add_jit_downloading_args(parser) -> None:
217226 default = default_model_dir ,
218227 help = f"The directory to store downloaded model artifacts. Default: { default_model_dir } " ,
219228 )
220-
229+
230+
221231# Add CLI Args that are general to subcommand cli execution
222232def _add_cli_metadata_args (parser ) -> None :
223233 parser .add_argument (
@@ -270,16 +280,26 @@ def _configure_artifact_inventory_args(parser, verb: str) -> None:
270280
271281
272282# Add CLI Args specific to user prompted generation
283+ # Include prompt and num_sample args when the subcommand is generate
273284def _add_generation_args (parser , verb : str ) -> None :
274285 generator_parser = parser .add_argument_group (
275286 "Generation" , "Configs for generating output based on provided prompt"
276287 )
277- generator_parser .add_argument (
278- "--prompt" ,
279- type = str ,
280- default = "Hello, my name is" ,
281- help = "Input prompt for manual output generation" if verb == "generate" else argparse .SUPPRESS ,
282- )
288+
289+ if verb == "generate" :
290+ generator_parser .add_argument (
291+ "--prompt" ,
292+ type = str ,
293+ default = "Hello, my name is" ,
294+ help = "Input prompt for manual output generation" ,
295+ )
296+ generator_parser .add_argument (
297+ "--num-samples" ,
298+ type = int ,
299+ default = 1 ,
300+ help = "Number of samples" ,
301+ )
302+
283303 generator_parser .add_argument (
284304 "--chat" ,
285305 action = "store_true" ,
@@ -292,12 +312,6 @@ def _add_generation_args(parser, verb: str) -> None:
292312 # help="Whether to use a web UI for an interactive chat session",
293313 help = argparse .SUPPRESS ,
294314 )
295- generator_parser .add_argument (
296- "--num-samples" ,
297- type = int ,
298- default = 1 ,
299- help = "Number of samples" if verb == "generate" else argparse .SUPPRESS ,
300- )
301315 generator_parser .add_argument (
302316 "--max-new-tokens" ,
303317 type = int ,
@@ -441,7 +455,7 @@ def arg_init(args):
441455 # if we specify dtype in quantization recipe, replicate it as args.dtype
442456 args .dtype = args .quantize .get ("precision" , {}).get ("dtype" , args .dtype )
443457
444- if args . output_pte_path :
458+ if getattr ( args , " output_pte_path" , None ) :
445459 if args .device not in ["cpu" , "fast" ]:
446460 raise RuntimeError ("Device not supported by ExecuTorch" )
447461 args .device = "cpu"
@@ -451,12 +465,12 @@ def arg_init(args):
451465 )
452466
453467 if "mps" in args .device :
454- if args . compile or args . compile_prefill :
468+ if hasattr ( args , " compile" ) and hasattr ( args , " compile_prefill" ) :
455469 print (
456470 "Warning: compilation is not available with device MPS, ignoring option to engage compilation"
457471 )
458- args . compile = False
459- args . compile_prefill = False
472+ vars ( args )[ " compile" ] = False
473+ vars ( args )[ " compile_prefill" ] = False
460474
461475 if hasattr (args , "seed" ) and args .seed :
462476 torch .manual_seed (args .seed )
0 commit comments