@@ -352,14 +352,122 @@ def _clamp_bounding_boxes(
352352 return out_boxes .to (in_dtype )
353353
354354
355+ def _order_bounding_boxes_points (
356+ bounding_boxes : torch .Tensor , indices : torch .Tensor | None = None
357+ ) -> tuple [torch .Tensor , torch .Tensor ]:
358+ """Re-order points in bounding boxes based on specific criteria or provided indices.
359+
360+ This function reorders the points of bounding boxes either according to provided indices or
361+ by a default ordering strategy. In the default strategy, (x1, y1) corresponds to the point
362+ with the lowest x value. If multiple points have the same lowest x value, the point with the
363+ lowest y value is chosen.
364+
365+ Args:
366+ bounding_boxes (torch.Tensor): A tensor containing bounding box coordinates in format [x1, y1, x2, y2, x3, y3, x4, y4].
367+ indices (torch.Tensor | None): Optional tensor containing indices for reordering. If None, default ordering is applied.
368+
369+ Returns:
370+ tuple[torch.Tensor, torch.Tensor]: A tuple containing:
371+ - indices: The indices used for reordering
372+ - reordered_boxes: The bounding boxes with reordered points
373+ """
374+ if indices is None :
375+ output_xyxyxyxy = bounding_boxes .clone ().reshape (- 1 , 8 )
376+ x , y = output_xyxyxyxy [..., 0 ::2 ], output_xyxyxyxy [..., 1 ::2 ]
377+ y_max = torch .max (y , dim = 1 , keepdim = True )[0 ]
378+ _ , x1 = (y_max - y ).div (y_max ).add (x .add (1 ).mul (100 )).min (dim = 1 )
379+ indices = torch .ones_like (output_xyxyxyxy )
380+ indices [..., 0 ] = x1 .mul (2 )
381+ indices .cumsum_ (1 ).remainder_ (8 )
382+ return indices , bounding_boxes .gather (1 , indices .to (torch .int64 ))
383+
384+
385+ def area (box : torch .Tensor ) -> torch .Tensor :
386+ x1 , y1 , x2 , y2 , x3 , y3 , x4 , y4 = box .clone ().reshape (- 1 , 8 ).unbind (- 1 )
387+ w = (y2 - y1 ) ** 2 + (x2 - x1 ) ** 2
388+ h = (y3 - y2 ) ** 2 + (x3 - x2 ) ** 2
389+ return w * h
390+
391+
392+ def _clamp_along_y_axis (
393+ bounding_boxes : torch .Tensor ,
394+ ) -> torch .Tensor :
395+ """
396+ Adjusts bounding boxes along the y-axis based on specific conditions.
397+
398+ This function modifies the bounding boxes by evaluating different cases
399+ and applying the appropriate transformation to ensure the bounding boxes
400+ are clamped correctly along the y-axis.
401+
402+ Args:
403+ bounding_boxes (torch.Tensor): A tensor containing bounding box coordinates.
404+
405+ Returns:
406+ torch.Tensor: The adjusted bounding boxes.
407+ """
408+ original_dtype = bounding_boxes .dtype
409+ original_shape = bounding_boxes .shape
410+ x1 , y1 , x2 , y2 , x3 , y3 , x4 , y4 = bounding_boxes .reshape (- 1 , 8 ).unbind (- 1 )
411+ a = (y2 - y1 ) / (x2 - x1 )
412+ b1 = y1 - a * x1
413+ b2 = y2 + x2 / a
414+ b3 = y3 - a * x3
415+ b4 = y4 + x4 / a
416+ b23 = (b2 - b3 ) / 2 * a / (1 + a ** 2 )
417+ z = torch .zeros_like (b1 )
418+ case_a = torch .cat ([x .unsqueeze (1 ) for x in [z , b1 , x2 , y2 , x3 , y3 , x3 - x2 , y3 + b1 - y2 ]], dim = 1 )
419+ case_b = torch .cat ([x .unsqueeze (1 ) for x in [z , b4 , x2 - x1 , y2 - y1 + b4 , x3 , y3 , x4 , y4 ]], dim = 1 )
420+ case_c = torch .cat (
421+ [x .unsqueeze (1 ) for x in [z , (b2 + b3 ) / 2 , b23 , - b23 / a + b2 , x3 , y3 , b23 , b23 * a + b3 ]], dim = 1
422+ )
423+ case_d = torch .zeros_like (case_c )
424+
425+ cond_a = x1 .lt (0 ).logical_and (x2 .ge (0 )).logical_and (x3 .ge (0 )).logical_and (x4 .ge (0 ))
426+ cond_a = cond_a .logical_and (area (case_a ) > area (case_b ))
427+ cond_a = cond_a .logical_or (x1 .lt (0 ).logical_and (x2 .ge (0 )).logical_and (x3 .ge (0 )).logical_and (x4 .le (0 )))
428+ cond_b = x1 .lt (0 ).logical_and (x2 .ge (0 )).logical_and (x3 .ge (0 )).logical_and (x4 .ge (0 ))
429+ cond_b = cond_b .logical_and (area (case_a ) <= area (case_b ))
430+ cond_b = cond_b .logical_or (x1 .lt (0 ).logical_and (x2 .le (0 )).logical_and (x3 .ge (0 )).logical_and (x4 .ge (0 )))
431+ cond_c = x1 .lt (0 ).logical_and (x2 .le (0 )).logical_and (x3 .ge (0 )).logical_and (x4 .le (0 ))
432+ cond_d = x1 .lt (0 ).logical_and (x2 .le (0 )).logical_and (x3 .le (0 )).logical_and (x4 .le (0 ))
433+
434+ for cond , case in zip (
435+ [cond_a , cond_b , cond_c , cond_d ],
436+ [case_a , case_b , case_c , case_d ],
437+ ):
438+ bounding_boxes = torch .where (cond .unsqueeze (1 ).repeat (1 , 8 ), case .reshape (- 1 , 8 ), bounding_boxes )
439+ return bounding_boxes .to (original_dtype ).reshape (original_shape )
440+
441+
355442def _clamp_rotated_bounding_boxes (
356443 bounding_boxes : torch .Tensor , format : BoundingBoxFormat , canvas_size : tuple [int , int ]
357444) -> torch .Tensor :
358- # TODO: For now we are not clamping rotated bounding boxes.
359- in_dtype = bounding_boxes .dtype
360- out_boxes = bounding_boxes .clone () if bounding_boxes .is_floating_point () else bounding_boxes .float ()
445+ original_shape = bounding_boxes .shape
446+ original_dtype = bounding_boxes .dtype
447+ bounding_boxes = bounding_boxes .clone () if bounding_boxes .is_floating_point () else bounding_boxes .float ()
448+ out_boxes = (
449+ convert_bounding_box_format (
450+ bounding_boxes , old_format = format , new_format = tv_tensors .BoundingBoxFormat .XYXYXYXY , inplace = True
451+ )
452+ ).reshape (- 1 , 8 )
453+
454+ for _ in range (4 ):
455+ indices , out_boxes = _order_bounding_boxes_points (out_boxes )
456+ out_boxes = _clamp_along_y_axis (out_boxes )
457+ _ , out_boxes = _order_bounding_boxes_points (out_boxes , indices )
458+ # rotate 90 degrees counter clock wise
459+ out_boxes [:, ::2 ], out_boxes [:, 1 ::2 ] = (
460+ out_boxes [:, 1 ::2 ].clone (),
461+ canvas_size [1 ] - out_boxes [:, ::2 ].clone (),
462+ )
463+ canvas_size = (canvas_size [1 ], canvas_size [0 ])
361464
362- return out_boxes .to (in_dtype )
465+ out_boxes = convert_bounding_box_format (
466+ out_boxes , old_format = tv_tensors .BoundingBoxFormat .XYXYXYXY , new_format = format , inplace = True
467+ ).reshape (original_shape )
468+
469+ out_boxes = out_boxes .to (original_dtype )
470+ return out_boxes
363471
364472
365473def clamp_bounding_boxes (
0 commit comments