|
3 | 3 | import io |
4 | 4 | import json |
5 | 5 | import os |
| 6 | +import sys |
6 | 7 | import random |
7 | 8 | import urllib |
8 | 9 | from pathlib import Path |
|
11 | 12 | import matplotlib.pyplot as plt |
12 | 13 | import numpy as np |
13 | 14 | import requests |
| 15 | +import wget |
14 | 16 | from PIL import Image |
15 | 17 |
|
16 | | -from roboflow.config import OBJECT_DETECTION_MODEL |
| 18 | +from roboflow.config import ( |
| 19 | + OBJECT_DETECTION_MODEL, |
| 20 | + API_URL, |
| 21 | +) |
17 | 22 | from roboflow.util.image_utils import check_image_url |
18 | 23 | from roboflow.util.prediction import PredictionGroup |
19 | 24 | from roboflow.util.versions import ( |
20 | 25 | print_warn_for_wrong_dependencies_versions, |
21 | 26 | warn_for_wrong_dependencies_versions, |
22 | 27 | ) |
23 | 28 |
|
24 | | - |
25 | 29 | class ObjectDetectionModel: |
26 | 30 | def __init__( |
27 | 31 | self, |
@@ -459,6 +463,36 @@ def view(button): |
459 | 463 | else: |
460 | 464 | view(stopButton) |
461 | 465 |
|
| 466 | + def download(self, location=".", format="pt"): |
| 467 | + supported_formats = ["pt"] |
| 468 | + if format not in supported_formats: |
| 469 | + raise Exception(f"Unsupported format {format}. Must be one of {supported_formats}") |
| 470 | + |
| 471 | + workspace, project, version = self.id.rsplit("/") |
| 472 | + |
| 473 | + # get pt url |
| 474 | + pt_api_url = f"{API_URL}/{workspace}/{project}/{self.version}/ptFile" |
| 475 | + |
| 476 | + r = requests.get(pt_api_url, params={"api_key": self.__api_key}) |
| 477 | + |
| 478 | + r.raise_for_status() |
| 479 | + |
| 480 | + pt_weights_url = r.json()["weightsUrl"] |
| 481 | + |
| 482 | + def bar_progress(current, total, width=80): |
| 483 | + progress_message = ( |
| 484 | + "Downloading weights to " |
| 485 | + + location |
| 486 | + + "/weights.pt" |
| 487 | + + ": %d%% [%d / %d] bytes" % (current / total * 100, current, total) |
| 488 | + ) |
| 489 | + sys.stdout.write("\r" + progress_message) |
| 490 | + sys.stdout.flush() |
| 491 | + |
| 492 | + wget.download(pt_weights_url, out=location + "/weights.pt", bar=bar_progress) |
| 493 | + |
| 494 | + return |
| 495 | + |
462 | 496 | def __exception_check(self, image_path_check=None): |
463 | 497 | # Check if Image path exists exception check (for both hosted URL and local image) |
464 | 498 | if image_path_check is not None: |
|
0 commit comments