Skip to content

Commit 1ab0692

Browse files
authored
Merge branch 'main' into classification_annotation_upload
2 parents d9789cd + c0b6bf0 commit 1ab0692

File tree

7 files changed

+65
-39
lines changed

7 files changed

+65
-39
lines changed

roboflow/core/project.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ def versions(self):
9999
return version_array
100100

101101
def generate_version(self, settings):
102-
103102
"""
104103
Settings, a python dict with augmentation and preprocessing keys and specifications for generation.
105104
These settings mirror capabilities available via the Roboflow UI.
@@ -239,7 +238,6 @@ def version(self, version_number, local=None):
239238
version_info = self.get_version_information()
240239

241240
for version_object in version_info:
242-
243241
current_version_num = os.path.basename(version_object["id"])
244242
if current_version_num == str(version_number):
245243
vers = Version(
@@ -265,6 +263,7 @@ def __image_upload(
265263
split="train",
266264
batch_name=DEFAULT_BATCH_NAME,
267265
tag_names=[],
266+
**kwargs,
268267
):
269268
"""function to upload image to the specific project
270269
:param image_path: path to image you'd like to upload.
@@ -295,6 +294,8 @@ def __image_upload(
295294
batch_name,
296295
]
297296
)
297+
for key, value in kwargs.items():
298+
self.image_upload_url += "&" + str(key) + "=" + str(value)
298299

299300
for tag in tag_names:
300301
self.image_upload_url = self.image_upload_url + f"&tag={tag}"
@@ -410,6 +411,7 @@ def upload(
410411
num_retry_uploads: int = 0,
411412
batch_name: str = DEFAULT_BATCH_NAME,
412413
tag_names: list = [],
414+
**kwargs,
413415
):
414416
"""Upload image function based on the RESTful API
415417
@@ -460,6 +462,7 @@ def upload(
460462
num_retry_uploads=num_retry_uploads,
461463
batch_name=batch_name,
462464
tag_names=tag_names,
465+
**kwargs,
463466
)
464467
else:
465468
images = os.listdir(image_path)
@@ -475,6 +478,7 @@ def upload(
475478
num_retry_uploads=num_retry_uploads,
476479
batch_name=batch_name,
477480
tag_names=tag_names,
481+
**kwargs,
478482
)
479483
print("[ " + path + " ] was uploaded succesfully.")
480484
else:
@@ -491,8 +495,8 @@ def single_upload(
491495
num_retry_uploads=0,
492496
batch_name=DEFAULT_BATCH_NAME,
493497
tag_names=[],
498+
**kwargs,
494499
):
495-
496500
success = False
497501
annotation_success = False
498502
# User gives image path
@@ -504,6 +508,7 @@ def single_upload(
504508
split=split,
505509
batch_name=batch_name,
506510
tag_names=tag_names,
511+
**kwargs,
507512
)
508513
# Get JSON response values
509514
try:
@@ -543,6 +548,7 @@ def single_upload(
543548
image_id=image_id,
544549
split=split,
545550
num_retry_uploads=num_retry_uploads - 1,
551+
**kwargs,
546552
)
547553
return
548554
else:

roboflow/core/version.py

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@
2626
from roboflow.models.instance_segmentation import InstanceSegmentationModel
2727
from roboflow.models.object_detection import ObjectDetectionModel
2828
from roboflow.models.semantic_segmentation import SemanticSegmentationModel
29+
from roboflow.util.annotations import amend_data_yaml
2930
from roboflow.util.versions import (
31+
get_wrong_dependencies_versions,
3032
print_warn_for_wrong_dependencies_versions,
31-
warn_for_wrong_dependencies_versions,
3233
)
3334

3435
load_dotenv()
@@ -311,7 +312,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
311312
model_path (str): File path to model weights to be uploaded
312313
"""
313314

314-
supported_models = ["yolov8", "yolov5"]
315+
supported_models = ["yolov8", "yolov5", "yolov7-seg"]
315316

316317
if model_type not in supported_models:
317318
raise (
@@ -334,7 +335,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
334335
[("ultralytics", "<=", "8.0.20")]
335336
)
336337

337-
elif model_type == "yolov5":
338+
elif model_type in ["yolov5", "yolov7-seg"]:
338339
try:
339340
import torch
340341
except ImportError as e:
@@ -378,7 +379,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
378379
"ultralytics_version": ultralytics.__version__,
379380
"model_type": model_type,
380381
}
381-
elif model_type == "yolov5":
382+
elif model_type in ["yolov5", "yolov7-seg"]:
382383
# parse from yaml for yolov5
383384

384385
with open(os.path.join(model_path, "opt.yaml"), "r") as stream:
@@ -406,11 +407,19 @@ def deploy(self, model_type: str, model_path: str) -> None:
406407

407408
with zipfile.ZipFile(model_path + "roboflow_deploy.zip", "w") as zipMe:
408409
for file in lista_files:
409-
zipMe.write(
410-
model_path + file,
411-
arcname=file,
412-
compress_type=zipfile.ZIP_DEFLATED,
413-
)
410+
if os.path.exists(model_path + file):
411+
zipMe.write(
412+
model_path + file,
413+
arcname=file,
414+
compress_type=zipfile.ZIP_DEFLATED,
415+
)
416+
else:
417+
if file in ["model_artifacts.json", "state_dict.pt"]:
418+
raise (
419+
ValueError(
420+
f"File {file} not found. Please make sure to provide a valid model path."
421+
)
422+
)
414423

