@@ -409,23 +409,17 @@ def _order_bounding_boxes_points(
409409 if indices is None :
410410 output_xyxyxyxy = bounding_boxes .reshape (- 1 , 8 )
411411 x , y = output_xyxyxyxy [..., 0 ::2 ], output_xyxyxyxy [..., 1 ::2 ]
412- y_max = torch .max (y , dim = 1 , keepdim = True )[0 ]
413- _ , x1 = (( y_max - y ) / y_max + (x + 1 ) * 100 ).min (dim = 1 )
412+ y_max = torch .max (y . abs () , dim = 1 , keepdim = True )[0 ]
413+ _ , x1 = (y / y_max + (x + 1 ) * 100 ).min (dim = 1 )
414414 indices = torch .ones_like (output_xyxyxyxy )
415415 indices [..., 0 ] = x1 .mul (2 )
416416 indices .cumsum_ (1 ).remainder_ (8 )
417417 return indices , bounding_boxes .gather (1 , indices .to (torch .int64 ))
418418
419419
420- def _area (box : torch .Tensor ) -> torch .Tensor :
421- x1 , y1 , x2 , y2 , x3 , y3 , x4 , y4 = box .reshape (- 1 , 8 ).unbind (- 1 )
422- w = torch .sqrt ((y2 - y1 ) ** 2 + (x2 - x1 ) ** 2 )
423- h = torch .sqrt ((y3 - y2 ) ** 2 + (x3 - x2 ) ** 2 )
424- return w * h
425-
426-
427420def _clamp_along_y_axis (
428421 bounding_boxes : torch .Tensor ,
422+ canvas_size : tuple [int , int ],
429423) -> torch .Tensor :
430424 """
431425 Adjusts bounding boxes along the y-axis based on specific conditions.
@@ -448,29 +442,33 @@ def _clamp_along_y_axis(
448442 b2 = y2 + x2 / a
449443 b3 = y3 - a * x3
450444 b4 = y4 + x4 / a
451- b23 = (b2 - b3 ) / 2 * a / (1 + a ** 2 )
452- z = torch .zeros_like (b1 )
453- case_a = torch .cat ([x .unsqueeze (1 ) for x in [z , b1 , x2 , y2 , x3 , y3 , x3 - x2 , y3 + b1 - y2 ]], dim = 1 )
454- case_b = torch .cat ([x .unsqueeze (1 ) for x in [z , b4 , x2 - x1 , y2 - y1 + b4 , x3 , y3 , x4 , y4 ]], dim = 1 )
455- case_c = torch .cat (
456- [x .unsqueeze (1 ) for x in [z , (b2 + b3 ) / 2 , b23 , - b23 / a + b2 , x3 , y3 , b23 , b23 * a + b3 ]], dim = 1
445+ c = a / (1 + a ** 2 )
446+ b1 = b2 .clamp (0 ).clamp (b1 , b3 )
447+ b4 = b3 .clamp (max = canvas_size [0 ]).clamp (b2 , b4 )
448+ case_a = torch .stack (
449+ (
450+ (b4 - b1 ) * c ,
451+ (b4 - b1 ) * c * a + b1 ,
452+ (b2 - b1 ) * c ,
453+ (b1 - b2 ) * c / a + b2 ,
454+ x3 ,
455+ y3 ,
456+ (b4 - b3 ) * c ,
457+ (b3 - b4 ) * c / a + b4 ,
458+ ),
459+ dim = - 1 ,
457460 )
458- case_d = torch .zeros_like (case_c )
459- case_e = torch .cat ([x .unsqueeze (1 ) for x in [x1 .clamp (0 ), y1 , x2 .clamp (0 ), y2 , x3 , y3 , x4 , y4 ]], dim = 1 )
460-
461- cond_a = (x1 < 0 ).logical_and (x2 >= 0 ).logical_and (x3 >= 0 ).logical_and (x4 >= 0 )
462- cond_a = cond_a .logical_and (_area (case_a ) > _area (case_b ))
463- cond_a = cond_a .logical_or ((x1 < 0 ).logical_and (x2 >= 0 ).logical_and (x3 >= 0 ).logical_and (x4 <= 0 ))
464- cond_b = (x1 < 0 ).logical_and (x2 >= 0 ).logical_and (x3 >= 0 ).logical_and (x4 >= 0 )
465- cond_b = cond_b .logical_and (_area (case_a ) <= _area (case_b ))
466- cond_b = cond_b .logical_or ((x1 < 0 ).logical_and (x2 <= 0 ).logical_and (x3 >= 0 ).logical_and (x4 >= 0 ))
467- cond_c = (x1 < 0 ).logical_and (x2 <= 0 ).logical_and (x3 >= 0 ).logical_and (x4 <= 0 )
468- cond_d = (x1 < 0 ).logical_and (x2 <= 0 ).logical_and (x3 <= 0 ).logical_and (x4 <= 0 )
469- cond_e = x1 .isclose (x2 )
470-
461+ case_b = bounding_boxes .clone ()
462+ case_b [..., 0 ].clamp_ (0 )
463+ case_b [..., 6 ].clamp_ (0 )
464+ case_c = torch .zeros_like (case_b )
465+
466+ cond_a = x1 < 0
467+ cond_b = y1 .isclose (y2 , rtol = 1e-05 , atol = 1e-05 )
468+ cond_c = (x1 <= 0 ).logical_and (x2 <= 0 ).logical_and (x3 <= 0 ).logical_and (x4 <= 0 )
471469 for cond , case in zip (
472- [cond_a , cond_b , cond_c , cond_d , cond_e ],
473- [case_a , case_b , case_c , case_d , case_e ],
470+ [cond_a , cond_b , cond_c ],
471+ [case_a , case_b , case_c ],
474472 ):
475473 bounding_boxes = torch .where (cond .unsqueeze (1 ).repeat (1 , 8 ), case .reshape (- 1 , 8 ), bounding_boxes )
476474 return bounding_boxes .to (original_dtype ).reshape (original_shape )
@@ -512,7 +510,7 @@ def _clamp_rotated_bounding_boxes(
512510
513511 for _ in range (4 ): # Iterate over the 4 vertices.
514512 indices , out_boxes = _order_bounding_boxes_points (out_boxes )
515- out_boxes = _clamp_along_y_axis (out_boxes )
513+ out_boxes = _clamp_along_y_axis (out_boxes , canvas_size )
516514 _ , out_boxes = _order_bounding_boxes_points (out_boxes , indices )
517515 # rotate 90 degrees counter clock wise
518516 out_boxes [:, ::2 ], out_boxes [:, 1 ::2 ] = (
0 commit comments