|
1 | 1 | import copy |
2 | 2 | import json |
3 | 3 | import os |
| 4 | +import shutil |
4 | 5 | import sys |
5 | 6 | import time |
6 | 7 | import zipfile |
@@ -425,10 +426,14 @@ def deploy(self, model_type: str, model_path: str) -> None: |
425 | 426 | model_path (str): File path to model weights to be uploaded |
426 | 427 | """ |
427 | 428 |
|
428 | | - supported_models = ["yolov5", "yolov7-seg", "yolov8", "yolov9"] |
| 429 | + supported_models = ["yolov5", "yolov7-seg", "yolov8", "yolov9", "yolonas"] |
429 | 430 |
|
430 | 431 | if not any(supported_model in model_type for supported_model in supported_models): |
431 | 432 | raise (ValueError(f"Model type {model_type} not supported. Supported models are" f" {supported_models}")) |
| 433 | + |
| 434 | + if "yolonas" in model_type: |
| 435 | + self.deploy_yolonas(model_path) |
| 436 | + return |
432 | 437 |
|
433 | 438 | if "yolov8" in model_type: |
434 | 439 | try: |
@@ -535,6 +540,61 @@ def deploy(self, model_type: str, model_path: str) -> None: |
535 | 540 | if file in ["model_artifacts.json", "state_dict.pt"]: |
536 | 541 | raise (ValueError(f"File {file} not found. Please make sure to provide a" " valid model path.")) |
537 | 542 |
|
| 543 | + self.upload_zip(model_type, model_path) |
| 544 | + |
| 545 | + def deploy_yolonas(self, model_type: str, model_path: str) -> None: |
| 546 | + try: |
| 547 | + import torch |
| 548 | + except ImportError: |
| 549 | + raise ( |
| 550 | + "The torch python package is required to deploy yolonas models." |
| 551 | + " Please install it with `pip install torch`" |
| 552 | + ) |
| 553 | + |
| 554 | + model = torch.load(os.path.join(model_path, "weights/best.pt"), map_location="cpu") |
| 555 | + class_names = model["processing_params"]["class_names"] |
| 556 | + |
| 557 | + with open(os.path.join(model_path, "opt.yaml"), "r") as stream: |
| 558 | + opts = yaml.safe_load(stream) |
| 559 | + |
| 560 | + model_artifacts = { |
| 561 | + "names": class_names, |
| 562 | + "nc": len(class_names), |
| 563 | + "args": { |
| 564 | + "imgsz": opts["imgsz"] if "imgsz" in opts else opts["img_size"], |
| 565 | + "batch": opts["batch_size"], |
| 566 | + }, |
| 567 | + "model_type": model_type, |
| 568 | + } |
| 569 | + |
| 570 | + with open(os.path.join(model_path, "model_artifacts.json"), "w") as fp: |
| 571 | + json.dump(model_artifacts, fp) |
| 572 | + |
| 573 | + shutil.copy(os.path.join(model_path, "weights/best.pt"), |
| 574 | + os.path.join(model_path, "state_dict.pt")) |
| 575 | + |
| 576 | + lista_files = [ |
| 577 | + "results.json", |
| 578 | + "results.png", |
| 579 | + "model_artifacts.json", |
| 580 | + "state_dict.pt", |
| 581 | + ] |
| 582 | + |
| 583 | + with zipfile.ZipFile(os.path.join(model_path, "roboflow_deploy.zip"), "w") as zipMe: |
| 584 | + for file in lista_files: |
| 585 | + if os.path.exists(os.path.join(model_path, file)): |
| 586 | + zipMe.write( |
| 587 | + os.path.join(model_path, file), |
| 588 | + arcname=file, |
| 589 | + compress_type=zipfile.ZIP_DEFLATED, |
| 590 | + ) |
| 591 | + else: |
| 592 | + if file in ["model_artifacts.json", "best.pt"]: |
| 593 | + raise (ValueError(f"File {file} not found. Please make sure to provide a" " valid model path.")) |
| 594 | + |
| 595 | + self.upload_zip(model_type, model_path) |
| 596 | + |
| 597 | + def upload_zip(self, model_type: str, model_path: str): |
538 | 598 | res = requests.get( |
539 | 599 | f"{API_URL}/{self.workspace}/{self.project}/{self.version}" |
540 | 600 | f"/uploadModel?api_key={self.__api_key}&modelType={model_type}&nocache=true" |
|
0 commit comments