Skip to content

Commit 6baf816

Browse files
committed
fix labelmap parsing for yaml dict
1 parent 53c41ba commit 6baf816

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

roboflow/util/image_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,10 @@ def load_labelmap(f):
103103
with open(f) as file:
104104
data = yaml.safe_load(file)
105105
if "names" in data:
106-
return {i: name for i, name in enumerate(data["names"])}
106+
names = data["names"]
107+
if isinstance(names, dict):
108+
return {int(i): name for i, name in names.items()}
109+
return {i: name for i, name in enumerate(names)}
107110
else:
108111
with open(f) as file:
109112
lines = [line for line in file.readlines() if line.strip()]

tests/util/test_image_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import responses
44

55
from roboflow.util.image_utils import check_image_path, check_image_url
6+
from roboflow.util.image_utils import load_labelmap
7+
import tempfile
8+
import os
69

710

811
class TestCheckImagePath(unittest.TestCase):
@@ -33,3 +36,17 @@ def test_url_not_found(self):
3336
url = "https://example.com/notfound.png"
3437
responses.add(responses.HEAD, url, status=404)
3538
self.assertFalse(check_image_url(url))
39+
40+
41+
class TestLoadLabelmap(unittest.TestCase):
42+
def test_yaml_dict_names(self):
43+
with tempfile.NamedTemporaryFile("w+", suffix=".yaml", delete=False) as tmp:
44+
tmp.write("names:\n 0: abc\n 1: def\n")
45+
tmp.flush()
46+
result = load_labelmap(tmp.name)
47+
os.unlink(tmp.name)
48+
self.assertEqual(result, {0: "abc", 1: "def"})
49+
50+
def test_yaml_list_names(self):
51+
result = load_labelmap("tests/datasets/sharks-tiny-yolov9/data.yaml")
52+
self.assertEqual(result, {0: "fish", 1: "primary", 2: "shark"})

0 commit comments

Comments
 (0)