Skip to content

Commit 17bf481

Browse files
committed
extract the yolov12 state_dict before upload
1 parent f40a653 commit 17bf481

File tree

1 file changed

+22
-56
lines changed

1 file changed

+22
-56
lines changed

roboflow/util/model_processor.py

Lines changed: 22 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,6 @@ def _get_processor_function(model_type: str) -> Callable:
5757
if "yolonas" in model_type:
5858
return _process_yolonas
5959

60-
if "yolov12" in model_type:
61-
return _process_yolov12
62-
6360
return _process_yolo
6461

6562

@@ -110,29 +107,38 @@ def _process_yolo(model_type: str, model_path: str, filename: str) -> str:
110107

111108
print_warn_for_wrong_dependencies_versions([("ultralytics", ">=", "8.3.0")], ask_to_continue=True)
112109

110+
111+
elif "yolov12" in model_type:
112+
print(
113+
"Note: Model must be trained using ultralytics from https://github.com/sunsmarterjie/yolov12 "
114+
"or through the Roboflow platform"
115+
)
116+
113117
model = torch.load(os.path.join(model_path, filename))
114118

115-
if isinstance(model["model"].names, list):
116-
class_names = model["model"].names
119+
model_instance = model["model"] if "model" in model and model["model"] is not None else model["ema"]
120+
121+
if isinstance(model_instance.names, list):
122+
class_names = model_instance.names
117123
else:
118124
class_names = []
119-
for i, val in enumerate(model["model"].names):
120-
class_names.append((val, model["model"].names[val]))
125+
for i, val in enumerate(model_instance.names):
126+
class_names.append((val, model_instance.names[val]))
121127
class_names.sort(key=lambda x: x[0])
122128
class_names = [x[1] for x in class_names]
123129

124130
if "yolov8" in model_type or "yolov10" in model_type or "yolov11" in model_type:
125131
# try except for backwards compatibility with older versions of ultralytics
126132
if "-cls" in model_type or model_type.startswith("yolov10") or model_type.startswith("yolov11"):
127-
nc = model["model"].yaml["nc"]
133+
nc = model_instance.yaml["nc"]
128134
args = model["train_args"]
129135
else:
130-
nc = model["model"].nc
131-
args = model["model"].args
136+
nc = model_instance.nc
137+
args = model_instance.args
132138
try:
133139
model_artifacts = {
134140
"names": class_names,
135-
"yaml": model["model"].yaml,
141+
"yaml": model_instance.yaml,
136142
"nc": nc,
137143
"args": {k: val for k, val in args.items() if ((k == "model") or (k == "imgsz") or (k == "batch"))},
138144
"ultralytics_version": ultralytics.__version__,
@@ -141,7 +147,7 @@ def _process_yolo(model_type: str, model_path: str, filename: str) -> str:
141147
except Exception:
142148
model_artifacts = {
143149
"names": class_names,
144-
"yaml": model["model"].yaml,
150+
"yaml": model_instance.yaml,
145151
"nc": nc,
146152
"args": {
147153
k: val for k, val in args.__dict__.items() if ((k == "model") or (k == "imgsz") or (k == "batch"))
@@ -157,20 +163,20 @@ def _process_yolo(model_type: str, model_path: str, filename: str) -> str:
157163

158164
model_artifacts = {
159165
"names": class_names,
160-
"nc": model["model"].nc,
166+
"nc": model_instance.nc,
161167
"args": {
162168
"imgsz": opts["imgsz"] if "imgsz" in opts else opts["img_size"],
163169
"batch": opts["batch_size"],
164170
},
165171
"model_type": model_type,
166172
}
167-
if hasattr(model["model"], "yaml"):
168-
model_artifacts["yaml"] = model["model"].yaml
173+
if hasattr(model_instance, "yaml"):
174+
model_artifacts["yaml"] = model_instance.yaml
169175

170176
with open(os.path.join(model_path, "model_artifacts.json"), "w") as fp:
171177
json.dump(model_artifacts, fp)
172178

173-
torch.save(model["model"].state_dict(), os.path.join(model_path, "state_dict.pt"))
179+
torch.save(model_instance.state_dict(), os.path.join(model_path, "state_dict.pt"))
174180

175181
list_files = [
176182
"results.csv",
@@ -196,46 +202,6 @@ def _process_yolo(model_type: str, model_path: str, filename: str) -> str:
196202
return zip_file_name
197203

198204

199-
def _process_yolov12(model_type: str, model_path: str, filename: str) -> str:
200-
# For YOLOv12, since it uses a special Ultralytics version,
201-
# state dict extraction and model artifacts are handled during model conversion
202-
203-
print(
204-
"Note: Model must be trained using ultralytics from https://github.com/sunsmarterjie/yolov12 "
205-
"or through the Roboflow platform"
206-
)
207-
208-
# Check if model_path exists
209-
if not os.path.exists(model_path):
210-
raise FileNotFoundError(f"Model path {model_path} does not exist.")
211-
212-
# Find any .pt file in model path
213-
model_files = os.listdir(model_path)
214-
pt_file = next((f for f in model_files if f.endswith(".pt")), None)
215-
216-
if pt_file is None:
217-
raise RuntimeError("No .pt model file found in the provided path")
218-
219-
# Copy the .pt file to weights.pt if not already named weights.pt
220-
if pt_file != "weights.pt":
221-
shutil.copy(os.path.join(model_path, pt_file), os.path.join(model_path, "weights.pt"))
222-
223-
required_files = ["weights.pt"]
224-
225-
optional_files = ["results.csv", "results.png", "model_artifacts.json"]
226-
227-
zip_file_name = "roboflow_deploy.zip"
228-
with zipfile.ZipFile(os.path.join(model_path, zip_file_name), "w") as zipMe:
229-
for file in required_files:
230-
zipMe.write(os.path.join(model_path, file), arcname=file, compress_type=zipfile.ZIP_DEFLATED)
231-
232-
for file in optional_files:
233-
if os.path.exists(os.path.join(model_path, file)):
234-
zipMe.write(os.path.join(model_path, file), arcname=file, compress_type=zipfile.ZIP_DEFLATED)
235-
236-
return zip_file_name
237-
238-
239205
def _process_huggingface(
240206
model_type: str, model_path: str, filename: str = "fine-tuned-paligemma-3b-pt-224.f16.npz"
241207
) -> str:

0 commit comments

Comments
 (0)