Skip to content

Commit 0290064

Browse files
committed
bye version.get_pt_file() hello model.download()
1 parent 32a220b commit 0290064

File tree

2 files changed

+36
-28
lines changed

2 files changed

+36
-28
lines changed

roboflow/core/version.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -433,32 +433,6 @@ def live_plot(epochs, mAP, loss, title=""):
433433
# return the model object
434434
return self.model
435435

436-
def get_pt_weights(self, location="."):
437-
workspace, project, *_ = self.id.rsplit("/")
438-
439-
# get pt url
440-
pt_api_url = f"{API_URL}/{workspace}/{project}/{self.version}/ptFile"
441-
442-
r = requests.get(pt_api_url, params={"api_key": self.__api_key})
443-
444-
r.raise_for_status()
445-
446-
pt_weights_url = r.json()["weightsUrl"]
447-
448-
def bar_progress(current, total, width=80):
449-
progress_message = (
450-
"Downloading weights to "
451-
+ location
452-
+ "/weights.pt"
453-
+ ": %d%% [%d / %d] bytes" % (current / total * 100, current, total)
454-
)
455-
sys.stdout.write("\r" + progress_message)
456-
sys.stdout.flush()
457-
458-
wget.download(pt_weights_url, out=location + "/weights.pt", bar=bar_progress)
459-
460-
return
461-
462436
# @warn_for_wrong_dependencies_versions([("ultralytics", "<=", "8.0.20")])
463437
def deploy(self, model_type: str, model_path: str) -> None:
464438
"""Uploads provided weights file to Roboflow

roboflow/models/object_detection.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import io
44
import json
55
import os
6+
import sys
67
import random
78
import urllib
89
from pathlib import Path
@@ -11,17 +12,20 @@
1112
import matplotlib.pyplot as plt
1213
import numpy as np
1314
import requests
15+
import wget
1416
from PIL import Image
1517

16-
from roboflow.config import OBJECT_DETECTION_MODEL
18+
from roboflow.config import (
19+
OBJECT_DETECTION_MODEL,
20+
API_URL,
21+
)
1722
from roboflow.util.image_utils import check_image_url
1823
from roboflow.util.prediction import PredictionGroup
1924
from roboflow.util.versions import (
2025
print_warn_for_wrong_dependencies_versions,
2126
warn_for_wrong_dependencies_versions,
2227
)
2328

24-
2529
class ObjectDetectionModel:
2630
def __init__(
2731
self,
@@ -459,6 +463,36 @@ def view(button):
459463
else:
460464
view(stopButton)
461465

466+
def download(self, location=".", format="pt"):
467+
supported_formats = ["pt"]
468+
if format not in supported_formats:
469+
raise Exception(f"Unsupported format {format}. Must be one of {supported_formats}")
470+
471+
workspace, project, version = self.id.rsplit("/")
472+
473+
# get pt url
474+
pt_api_url = f"{API_URL}/{workspace}/{project}/{self.version}/ptFile"
475+
476+
r = requests.get(pt_api_url, params={"api_key": self.__api_key})
477+
478+
r.raise_for_status()
479+
480+
pt_weights_url = r.json()["weightsUrl"]
481+
482+
def bar_progress(current, total, width=80):
483+
progress_message = (
484+
"Downloading weights to "
485+
+ location
486+
+ "/weights.pt"
487+
+ ": %d%% [%d / %d] bytes" % (current / total * 100, current, total)
488+
)
489+
sys.stdout.write("\r" + progress_message)
490+
sys.stdout.flush()
491+
492+
wget.download(pt_weights_url, out=location + "/weights.pt", bar=bar_progress)
493+
494+
return
495+
462496
def __exception_check(self, image_path_check=None):
463497
# Check if Image path exists exception check (for both hosted URL and local image)
464498
if image_path_check is not None:

0 commit comments

Comments
 (0)