Skip to content

Commit 2276e84

Browse files
authored
Merge branch 'main' into upload-with-tag
2 parents 2d93321 + 57b7aa7 commit 2276e84

File tree

5 files changed

+69
-33
lines changed

5 files changed

+69
-33
lines changed

roboflow/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from roboflow.core.project import Project
99
from roboflow.core.workspace import Workspace
1010

11-
__version__ = "0.2.29"
11+
__version__ = "0.2.30"
1212

1313

1414
def check_key(api_key, model, notebook, num_retries=0):

roboflow/core/project.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,8 @@ def __image_upload(
263263
split="train",
264264
batch_name=DEFAULT_BATCH_NAME,
265265
tag_names=[],
266+
**kwargs
267+
266268
):
267269
"""function to upload image to the specific project
268270
:param image_path: path to image you'd like to upload.
@@ -293,6 +295,8 @@ def __image_upload(
293295
batch_name,
294296
]
295297
)
298+
for key, value in kwargs.items():
299+
self.image_upload_url += "&" + str(key) + "=" + str(value)
296300

297301
for tag in tag_names:
298302
self.image_upload_url = self.image_upload_url + f"&tag={tag}"
@@ -377,6 +381,7 @@ def check_valid_image(self, image_path):
377381

378382
def upload(
379383
self,
384+
380385
image_path: str = None,
381386
annotation_path: str = None,
382387
hosted_image: bool = False,
@@ -385,6 +390,7 @@ def upload(
385390
num_retry_uploads: int = 0,
386391
batch_name: str = DEFAULT_BATCH_NAME,
387392
tag_names: list = [],
393+
**kwargs
388394
):
389395
"""Upload image function based on the RESTful API
390396
@@ -435,6 +441,8 @@ def upload(
435441
num_retry_uploads=num_retry_uploads,
436442
batch_name=batch_name,
437443
tag_names=tag_names,
444+
**kwargs
445+
438446
)
439447
else:
440448
images = os.listdir(image_path)
@@ -450,6 +458,8 @@ def upload(
450458
num_retry_uploads=num_retry_uploads,
451459
batch_name=batch_name,
452460
tag_names=tag_names,
461+
**kwargs
462+
453463
)
454464
print("[ " + path + " ] was uploaded succesfully.")
455465
else:
@@ -466,6 +476,7 @@ def single_upload(
466476
num_retry_uploads=0,
467477
batch_name=DEFAULT_BATCH_NAME,
468478
tag_names=[],
479+
**kwargs
469480
):
470481
success = False
471482
annotation_success = False
@@ -478,6 +489,7 @@ def single_upload(
478489
split=split,
479490
batch_name=batch_name,
480491
tag_names=tag_names,
492+
**kwargs,
481493
)
482494
# Get JSON response values
483495
try:
@@ -517,6 +529,7 @@ def single_upload(
517529
image_id=image_id,
518530
split=split,
519531
num_retry_uploads=num_retry_uploads - 1,
532+
**kwargs,
520533
)
521534
return
522535
else:

roboflow/core/version.py

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@
2626
from roboflow.models.instance_segmentation import InstanceSegmentationModel
2727
from roboflow.models.object_detection import ObjectDetectionModel
2828
from roboflow.models.semantic_segmentation import SemanticSegmentationModel
29+
from roboflow.util.annotations import amend_data_yaml
2930
from roboflow.util.versions import (
31+
get_wrong_dependencies_versions,
3032
print_warn_for_wrong_dependencies_versions,
31-
warn_for_wrong_dependencies_versions,
3233
)
3334

