@@ -264,14 +264,18 @@ def shrink_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
264264 return new_cfg
265265
266266
267- def _preprocess_model_id (model_id , subfolder ):
267+ def _preprocess_model_id (
268+ model_id : str , subfolder : str , same_as_pretrained : bool , use_pretrained : bool
269+ ) -> Tuple [str , str , bool , bool ]:
268270 if subfolder or "//" not in model_id :
269- return model_id , subfolder
271+ return model_id , subfolder , same_as_pretrained , use_pretrained
270272 spl = model_id .split ("//" )
273+ if spl [- 1 ] == "pretrained" :
274+ return _preprocess_model_id ("//" .join (spl [:- 1 ]), "" , True , True )
271275 if spl [- 1 ] in {"transformer" , "vae" }:
272276 # known subfolder
273- return "//" .join (spl [:- 1 ]), spl [- 1 ]
274- return model_id , subfolder
277+ return "//" .join (spl [:- 1 ]), spl [- 1 ], same_as_pretrained , use_pretrained
278+ return model_id , subfolder , same_as_pretrained , use_pretrained
275279
276280
277281def validate_model (
@@ -384,7 +388,12 @@ def validate_model(
384388 if ``runtime == 'ref'``,
385389 ``orteval10`` increases the verbosity.
386390 """
387- model_id , subfolder = _preprocess_model_id (model_id , subfolder )
391+ model_id , subfolder , same_as_pretrained , use_pretrained = _preprocess_model_id (
392+ model_id ,
393+ subfolder ,
394+ same_as_pretrained = same_as_pretrained ,
395+ use_pretrained = use_pretrained ,
396+ )
388397 if isinstance (patch , bool ):
389398 patch_kwargs = (
390399 dict (patch_transformers = True , patch_diffusers = True , patch = True )
0 commit comments