415424
res = requests.get(
416425
f"{API_URL}/{self.workspace}/{self.project}/{self.version}/uploadModel?api_key={self.__api_key}"
@@ -549,7 +558,7 @@ def __get_format_identifier(self, format):
549558
friendly_formats = {"yolov5": "yolov5pytorch", "yolov7": "yolov7pytorch"}
550559
return friendly_formats.get(format, format)
551560

552-
def __reformat_yaml(self, location, format):
561+
def __reformat_yaml(self, location: str, format: str):
553562
"""
554563
Certain formats seem to require reformatting the downloaded YAML.
555564
It'd be nice if the API did this, but we're doing it in python for now.
@@ -559,28 +568,29 @@ def __reformat_yaml(self, location, format):
559568
560569
:return None:
561570
"""
562-
if format in ["yolov5pytorch", "yolov7pytorch", "yolov8"]:
563-
with open(location + "/data.yaml") as file:
564-
new_yaml = yaml.safe_load(file)
565-
new_yaml["train"] = location + new_yaml["train"].lstrip("..")
566-
new_yaml["val"] = location + new_yaml["val"].lstrip("..")
567-
568-
os.remove(location + "/data.yaml")
569-
570-
with open(location + "/data.yaml", "w") as outfile:
571-
yaml.dump(new_yaml, outfile)
572-
573-
if format == "mt-yolov6":
574-
with open(location + "/data.yaml") as file:
575-
new_yaml = yaml.safe_load(file)
576-
new_yaml["train"] = location + new_yaml["train"].lstrip(".")
577-
new_yaml["val"] = location + new_yaml["val"].lstrip(".")
578-
new_yaml["test"] = location + new_yaml["test"].lstrip(".")
579-
580-
os.remove(location + "/data.yaml")
571+
data_path = os.path.join(location, "data.yaml")
572+
573+
def callback(content: dict) -> dict:
574+
if format == "mt-yolov6":
575+
content["train"] = location + content["train"].lstrip(".")
576+
content["val"] = location + content["val"].lstrip(".")
577+
content["test"] = location + content["test"].lstrip(".")
578+
if format in ["yolov5pytorch", "yolov7pytorch", "yolov8"]:
579+
content["train"] = location + content["train"].lstrip("..")
580+
content["val"] = location + content["val"].lstrip("..")
581+
try:
582+
# get_wrong_dependencies_versions raises exception if ultralytics is not installed at all
583+
if not get_wrong_dependencies_versions(
584+
dependencies_versions=[("ultralytics", ">=", "8.0.30")]
585+
):
586+
content["train"] = "train/images"
587+
content["val"] = "valid/images"
588+
content["test"] = "test/images"
589+
except ModuleNotFoundError:
590+
pass
591+
return content
581592

582-
with open(location + "/data.yaml", "w") as outfile:
583-
yaml.dump(new_yaml, outfile)
593+
amend_data_yaml(path=data_path, callback=callback)
584594

585595
def __str__(self):
586596
"""string representation of version object."""

roboflow/core/workspace.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,6 @@ def active_learning(
368368
and prediction["confidence"] * 100
369369
<= conditionals["confidence_interval"][1]
370370
):
371-
372371
# filter out non-target_class uploads if enabled
373372
if (
374373
len(conditionals["target_classes"]) > 0

roboflow/models/object_detection.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,6 @@ def __generate_url(
214214
labels=None,
215215
format=None,
216216
):
217-
218217
# Reassign parameters if any parameters are changed
219218
if local is not None:
220219
if not local:

roboflow/util/annotations.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import os
2+
from typing import Callable
3+
4+
import yaml
5+
6+
7+
def amend_data_yaml(path: str, callback: Callable[[dict], dict]):
8+
with open(path) as source:
9+
content = yaml.safe_load(source)
10+
content = callback(content)
11+
os.remove(path)
12+
with open(path, "w") as target:
13+
yaml.dump(content, target)

roboflow/util/two_stage_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88

99
def ocr_infer(image):
10-
1110
# Convert to JPEG Buffer
1211
buffered = io.BytesIO()
1312
image.save(buffered, quality=90, format="PNG")

roboflow/util/versions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def get_wrong_dependencies_versions(
88
dependencies_versions: List[Tuple[str, str, str]]
99
) -> List[Tuple[str, str, str, str]]:
1010
"""
11-
Get a list of missmatching dependencies with current version installed.
11+
Get a list of mismatching dependencies with current version installed.
1212
E.g., assuming we pass `get_wrong_dependencies_versions([("torch", "==", "1.2.0")]), we will check if the current version of `torch` is `==1.2.0`. If not, we will return `[("torch", "==", "1.2.0", "<current_installed_version>")]
1313
1414
We support `<=`, `==`, `>=`
@@ -45,7 +45,7 @@ def print_warn_for_wrong_dependencies_versions(
4545
dependencies_versions: List[Tuple[str, str, str]]
4646
):
4747
wrong_dependencies_versions = get_wrong_dependencies_versions(dependencies_versions)
48-
for (dependency, order, version, module_version) in wrong_dependencies_versions:
48+
for dependency, order, version, module_version in wrong_dependencies_versions:
4949
print(
5050
f"Dependency {dependency}{order}{version} is required but found version={module_version}, to fix: `pip install {dependency}{order}{version}`"
5151
)

0 commit comments

Comments
 (0)