Skip to content

Commit fc9caf7

Browse files
committed
pt download
1 parent 071bf45 commit fc9caf7

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

roboflow/core/version.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,34 @@ def live_plot(epochs, mAP, loss, title=""):
432432

433433
# return the model object
434434
return self.model
435+
436+
def get_pt_weights(self, location="."):
437+
438+
workspace, project, *_ = self.id.rsplit("/")
439+
440+
#get pt url
441+
pt_api_url = f"{API_URL}/{workspace}/{project}/{self.version}/ptFile"
442+
443+
r = requests.get(pt_api_url, params={"api_key": self.__api_key})
444+
445+
r.raise_for_status()
446+
447+
pt_weights_url = r.json()["weightsUrl"]
448+
449+
def bar_progress(current, total, width=80):
450+
451+
progress_message = (
452+
"Downloading weights to "
453+
+ location
454+
+ "/weights.pt"
455+
+ ": %d%% [%d / %d] bytes" % (current / total * 100, current, total)
456+
)
457+
sys.stdout.write("\r" + progress_message)
458+
sys.stdout.flush()
459+
460+
wget.download(pt_weights_url, out=location + "/weights.pt", bar=bar_progress)
461+
462+
return
435463

436464
# @warn_for_wrong_dependencies_versions([("ultralytics", "<=", "8.0.20")])
437465
def deploy(self, model_type: str, model_path: str) -> None:

0 commit comments

Comments
 (0)