Skip to content

Commit 6d767a2

Browse files
committed
add custom weights filename for yolonas
1 parent 0494f42 commit 6d767a2

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

roboflow/core/version.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ def deploy(self, model_type: str, model_path: str, filename: str = "weights/best
544544

545545
self.upload_zip(model_type, model_path)
546546

547-
def deploy_yolonas(self, model_type: str, model_path: str) -> None:
547+
def deploy_yolonas(self, model_type: str, model_path: str, filename: str = "weights/best.pt") -> None:
548548
try:
549549
import torch
550550
except ImportError:
@@ -553,7 +553,7 @@ def deploy_yolonas(self, model_type: str, model_path: str) -> None:
553553
" Please install it with `pip install torch`"
554554
)
555555

556-
model = torch.load(os.path.join(model_path, "weights/best.pt"), map_location="cpu")
556+
model = torch.load(os.path.join(model_path, filename), map_location="cpu")
557557
class_names = model["processing_params"]["class_names"]
558558

559559
opt_path = os.path.join(model_path, "opt.yaml")
@@ -586,7 +586,7 @@ def deploy_yolonas(self, model_type: str, model_path: str) -> None:
586586
with open(os.path.join(model_path, "model_artifacts.json"), "w") as fp:
587587
json.dump(model_artifacts, fp)
588588

589-
shutil.copy(os.path.join(model_path, "weights/best.pt"), os.path.join(model_path, "state_dict.pt"))
589+
shutil.copy(os.path.join(model_path, filename), os.path.join(model_path, "state_dict.pt"))
590590

591591
list_files = [
592592
"results.json",
@@ -604,7 +604,7 @@ def deploy_yolonas(self, model_type: str, model_path: str) -> None:
604604
compress_type=zipfile.ZIP_DEFLATED,
605605
)
606606
else:
607-
if file in ["model_artifacts.json", "best.pt"]:
607+
if file in ["model_artifacts.json", filename]:
608608
raise (ValueError(f"File {file} not found. Please make sure to provide a" " valid model path."))
609609

610610
self.upload_zip(model_type, model_path)

0 commit comments

Comments
 (0)