@@ -591,6 +591,7 @@ def postprocess(
591591 img_dims = preproc_return_metadata ["img_dims" ]
592592
593593 processed_predictions = []
594+ processed_masks = []
594595
595596 for batch_idx in range (batch_size ):
596597 orig_h , orig_w = img_dims [batch_idx ]
@@ -623,37 +624,6 @@ def postprocess(
623624
624625 selected_boxes = bboxes [batch_idx , topk_boxes ]
625626 selected_masks = masks [batch_idx , topk_boxes ]
626- if selected_masks .size != 0 :
627- if kwargs .get ("mask_decode_mode" , "accurate" ) == "accurate" :
628- target_res = (orig_w , orig_h )
629- new_masks = []
630- for mask in selected_masks :
631- new_masks .append (
632- cv2 .resize (mask , target_res , interpolation = cv2 .INTER_LINEAR )
633- )
634- selected_masks = np .stack (new_masks , axis = 0 )
635- elif kwargs .get ("mask_decode_mode" , "accurate" ) == "tradeoff" :
636- tradeoff_factor = kwargs .get ("tradeoff_factor" , 0.0 )
637- mask_res = (selected_masks .shape [2 ], selected_masks .shape [1 ])
638- full_res = (orig_w , orig_h )
639- target_res = (
640- int (
641- mask_res [0 ] * (1 - tradeoff_factor )
642- + full_res [0 ] * tradeoff_factor
643- ),
644- int (
645- mask_res [1 ] * (1 - tradeoff_factor )
646- + full_res [1 ] * tradeoff_factor
647- ),
648- )
649- new_masks = []
650- for mask in selected_masks :
651- new_masks .append (
652- cv2 .resize (mask , target_res , interpolation = cv2 .INTER_LINEAR )
653- )
654- selected_masks = np .stack (new_masks , axis = 0 )
655-
656- selected_masks = selected_masks > 0
657627
658628 cxcy = selected_boxes [:, :2 ]
659629 wh = selected_boxes [:, 2 :]
@@ -700,58 +670,59 @@ def postprocess(
700670 topk_labels ,
701671 )
702672 )
703- batch_predictions = batch_predictions [
704- batch_predictions [:, 6 ] < len (self .class_names )
705- ]
706- selected_masks = selected_masks [
707- batch_predictions [:, 6 ] < len (self .class_names )
708- ]
709-
710- outputs = []
711- for pred , mask in zip (batch_predictions , selected_masks ):
712- outputs .append (list (pred ) + [mask ])
713-
714- processed_predictions .append (outputs )
715-
716- res = self .make_response (processed_predictions , img_dims , ** kwargs )
717- return res
718-
719- def make_response (
720- self ,
721- predictions : List [List [float ]],
722- img_dims : List [Tuple [int , int ]],
723- class_filter : Optional [List [str ]] = None ,
724- * args ,
725- ** kwargs ,
726- ) -> List [ObjectDetectionInferenceResponse ]:
727- """Constructs object detection response objects based on predictions.
728-
729- Args:
730- predictions (List[List[float]]): The list of predictions.
731- img_dims (List[Tuple[int, int]]): Dimensions of the images.
732- class_filter (Optional[List[str]]): A list of class names to filter, if provided.
733-
734- Returns:
735- List[ObjectDetectionInferenceResponse]: A list of response objects containing object detection predictions.
736- """
737-
738- if isinstance (img_dims , dict ) and "img_dims" in img_dims :
739- img_dims = img_dims ["img_dims" ]
740-
741- predictions = predictions [
742- : len (img_dims )
743- ] # If the batch size was fixed we have empty preds at the end
744-
745- batch_mask_preds = []
746- for image_ind in range (len (img_dims )):
747- masks = [pred [7 ] for pred in predictions [image_ind ]]
748- orig_h , orig_w = img_dims [image_ind ]
749-
750- mask_preds = []
751- for mask in masks :
752- points = mask2poly (mask .astype (np .uint8 ))
673+ valid_pred_mask = batch_predictions [:, 6 ] < len (self .class_names )
674+
675+ outputs_predictions = []
676+ outputs_polygons = []
677+ class_filter_local = kwargs .get ("class_filter" )
678+ for i , pred in enumerate (batch_predictions ):
679+ if not valid_pred_mask [i ]:
680+ continue
681+ # Early class filtering to avoid unnecessary mask processing
682+ if class_filter_local :
683+ try :
684+ pred_class_name = self .class_names [int (pred [6 ])]
685+ except Exception :
686+ continue
687+ if pred_class_name not in class_filter_local :
688+ continue
689+ mask = selected_masks [i ]
690+ # Per-mask optional upscaling for better polygon quality without retaining all high-res masks
691+ mask_decode_mode = kwargs .get ("mask_decode_mode" , "accurate" )
692+ if mask_decode_mode == "accurate" :
693+ target_res = (orig_w , orig_h )
694+ if mask .shape [1 ] != target_res [0 ] or mask .shape [0 ] != target_res [1 ]:
695+ mask = cv2 .resize (
696+ mask .astype (np .float32 ),
697+ target_res ,
698+ interpolation = cv2 .INTER_LINEAR ,
699+ )
700+ elif mask_decode_mode == "tradeoff" :
701+ tradeoff_factor = kwargs .get ("tradeoff_factor" , 0.0 )
702+ mask_res = (mask .shape [1 ], mask .shape [0 ]) # (w, h)
703+ full_res = (orig_w , orig_h ) # (w, h)
704+ target_res = (
705+ int (
706+ mask_res [0 ] * (1 - tradeoff_factor )
707+ + full_res [0 ] * tradeoff_factor
708+ ),
709+ int (
710+ mask_res [1 ] * (1 - tradeoff_factor )
711+ + full_res [1 ] * tradeoff_factor
712+ ),
713+ )
714+ if mask .shape [1 ] != target_res [0 ] or mask .shape [0 ] != target_res [1 ]:
715+ mask = cv2 .resize (
716+ mask .astype (np .float32 ),
717+ target_res ,
718+ interpolation = cv2 .INTER_LINEAR ,
719+ )
720+ # Ensure binary for polygonization
721+ mask_bin = (mask > 0 ).astype (np .uint8 )
722+ points = mask2poly (mask_bin )
723+ # Scale polygon points back to original image coordinates if needed
753724 new_points = []
754- prediction_h , prediction_w = mask .shape [0 ], mask .shape [1 ]
725+ prediction_h , prediction_w = mask_bin .shape [0 ], mask_bin .shape [1 ]
755726 for point in points :
756727 if self .resize_method == "Stretch to" :
757728 new_x = point [0 ] * (orig_w / prediction_w )
@@ -763,14 +734,42 @@ def make_response(
763734 new_x = point [0 ] * scale + pad_x
764735 new_y = point [1 ] * scale + pad_y
765736 new_points .append (np .array ([new_x , new_y ]))
766- mask_preds .append (new_points )
767- batch_mask_preds .append (mask_preds )
737+ outputs_polygons .append (new_points )
738+ outputs_predictions .append (list ( pred ) )
768739
769- responses = [
770- InstanceSegmentationInferenceResponse (
771- predictions = [
740+ processed_predictions .append (outputs_predictions )
741+ processed_masks .append (outputs_polygons )
742+
743+ res = self .make_response (
744+ processed_predictions , processed_masks , img_dims , ** kwargs
745+ )
746+ return res
747+
748+ def make_response (
749+ self ,
750+ predictions : List [List [List [float ]]],
751+ masks : List [List [List [np .ndarray ]]],
752+ img_dims : List [Tuple [int , int ]],
753+ class_filter : Optional [List [str ]] = None ,
754+ * args ,
755+ ** kwargs ,
756+ ) -> List [InstanceSegmentationInferenceResponse ]:
757+ """Constructs instance segmentation response objects from preprocessed predictions and polygons."""
758+ # Align to actual number of real images; predictions/masks may include padded slots
759+ if isinstance (img_dims , dict ) and "img_dims" in img_dims :
760+ img_dims = img_dims ["img_dims" ]
761+ effective_len = min (len (img_dims ), len (predictions ), len (masks ))
762+
763+ responses = []
764+ for ind in range (effective_len ):
765+ batch_predictions = predictions [ind ]
766+ batch_masks = masks [ind ]
767+ preds_out = []
768+ for pred , mask in zip (batch_predictions , batch_masks ):
769+ if class_filter and self .class_names [int (pred [6 ])] not in class_filter :
770+ continue
771+ preds_out .append (
772772 InstanceSegmentationPrediction (
773- # Passing args as a dictionary here since one of the args is 'class' (a protected term in Python)
774773 ** {
775774 "x" : (pred [0 ] + pred [2 ]) / 2 ,
776775 "y" : (pred [1 ] + pred [3 ]) / 2 ,
@@ -782,14 +781,13 @@ def make_response(
782781 "points" : [Point (x = point [0 ], y = point [1 ]) for point in mask ],
783782 }
784783 )
785- for pred , mask in zip (batch_predictions , batch_mask_preds [ind ])
786- if not class_filter
787- or self .class_names [int (pred [6 ])] in class_filter
788- ],
789- image = InferenceResponseImage (
790- width = img_dims [ind ][1 ], height = img_dims [ind ][0 ]
791- ),
784+ )
785+ responses .append (
786+ InstanceSegmentationInferenceResponse (
787+ predictions = preds_out ,
788+ image = InferenceResponseImage (
789+ width = img_dims [ind ][1 ], height = img_dims [ind ][0 ]
790+ ),
791+ )
792792 )
793- for ind , batch_predictions in enumerate (predictions )
794- ]
795793 return responses
0 commit comments