@@ -214,6 +214,22 @@ def get_parser_config() -> ArgumentParser:
214214 action = BooleanOptionalAction ,
215215 help = "displays the task as well" ,
216216 )
217+ parser .add_argument (
218+ "-c" ,
219+ "--cached" ,
220+ default = True ,
221+ action = BooleanOptionalAction ,
222+ help = "uses cached configuration, only available for some of them, "
223+ "mostly for unit test purposes" ,
224+ )
225+ parser .add_argument (
226+ "--mop" ,
227+ metavar = "KEY=VALUE" ,
228+ nargs = "*" ,
229+ help = "Additional model options, use to change some parameters of the model, "
230+ "example: --mop attn_implementation=eager" ,
231+ action = _ParseDict ,
232+ )
217233 return parser
218234
219235
@@ -222,7 +238,11 @@ def _cmd_config(argv: List[Any]):
222238
223239 parser = get_parser_config ()
224240 args = parser .parse_args (argv [1 :])
225- print (get_pretrained_config (args .mid ))
241+ conf = get_pretrained_config (args .mid , ** (args .mop or {}))
242+ print (conf )
243+ for k , v in sorted (conf .__dict__ .items ()):
244+ if "_implementation" in k :
245+ print (f"config.{ k } ={ v !r} " )
226246 if args .task :
227247 print ("------" )
228248 print (f"task: { task_from_id (args .mid )} " )
@@ -238,6 +258,19 @@ def __call__(self, parser, namespace, values, option_string=None):
238258 key = split_items [0 ].strip () # we remove blanks around keys, as is logical
239259 value = split_items [1 ]
240260
261+ if value in ("True" , "true" , "False" , "false" ):
262+ d [key ] = bool (value )
263+ continue
264+ try :
265+ d [key ] = int (value )
266+ continue
267+ except (TypeError , ValueError ):
268+ pass
269+ try :
270+ d [key ] = float (value )
271+ continue
272+ except (TypeError , ValueError ):
273+ pass
241274 d [key ] = value
242275
243276 setattr (namespace , self .dest , d )
@@ -321,6 +354,14 @@ def get_parser_validate() -> ArgumentParser:
321354 "inputs use to export, example: --iop cls_cache=SlidingWindowCache" ,
322355 action = _ParseDict ,
323356 )
357+ parser .add_argument (
358+ "--mop" ,
359+ metavar = "KEY=VALUE" ,
360+ nargs = "*" ,
361+ help = "Additional model options, use to change some parameters of the model, "
362+ "example: --mop attn_implementation=eager" ,
363+ action = _ParseDict ,
364+ )
324365 return parser
325366
326367
@@ -371,6 +412,7 @@ def _cmd_validate(argv: List[Any]):
371412 drop_inputs = None if not args .drop else args .drop .split ("," ),
372413 ortfusiontype = args .ortfusiontype ,
373414 input_options = args .iop ,
415+ model_options = args .mop ,
374416 )
375417 print ("" )
376418 print ("-- summary --" )
0 commit comments