Skip to content

Commit 1e9452a

Browse files
Merge pull request #1747 from roboflow/fix-oom-rfdetr-mask-posprocessing
improve rfdetr segmentation post processing
2 parents ee847e6 + 7ff836c commit 1e9452a

File tree

1 file changed

+95
-97
lines changed

1 file changed

+95
-97
lines changed

inference/models/rfdetr/rfdetr.py

Lines changed: 95 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,7 @@ def postprocess(
591591
img_dims = preproc_return_metadata["img_dims"]
592592

593593
processed_predictions = []
594+
processed_masks = []
594595

595596
for batch_idx in range(batch_size):
596597
orig_h, orig_w = img_dims[batch_idx]
@@ -623,37 +624,6 @@ def postprocess(
623624

624625
selected_boxes = bboxes[batch_idx, topk_boxes]
625626
selected_masks = masks[batch_idx, topk_boxes]
626-
if selected_masks.size != 0:
627-
if kwargs.get("mask_decode_mode", "accurate") == "accurate":
628-
target_res = (orig_w, orig_h)
629-
new_masks = []
630-
for mask in selected_masks:
631-
new_masks.append(
632-
cv2.resize(mask, target_res, interpolation=cv2.INTER_LINEAR)
633-
)
634-
selected_masks = np.stack(new_masks, axis=0)
635-
elif kwargs.get("mask_decode_mode", "accurate") == "tradeoff":
636-
tradeoff_factor = kwargs.get("tradeoff_factor", 0.0)
637-
mask_res = (selected_masks.shape[2], selected_masks.shape[1])
638-
full_res = (orig_w, orig_h)
639-
target_res = (
640-
int(
641-
mask_res[0] * (1 - tradeoff_factor)
642-
+ full_res[0] * tradeoff_factor
643-
),
644-
int(
645-
mask_res[1] * (1 - tradeoff_factor)
646-
+ full_res[1] * tradeoff_factor
647-
),
648-
)
649-
new_masks = []
650-
for mask in selected_masks:
651-
new_masks.append(
652-
cv2.resize(mask, target_res, interpolation=cv2.INTER_LINEAR)
653-
)
654-
selected_masks = np.stack(new_masks, axis=0)
655-
656-
selected_masks = selected_masks > 0
657627

658628
cxcy = selected_boxes[:, :2]
659629
wh = selected_boxes[:, 2:]
@@ -700,58 +670,59 @@ def postprocess(
700670
topk_labels,
701671
)
702672
)
703-
batch_predictions = batch_predictions[
704-
batch_predictions[:, 6] < len(self.class_names)
705-
]
706-
selected_masks = selected_masks[
707-
batch_predictions[:, 6] < len(self.class_names)
708-
]
709-
710-
outputs = []
711-
for pred, mask in zip(batch_predictions, selected_masks):
712-
outputs.append(list(pred) + [mask])
713-
714-
processed_predictions.append(outputs)
715-
716-
res = self.make_response(processed_predictions, img_dims, **kwargs)
717-
return res
718-
719-
def make_response(
720-
self,
721-
predictions: List[List[float]],
722-
img_dims: List[Tuple[int, int]],
723-
class_filter: Optional[List[str]] = None,
724-
*args,
725-
**kwargs,
726-
) -> List[ObjectDetectionInferenceResponse]:
727-
"""Constructs object detection response objects based on predictions.
728-
729-
Args:
730-
predictions (List[List[float]]): The list of predictions.
731-
img_dims (List[Tuple[int, int]]): Dimensions of the images.
732-
class_filter (Optional[List[str]]): A list of class names to filter, if provided.
733-
734-
Returns:
735-
List[ObjectDetectionInferenceResponse]: A list of response objects containing object detection predictions.
736-
"""
737-
738-
if isinstance(img_dims, dict) and "img_dims" in img_dims:
739-
img_dims = img_dims["img_dims"]
740-
741-
predictions = predictions[
742-
: len(img_dims)
743-
] # If the batch size was fixed we have empty preds at the end
744-
745-
batch_mask_preds = []
746-
for image_ind in range(len(img_dims)):
747-
masks = [pred[7] for pred in predictions[image_ind]]
748-
orig_h, orig_w = img_dims[image_ind]
749-
750-
mask_preds = []
751-
for mask in masks:
752-
points = mask2poly(mask.astype(np.uint8))
673+
valid_pred_mask = batch_predictions[:, 6] < len(self.class_names)
674+
675+
outputs_predictions = []
676+
outputs_polygons = []
677+
class_filter_local = kwargs.get("class_filter")
678+
for i, pred in enumerate(batch_predictions):
679+
if not valid_pred_mask[i]:
680+
continue
681+
# Early class filtering to avoid unnecessary mask processing
682+
if class_filter_local:
683+
try:
684+
pred_class_name = self.class_names[int(pred[6])]
685+
except Exception:
686+
continue
687+
if pred_class_name not in class_filter_local:
688+
continue
689+
mask = selected_masks[i]
690+
# Per-mask optional upscaling for better polygon quality without retaining all high-res masks
691+
mask_decode_mode = kwargs.get("mask_decode_mode", "accurate")
692+
if mask_decode_mode == "accurate":
693+
target_res = (orig_w, orig_h)
694+
if mask.shape[1] != target_res[0] or mask.shape[0] != target_res[1]:
695+
mask = cv2.resize(
696+
mask.astype(np.float32),
697+
target_res,
698+
interpolation=cv2.INTER_LINEAR,
699+
)
700+
elif mask_decode_mode == "tradeoff":
701+
tradeoff_factor = kwargs.get("tradeoff_factor", 0.0)
702+
mask_res = (mask.shape[1], mask.shape[0]) # (w, h)
703+
full_res = (orig_w, orig_h) # (w, h)
704+
target_res = (
705+
int(
706+
mask_res[0] * (1 - tradeoff_factor)
707+
+ full_res[0] * tradeoff_factor
708+
),
709+
int(
710+
mask_res[1] * (1 - tradeoff_factor)
711+
+ full_res[1] * tradeoff_factor
712+
),
713+
)
714+
if mask.shape[1] != target_res[0] or mask.shape[0] != target_res[1]:
715+
mask = cv2.resize(
716+
mask.astype(np.float32),
717+
target_res,
718+
interpolation=cv2.INTER_LINEAR,
719+
)
720+
# Ensure binary for polygonization
721+
mask_bin = (mask > 0).astype(np.uint8)
722+
points = mask2poly(mask_bin)
723+
# Scale polygon points back to original image coordinates if needed
753724
new_points = []
754-
prediction_h, prediction_w = mask.shape[0], mask.shape[1]
725+
prediction_h, prediction_w = mask_bin.shape[0], mask_bin.shape[1]
755726
for point in points:
756727
if self.resize_method == "Stretch to":
757728
new_x = point[0] * (orig_w / prediction_w)
@@ -763,14 +734,42 @@ def make_response(
763734
new_x = point[0] * scale + pad_x
764735
new_y = point[1] * scale + pad_y
765736
new_points.append(np.array([new_x, new_y]))
766-
mask_preds.append(new_points)
767-
batch_mask_preds.append(mask_preds)
737+
outputs_polygons.append(new_points)
738+
outputs_predictions.append(list(pred))
768739

769-
responses = [
770-
InstanceSegmentationInferenceResponse(
771-
predictions=[
740+
processed_predictions.append(outputs_predictions)
741+
processed_masks.append(outputs_polygons)
742+
743+
res = self.make_response(
744+
processed_predictions, processed_masks, img_dims, **kwargs
745+
)
746+
return res
747+
748+
def make_response(
749+
self,
750+
predictions: List[List[List[float]]],
751+
masks: List[List[List[np.ndarray]]],
752+
img_dims: List[Tuple[int, int]],
753+
class_filter: Optional[List[str]] = None,
754+
*args,
755+
**kwargs,
756+
) -> List[InstanceSegmentationInferenceResponse]:
757+
"""Constructs instance segmentation response objects from preprocessed predictions and polygons."""
758+
# Align to actual number of real images; predictions/masks may include padded slots
759+
if isinstance(img_dims, dict) and "img_dims" in img_dims:
760+
img_dims = img_dims["img_dims"]
761+
effective_len = min(len(img_dims), len(predictions), len(masks))
762+
763+
responses = []
764+
for ind in range(effective_len):
765+
batch_predictions = predictions[ind]
766+
batch_masks = masks[ind]
767+
preds_out = []
768+
for pred, mask in zip(batch_predictions, batch_masks):
769+
if class_filter and self.class_names[int(pred[6])] not in class_filter:
770+
continue
771+
preds_out.append(
772772
InstanceSegmentationPrediction(
773-
# Passing args as a dictionary here since one of the args is 'class' (a protected term in Python)
774773
**{
775774
"x": (pred[0] + pred[2]) / 2,
776775
"y": (pred[1] + pred[3]) / 2,
@@ -782,14 +781,13 @@ def make_response(
782781
"points": [Point(x=point[0], y=point[1]) for point in mask],
783782
}
784783
)
785-
for pred, mask in zip(batch_predictions, batch_mask_preds[ind])
786-
if not class_filter
787-
or self.class_names[int(pred[6])] in class_filter
788-
],
789-
image=InferenceResponseImage(
790-
width=img_dims[ind][1], height=img_dims[ind][0]
791-
),
784+
)
785+
responses.append(
786+
InstanceSegmentationInferenceResponse(
787+
predictions=preds_out,
788+
image=InferenceResponseImage(
789+
width=img_dims[ind][1], height=img_dims[ind][0]
790+
),
791+
)
792792
)
793-
for ind, batch_predictions in enumerate(predictions)
794-
]
795793
return responses

0 commit comments

Comments
 (0)