@@ -306,7 +306,7 @@ def __call__(self, parser, namespace, values, option_string=None):
306306 value = split_items [1 ]
307307
308308 if value in ("True" , "true" , "False" , "false" ):
309- d [key ] = bool ( value )
309+ d [key ] = value in ( "True" , "true" )
310310 continue
311311 try :
312312 d [key ] = int (value )
@@ -323,6 +323,54 @@ def __call__(self, parser, namespace, values, option_string=None):
323323 setattr (namespace , self .dest , d )
324324
325325
326+ class _BoolOrParseDictPatch (argparse .Action ):
327+ def __call__ (self , parser , namespace , values , option_string = None ):
328+
329+ if not values :
330+ return
331+ if len (values ) == 1 and values [0 ] in (
332+ "True" ,
333+ "False" ,
334+ "true" ,
335+ "false" ,
336+ "0" ,
337+ "1" ,
338+ 0 ,
339+ 1 ,
340+ ):
341+ setattr (namespace , self .dest , values [0 ] in ("True" , "true" , 1 , "1" ))
342+ return
343+ d = getattr (namespace , self .dest ) or {}
344+ if not isinstance (d , dict ):
345+ d = {
346+ "patch_sympy" : d ,
347+ "patch_torch" : d ,
348+ "patch_transformers" : d ,
349+ "patch_diffusers" : d ,
350+ }
351+ for item in values :
352+ split_items = item .split ("=" , 1 )
353+ key = split_items [0 ].strip () # we remove blanks around keys, as is logical
354+ value = split_items [1 ]
355+
356+ if value in ("True" , "true" , "False" , "false" ):
357+ d [key ] = value in ("True" , "true" )
358+ continue
359+ try :
360+ d [key ] = int (value )
361+ continue
362+ except (TypeError , ValueError ):
363+ pass
364+ try :
365+ d [key ] = float (value )
366+ continue
367+ except (TypeError , ValueError ):
368+ pass
369+ d [key ] = _parse_json (value )
370+
371+ setattr (namespace , self .dest , d )
372+
373+
326374def get_parser_validate () -> ArgumentParser :
327375 parser = ArgumentParser (
328376 prog = "validate" ,
@@ -383,8 +431,13 @@ def get_parser_validate() -> ArgumentParser:
383431 parser .add_argument (
384432 "--patch" ,
385433 default = True ,
386- action = BooleanOptionalAction ,
387- help = "Applies patches before exporting." ,
434+ action = _BoolOrParseDictPatch ,
435+ nargs = "*" ,
436+ help = "Applies patches before exporting, it can be a boolean "
437+ "to enable to disable the patches or be more finetuned. It is possible to "
438+ "disable patch for torch by adding "
439+ '--patch "patch_sympy=False" --patch "patch_torch=False", '
440+ "default is True." ,
388441 )
389442 parser .add_argument (
390443 "--rewrite" ,
0 commit comments