Skip to content

Commit 8e3c9a1

Browse files
authored
Merge pull request #370 from roboflow/lean/fix-yolov12-upload
Extract the YOLOv12 state_dict before upload
2 parents b16584c + 76c33e7 commit 8e3c9a1

File tree

2 files changed

+44
-60
lines changed

2 files changed

+44
-60
lines changed

roboflow/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from roboflow.models import CLIPModel, GazeModel # noqa: F401
1616
from roboflow.util.general import write_line
1717

18-
__version__ = "1.1.58"
18+
__version__ = "1.1.60"
1919

2020

2121
def check_key(api_key, model, notebook, num_retries=0):

roboflow/util/model_processor.py

Lines changed: 43 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,6 @@ def _get_processor_function(model_type: str) -> Callable:
5757
if "yolonas" in model_type:
5858
return _process_yolonas
5959

60-
if "yolov12" in model_type:
61-
return _process_yolov12
62-
6360
return _process_yolo
6461

6562

@@ -110,29 +107,56 @@ def _process_yolo(model_type: str, model_path: str, filename: str) -> str:
110107

111108
print_warn_for_wrong_dependencies_versions([("ultralytics", ">=", "8.3.0")], ask_to_continue=True)
112109

113-
model = torch.load(os.path.join(model_path, filename))
110+
elif "yolov12" in model_type:
111+
try:
112+
import torch
113+
import ultralytics
114114

115-
if isinstance(model["model"].names, list):
116-
class_names = model["model"].names
115+
except ImportError:
116+
raise RuntimeError(
117+
"The ultralytics python package is required to deploy yolov12"
118+
" models. Please install it from `https://github.com/sunsmarterjie/yolov12`"
119+
)
120+
121+
print(
122+
"\n!!! ATTENTION !!!\n"
123+
"Model must be trained and uploaded using ultralytics from https://github.com/sunsmarterjie/yolov12\n"
124+
"or through the Roboflow platform\n"
125+
"!!! ATTENTION !!!\n"
126+
)
127+
128+
print_warn_for_wrong_dependencies_versions([("ultralytics", "==", "8.3.63")], ask_to_continue=True)
129+
130+
model = torch.load(os.path.join(model_path, filename), weights_only=False)
131+
132+
model_instance = model["model"] if "model" in model and model["model"] is not None else model["ema"]
133+
134+
if isinstance(model_instance.names, list):
135+
class_names = model_instance.names
117136
else:
118137
class_names = []
119-
for i, val in enumerate(model["model"].names):
120-
class_names.append((val, model["model"].names[val]))
138+
for i, val in enumerate(model_instance.names):
139+
class_names.append((val, model_instance.names[val]))
121140
class_names.sort(key=lambda x: x[0])
122141
class_names = [x[1] for x in class_names]
123142

