@@ -25,6 +25,20 @@ def _code_needing_rewriting(model: Any) -> Any:
2525 return code_needing_rewriting (model )
2626
2727
28+ def _preprocess_model_id (
29+ model_id : str , subfolder : Optional [str ], same_as_pretrained : bool , use_pretrained : bool
30+ ) -> Tuple [str , Optional [str ], bool , bool ]:
31+ if subfolder or "//" not in model_id :
32+ return model_id , subfolder , same_as_pretrained , use_pretrained
33+ spl = model_id .split ("//" )
34+ if spl [- 1 ] == "pretrained" :
35+ return _preprocess_model_id ("//" .join (spl [:- 1 ]), "" , True , True )
36+ if spl [- 1 ] in {"transformer" , "vae" }:
37+ # known subfolder
38+ return "//" .join (spl [:- 1 ]), spl [- 1 ], same_as_pretrained , use_pretrained
39+ return model_id , subfolder , same_as_pretrained , use_pretrained
40+
41+
2842def get_untrained_model_with_inputs (
2943 model_id : str ,
3044 config : Optional [Any ] = None ,
@@ -85,8 +99,16 @@ def get_untrained_model_with_inputs(
8599 f"model_id={ model_id !r} , preinstalled model is only available "
86100 f"if use_only_preinstalled is False."
87101 )
102+ model_id , subfolder , same_as_pretrained , use_pretrained = _preprocess_model_id (
103+ model_id ,
104+ subfolder ,
105+ same_as_pretrained = same_as_pretrained ,
106+ use_pretrained = use_pretrained ,
107+ )
88108 if verbose :
89- print (f"[get_untrained_model_with_inputs] model_id={ model_id !r} " )
109+ print (
110+ f"[get_untrained_model_with_inputs] model_id={ model_id !r} , subfolder={ subfolder !r} "
111+ )
90112 if use_preinstalled :
91113 print (f"[get_untrained_model_with_inputs] use preinstalled { model_id !r} " )
92114 if config is None :
0 commit comments