Skip to content

Commit 1f0cfd5

Browse files
committed
Fix Issue #345: KeypointDetectionModel returns correct prediction type
1 parent 1dd0500 commit 1f0cfd5

File tree

4 files changed

+9
-5
lines changed

4 files changed

+9
-5
lines changed

roboflow/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,11 @@ def get_conditional_configuration_variable(key, default):
7171
TYPE_OBJECT_DETECTION = "object-detection"
7272
TYPE_INSTANCE_SEGMENTATION = "instance-segmentation"
7373
TYPE_SEMANTIC_SEGMENTATION = "semantic-segmentation"
74+
TYPE_SEMANTIC_SEGMENTATION = "semantic-segmentation"
7475
TYPE_KEYPOINT_DETECTION = "keypoint-detection"
7576

77+
KEYPOINT_DETECTION_MODEL = "KeypointDetectionModel"
78+
7679
DEFAULT_BATCH_NAME = "Pip Package Upload"
7780
DEFAULT_JOB_NAME = "Annotated via API"
7881

roboflow/models/keypoint_detection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import requests
99
from PIL import Image
1010

11-
from roboflow.config import CLASSIFICATION_MODEL
11+
from roboflow.config import KEYPOINT_DETECTION_MODEL
1212
from roboflow.models.inference import InferenceModel
1313
from roboflow.util.image_utils import check_image_url
1414
from roboflow.util.prediction import PredictionGroup
@@ -119,7 +119,7 @@ def predict(self, image_path, hosted=False, confidence=None): # type: ignore[ov
119119
resp.json(),
120120
image_dims=img_dims,
121121
image_path=image_path,
122-
prediction_type=CLASSIFICATION_MODEL,
122+
prediction_type=KEYPOINT_DETECTION_MODEL,
123123
colors=self.colors,
124124
)
125125

roboflow/util/image_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import urllib
66

77
# Third-party imports
8-
import pi_heif # type: ignore[import-untyped]
9-
import pillow_avif # type: ignore[import-untyped]
8+
import pi_heif # type: ignore[import-untyped, import-not-found]
9+
import pillow_avif # type: ignore[import-untyped, import-not-found]
1010
import requests
1111
import yaml
1212
from PIL import Image

roboflow/util/prediction.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from roboflow.config import (
1111
CLASSIFICATION_MODEL,
1212
INSTANCE_SEGMENTATION_MODEL,
13+
KEYPOINT_DETECTION_MODEL,
1314
OBJECT_DETECTION_MODEL,
1415
PREDICTION_OBJECT,
1516
SEMANTIC_SEGMENTATION_MODEL,
@@ -509,7 +510,7 @@ def create_prediction_group(json_response, image_path, prediction_type, image_di
509510
colors = {} if colors is None else colors
510511
prediction_list = []
511512

512-
if prediction_type in [OBJECT_DETECTION_MODEL, INSTANCE_SEGMENTATION_MODEL]:
513+
if prediction_type in [OBJECT_DETECTION_MODEL, INSTANCE_SEGMENTATION_MODEL, KEYPOINT_DETECTION_MODEL]:
513514
for prediction in json_response["predictions"]:
514515
prediction = Prediction(
515516
prediction,

0 commit comments

Comments
 (0)