@@ -114,29 +114,56 @@ def _process_yolo(model_type: str, model_path: str, filename: str) -> str:
114114
115115 print_warn_for_wrong_dependencies_versions ([("ultralytics" , ">=" , "8.3.0" )], ask_to_continue = True )
116116
117- model = torch .load (os .path .join (model_path , filename ), map_location = "cpu" )
117+ elif "yolov12" in model_type :
118+ try :
119+ import torch
120+ import ultralytics
121+
122+ except ImportError :
123+ raise RuntimeError (
124+ "The ultralytics python package is required to deploy yolov12"
125+ " models. Please install it from `https://github.com/sunsmarterjie/yolov12`"
126+ )
127+
128+ print (
129+ "\n !!! ATTENTION !!!\n "
130+ "Model must be trained and uploaded using ultralytics from https://github.com/sunsmarterjie/yolov12\n "
131+ "or through the Roboflow platform\n "
132+ "!!! ATTENTION !!!\n "
133+ )
134+
135+ print_warn_for_wrong_dependencies_versions ([("ultralytics" , "==" , "8.3.63" )], ask_to_continue = True )
118136
119- if isinstance (model ["model" ].names , list ):
120- class_names = model ["model" ].names
137+ model = torch .load (os .path .join (model_path , filename ), weights_only = False )
138+
139+ model_instance = model ["model" ] if "model" in model and model ["model" ] is not None else model ["ema" ]
140+
141+ if isinstance (model_instance .names , list ):
142+ class_names = model_instance .names
121143 else :
122144 class_names = []
123- for i , val in enumerate (model [ "model" ] .names ):
124- class_names .append ((val , model [ "model" ] .names [val ]))
145+ for i , val in enumerate (model_instance .names ):
146+ class_names .append ((val , model_instance .names [val ]))
125147 class_names .sort (key = lambda x : x [0 ])
126148 class_names = [x [1 ] for x in class_names ]
127149
128- if "yolov8" in model_type or "yolov10" in model_type or "yolov11" in model_type :
150+ if "yolov8" in model_type or "yolov10" in model_type or "yolov11" in model_type or "yolov12" in model_type :
129151 # try except for backwards compatibility with older versions of ultralytics
130- if "-cls" in model_type or model_type .startswith ("yolov10" ) or model_type .startswith ("yolov11" ):
131- nc = model ["model" ].yaml ["nc" ]
152+ if (
153+ "-cls" in model_type
154+ or model_type .startswith ("yolov10" )
155+ or model_type .startswith ("yolov11" )
156+ or model_type .startswith ("yolov12" )
157+ ):
158+ nc = model_instance .yaml ["nc" ]
132159 args = model ["train_args" ]
133160 else :
134- nc = model [ "model" ] .nc
135- args = model [ "model" ] .args
161+ nc = model_instance .nc
162+ args = model_instance .args
136163 try :
137164 model_artifacts = {
138165 "names" : class_names ,
139- "yaml" : model [ "model" ] .yaml ,
166+ "yaml" : model_instance .yaml ,
140167 "nc" : nc ,
141168 "args" : {k : val for k , val in args .items () if ((k == "model" ) or (k == "imgsz" ) or (k == "batch" ))},
142169 "ultralytics_version" : ultralytics .__version__ ,
@@ -145,7 +172,7 @@ def _process_yolo(model_type: str, model_path: str, filename: str) -> str:
145172 except Exception :
146173 model_artifacts = {
147174 "names" : class_names ,
148- "yaml" : model [ "model" ] .yaml ,
175+ "yaml" : model_instance .yaml ,
149176 "nc" : nc ,
150177 "args" : {
151178 k : val for k , val in args .__dict__ .items () if ((k == "model" ) or (k == "imgsz" ) or (k == "batch" ))
@@ -161,20 +188,20 @@ def _process_yolo(model_type: str, model_path: str, filename: str) -> str:
161188
162189 model_artifacts = {
163190 "names" : class_names ,
164- "nc" : model [ "model" ] .nc ,
191+ "nc" : model_instance .nc ,
165192 "args" : {
166193 "imgsz" : opts ["imgsz" ] if "imgsz" in opts else opts ["img_size" ],
167194 "batch" : opts ["batch_size" ],
168195 },
169196 "model_type" : model_type ,
170197 }
171- if hasattr (model [ "model" ] , "yaml" ):
172- model_artifacts ["yaml" ] = model [ "model" ] .yaml
198+ if hasattr (model_instance , "yaml" ):
199+ model_artifacts ["yaml" ] = model_instance .yaml
173200
174201 with open (os .path .join (model_path , "model_artifacts.json" ), "w" ) as fp :
175202 json .dump (model_artifacts , fp )
176203
177- torch .save (model [ "model" ] .state_dict (), os .path .join (model_path , "state_dict.pt" ))
204+ torch .save (model_instance .state_dict (), os .path .join (model_path , "state_dict.pt" ))
178205
179206 list_files = [
180207 "results.csv" ,
@@ -300,6 +327,7 @@ def get_classnames_txt_for_rfdetr(model_path: str, pt_file: str):
300327 f"please create a class_names.txt file in the model path with the class names in new lines in the order of the classes in the model.\n "
301328 )
302329
330+
303331def maybe_prepend_dummy_class (class_name_file : str ):
304332 with open (class_name_file , "r" ) as f :
305333 class_names = f .readlines ()
0 commit comments