3435
load_dotenv()
@@ -311,7 +312,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
311312
model_path (str): File path to model weights to be uploaded
312313
"""
313314

314-
supported_models = ["yolov8", "yolov5"]
315+
supported_models = ["yolov8", "yolov5", "yolov7-seg"]
315316

316317
if model_type not in supported_models:
317318
raise (
@@ -334,7 +335,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
334335
[("ultralytics", "<=", "8.0.20")]
335336
)
336337

337-
elif model_type == "yolov5":
338+
elif model_type in ["yolov5", "yolov7-seg"]:
338339
try:
339340
import torch
340341
except ImportError as e:
@@ -378,7 +379,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
378379
"ultralytics_version": ultralytics.__version__,
379380
"model_type": model_type,
380381
}
381-
elif model_type == "yolov5":
382+
elif model_type in ["yolov5", "yolov7-seg"]:
382383
# parse from yaml for yolov5
383384

384385
with open(os.path.join(model_path, "opt.yaml"), "r") as stream:
@@ -406,11 +407,19 @@ def deploy(self, model_type: str, model_path: str) -> None:
406407

407408
with zipfile.ZipFile(model_path + "roboflow_deploy.zip", "w") as zipMe:
408409
for file in lista_files:
409-
zipMe.write(
410-
model_path + file,
411-
arcname=file,
412-
compress_type=zipfile.ZIP_DEFLATED,
413-
)
410+
if os.path.exists(model_path + file):
411+
zipMe.write(
412+
model_path + file,
413+
arcname=file,
414+
compress_type=zipfile.ZIP_DEFLATED,
415+
)
416+
else:
417+
if file in ["model_artifacts.json", "state_dict.pt"]:
418+
raise (
419+
ValueError(
420+
f"File {file} not found. Please make sure to provide a valid model path."
421+
)
422+
)
414423

415424
res = requests.get(
416425
f"{API_URL}/{self.workspace}/{self.project}/{self.version}/uploadModel?api_key={self.__api_key}"
@@ -549,7 +558,7 @@ def __get_format_identifier(self, format):
549558
friendly_formats = {"yolov5": "yolov5pytorch", "yolov7": "yolov7pytorch"}
550559
return friendly_formats.get(format, format)
551560

552-
def __reformat_yaml(self, location, format):
561+
def __reformat_yaml(self, location: str, format: str):
553562
"""
554563
Certain formats seem to require reformatting the downloaded YAML.
555564
It'd be nice if the API did this, but we're doing it in python for now.
@@ -559,28 +568,29 @@ def __reformat_yaml(self, location, format):
559568
560569
:return None:
561570
"""
562-
if format in ["yolov5pytorch", "yolov7pytorch", "yolov8"]:
563-
with open(location + "/data.yaml") as file:
564-
new_yaml = yaml.safe_load(file)
565-
new_yaml["train"] = location + new_yaml["train"].lstrip("..")
566-
new_yaml["val"] = location + new_yaml["val"].lstrip("..")
567-
568-
os.remove(location + "/data.yaml")
569-
570-
with open(location + "/data.yaml", "w") as outfile:
571-
yaml.dump(new_yaml, outfile)
572-
573-
if format == "mt-yolov6":
574-
with open(location + "/data.yaml") as file:
575-
new_yaml = yaml.safe_load(file)
576-
new_yaml["train"] = location + new_yaml["train"].lstrip(".")
577-
new_yaml["val"] = location + new_yaml["val"].lstrip(".")
578-
new_yaml["test"] = location + new_yaml["test"].lstrip(".")
579-
580-
os.remove(location + "/data.yaml")
571+
data_path = os.path.join(location, "data.yaml")
572+
573+
def callback(content: dict) -> dict:
574+
if format == "mt-yolov6":
575+
content["train"] = location + content["train"].lstrip(".")
576+
content["val"] = location + content["val"].lstrip(".")
577+
content["test"] = location + content["test"].lstrip(".")
578+
if format in ["yolov5pytorch", "yolov7pytorch", "yolov8"]:
579+
content["train"] = location + content["train"].lstrip("..")
580+
content["val"] = location + content["val"].lstrip("..")
581+
try:
582+
# get_wrong_dependencies_versions raises exception if ultralytics is not installed at all
583+
if not get_wrong_dependencies_versions(
584+
dependencies_versions=[("ultralytics", ">=", "8.0.30")]
585+
):
586+
content["train"] = "train/images"
587+
content["val"] = "valid/images"
588+
content["test"] = "test/images"
589+
except ModuleNotFoundError:
590+
pass
591+
return content
581592

582-
with open(location + "/data.yaml", "w") as outfile:
583-
yaml.dump(new_yaml, outfile)
593+
amend_data_yaml(path=data_path, callback=callback)
584594

585595
def __str__(self):
586596
"""string representation of version object."""

roboflow/util/annotations.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import os
2+
from typing import Callable
3+
4+
import yaml
5+
6+
7+
def amend_data_yaml(path: str, callback: Callable[[dict], dict]):
8+
with open(path) as source:
9+
content = yaml.safe_load(source)
10+
content = callback(content)
11+
os.remove(path)
12+
with open(path, "w") as target:
13+
yaml.dump(content, target)

roboflow/util/versions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def get_wrong_dependencies_versions(
88
dependencies_versions: List[Tuple[str, str, str]]
99
) -> List[Tuple[str, str, str, str]]:
1010
"""
11-
Get a list of missmatching dependencies with current version installed.
11+
Get a list of mismatching dependencies with current version installed.
1212
E.g., assuming we pass `get_wrong_dependencies_versions([("torch", "==", "1.2.0")]), we will check if the current version of `torch` is `==1.2.0`. If not, we will return `[("torch", "==", "1.2.0", "<current_installed_version>")]
1313
1414
We support `<=`, `==`, `>=`

0 commit comments

Comments
 (0)