124-
if "yolov8" in model_type or "yolov10" in model_type or "yolov11" in model_type:
143+
if "yolov8" in model_type or "yolov10" in model_type or "yolov11" in model_type or "yolov12" in model_type:
125144
# try except for backwards compatibility with older versions of ultralytics
126-
if "-cls" in model_type or model_type.startswith("yolov10") or model_type.startswith("yolov11"):
127-
nc = model["model"].yaml["nc"]
145+
if (
146+
"-cls" in model_type
147+
or model_type.startswith("yolov10")
148+
or model_type.startswith("yolov11")
149+
or model_type.startswith("yolov12")
150+
):
151+
nc = model_instance.yaml["nc"]
128152
args = model["train_args"]
129153
else:
130-
nc = model["model"].nc
131-
args = model["model"].args
154+
nc = model_instance.nc
155+
args = model_instance.args
132156
try:
133157
model_artifacts = {
134158
"names": class_names,
135-
"yaml": model["model"].yaml,
159+
"yaml": model_instance.yaml,
136160
"nc": nc,
137161
"args": {k: val for k, val in args.items() if ((k == "model") or (k == "imgsz") or (k == "batch"))},
138162
"ultralytics_version": ultralytics.__version__,
@@ -141,7 +165,7 @@ def _process_yolo(model_type: str, model_path: str, filename: str) -> str:
141165
except Exception:
142166
model_artifacts = {
143167
"names": class_names,
144-
"yaml": model["model"].yaml,
168+
"yaml": model_instance.yaml,
145169
"nc": nc,
146170
"args": {
147171
k: val for k, val in args.__dict__.items() if ((k == "model") or (k == "imgsz") or (k == "batch"))
@@ -157,20 +181,20 @@ def _process_yolo(model_type: str, model_path: str, filename: str) -> str:
157181

158182
model_artifacts = {
159183
"names": class_names,
160-
"nc": model["model"].nc,
184+
"nc": model_instance.nc,
161185
"args": {
162186
"imgsz": opts["imgsz"] if "imgsz" in opts else opts["img_size"],
163187
"batch": opts["batch_size"],
164188
},
165189
"model_type": model_type,
166190
}
167-
if hasattr(model["model"], "yaml"):
168-
model_artifacts["yaml"] = model["model"].yaml
191+
if hasattr(model_instance, "yaml"):
192+
model_artifacts["yaml"] = model_instance.yaml
169193

170194
with open(os.path.join(model_path, "model_artifacts.json"), "w") as fp:
171195
json.dump(model_artifacts, fp)
172196

173-
torch.save(model["model"].state_dict(), os.path.join(model_path, "state_dict.pt"))
197+
torch.save(model_instance.state_dict(), os.path.join(model_path, "state_dict.pt"))
174198

175199
list_files = [
176200
"results.csv",
@@ -196,46 +220,6 @@ def _process_yolo(model_type: str, model_path: str, filename: str) -> str:
196220
return zip_file_name
197221

198222

199-
def _process_yolov12(model_type: str, model_path: str, filename: str) -> str:
200-
# For YOLOv12, since it uses a special Ultralytics version,
201-
# state dict extraction and model artifacts are handled during model conversion
202-
203-
print(
204-
"Note: Model must be trained using ultralytics from https://github.com/sunsmarterjie/yolov12 "
205-
"or through the Roboflow platform"
206-
)
207-
208-
# Check if model_path exists
209-
if not os.path.exists(model_path):
210-
raise FileNotFoundError(f"Model path {model_path} does not exist.")
211-
212-
# Find any .pt file in model path
213-
model_files = os.listdir(model_path)
214-
pt_file = next((f for f in model_files if f.endswith(".pt")), None)
215-
216-
if pt_file is None:
217-
raise RuntimeError("No .pt model file found in the provided path")
218-
219-
# Copy the .pt file to weights.pt if not already named weights.pt
220-
if pt_file != "weights.pt":
221-
shutil.copy(os.path.join(model_path, pt_file), os.path.join(model_path, "weights.pt"))
222-
223-
required_files = ["weights.pt"]
224-
225-
optional_files = ["results.csv", "results.png", "model_artifacts.json"]
226-
227-
zip_file_name = "roboflow_deploy.zip"
228-
with zipfile.ZipFile(os.path.join(model_path, zip_file_name), "w") as zipMe:
229-
for file in required_files:
230-
zipMe.write(os.path.join(model_path, file), arcname=file, compress_type=zipfile.ZIP_DEFLATED)
231-
232-
for file in optional_files:
233-
if os.path.exists(os.path.join(model_path, file)):
234-
zipMe.write(os.path.join(model_path, file), arcname=file, compress_type=zipfile.ZIP_DEFLATED)
235-
236-
return zip_file_name
237-
238-
239223
def _process_huggingface(
240224
model_type: str, model_path: str, filename: str = "fine-tuned-paligemma-3b-pt-224.f16.npz"
241225
) -> str:

0 commit comments

Comments
 (0)