@@ -381,29 +381,32 @@ def _resize_mask_dispatch(
381381 return tv_tensors .wrap (output , like = inpt )
382382
383383
384- def _parallelogram_to_bounding_boxes (parallelogram : torch .Tensor , inplace : bool = False ) -> torch .Tensor :
384+ def _parallelogram_to_bounding_boxes (parallelogram : torch .Tensor ) -> torch .Tensor :
385385 """
386386 Convert a parallelogram to a rectangle while keeping the points (x1, y1) and (x3, y3) unchanged.
387387
388388 This function transforms a parallelogram represented by 8 coordinates (4 points) into a rectangle.
389389 The first point (x1, y1) and the third point (x3, y3) of the parallelogram remain fixed,
390390 while the second and fourth points are adjusted to form a proper rectangle.
391391
392+ Note:
393+ This function is not applied in-place and will return a copy of the input tensor.
394+
392395 Args:
393396 parallelogram (torch.Tensor): Tensor of shape (..., 8) containing coordinates of parallelograms.
394397 Format is [x1, y1, x2, y2, x3, y3, x4, y4].
395- inplace (bool, optional): If True, performs operation in-place. Default is False.
396398
397399 Returns:
398400 torch.Tensor: Tensor of same shape as input containing the rectangle coordinates.
399401 The output maintains the same dtype as the input.
400402 """
401- if not inplace :
402- parallelogram = parallelogram .clone ()
403-
404403 dtype = parallelogram .dtype
405- if not torch .is_floating_point (parallelogram ):
406- parallelogram = parallelogram .float ()
404+ int_dtype = dtype in (torch .uint8 ,
405+ torch .int8 ,
406+ torch .int16 ,
407+ torch .int32 ,
408+ torch .int64 ,
409+ )
407410
408411 # Calculate diagonal vector from first to third point
409412 dx = parallelogram [..., 4 ] - parallelogram [..., 0 ]
@@ -417,21 +420,28 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor, inplace: bool
417420 # Calculate width using the angle between diagonal and rotation
418421 w = diag * torch .abs (torch .sin (torch .atan2 (dx , dy ) - r_rad ))
419422
423+ delta_x = torch .round (w * cos ).to (dtype ) if int_dtype else w * cos
424+ detla_y = torch .round (w * sin ).to (dtype ) if int_dtype else w * sin
425+
420426 # Update coordinates to form a rectangle
421- parallelogram [..., 2 ] = parallelogram [..., 0 ] + w * cos
422- parallelogram [..., 3 ] = parallelogram [..., 1 ] - w * sin
423- parallelogram [..., 6 ] = parallelogram [..., 4 ] - w * cos
424- parallelogram [..., 7 ] = parallelogram [..., 5 ] + w * sin
425- return parallelogram . to ( dtype )
427+ parallelogram [..., 2 ] = parallelogram [..., 0 ] + delta_x
428+ parallelogram [..., 3 ] = parallelogram [..., 1 ] - detla_y
429+ parallelogram [..., 6 ] = parallelogram [..., 4 ] - delta_x
430+ parallelogram [..., 7 ] = parallelogram [..., 5 ] + detla_y
431+ return parallelogram
426432
427433
428434def resize_bounding_boxes (
429435 bounding_boxes : torch .Tensor ,
430- format : tv_tensors .BoundingBoxFormat ,
431436 canvas_size : tuple [int , int ],
432437 size : Optional [list [int ]],
433438 max_size : Optional [int ] = None ,
439+ format : tv_tensors .BoundingBoxFormat = tv_tensors .BoundingBoxFormat .XYXY ,
434440) -> tuple [torch .Tensor , tuple [int , int ]]:
441+ # We set the default format as `tv_tensors.BoundingBoxFormat.XYXY`
442+ # to ensure backward compatibility.
443+ # Indeed before the introduction of rotated bounding box format
444+ # this function did not received `format` parameter as input.
435445 old_height , old_width = canvas_size
436446 new_height , new_width = _compute_resized_output_size (canvas_size , size = size , max_size = max_size )
437447
@@ -893,12 +903,9 @@ def _affine_bounding_boxes_with_expand(
893903 bounding_boxes = bounding_boxes .clone () if bounding_boxes .is_floating_point () else bounding_boxes .float ()
894904 dtype = bounding_boxes .dtype
895905 device = bounding_boxes .device
896- intermediate_format = (
897- tv_tensors .BoundingBoxFormat .XYXYXYXY
898- if tv_tensors .is_rotated_bounding_format (format )
899- else tv_tensors .BoundingBoxFormat .XYXY
900- )
901- intermediate_shape = 8 if tv_tensors .is_rotated_bounding_format (format ) else 4
906+ is_rotated = tv_tensors .is_rotated_bounding_format (format )
907+ intermediate_format = tv_tensors .BoundingBoxFormat .XYXYXYXY if is_rotated else tv_tensors .BoundingBoxFormat .XYXY
908+ intermediate_shape = 8 if is_rotated else 4
902909 bounding_boxes = (
903910 convert_bounding_box_format (bounding_boxes , old_format = format , new_format = intermediate_format , inplace = True )
904911 ).reshape (- 1 , intermediate_shape )
@@ -925,7 +932,7 @@ def _affine_bounding_boxes_with_expand(
925932 # Tensor of points has shape (N * 4, 3), where N is the number of bboxes
926933 # Single point structure is similar to
927934 # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
928- if tv_tensors . is_rotated_bounding_format ( format ) :
935+ if is_rotated :
929936 points = bounding_boxes .reshape (- 1 , 2 )
930937 else :
931938 points = bounding_boxes [:, [[0 , 1 ], [2 , 1 ], [2 , 3 ], [0 , 3 ]]].reshape (- 1 , 2 )
@@ -934,7 +941,7 @@ def _affine_bounding_boxes_with_expand(
934941 transformed_points = torch .matmul (points , transposed_affine_matrix )
935942 # 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
936943 # and compute bounding box from 4 transformed points:
937- if tv_tensors . is_rotated_bounding_format ( format ) :
944+ if is_rotated :
938945 transformed_points = transformed_points .reshape (- 1 , 8 )
939946 out_bboxes = _parallelogram_to_bounding_boxes (transformed_points )
940947 else :
@@ -1557,6 +1564,9 @@ def crop_bounding_boxes(
15571564 bounding_boxes = bounding_boxes - torch .tensor (sub , dtype = bounding_boxes .dtype , device = bounding_boxes .device )
15581565 canvas_size = (height , width )
15591566
1567+ if format == tv_tensors .BoundingBoxFormat .XYXYXYXY :
1568+ bounding_boxes = _parallelogram_to_bounding_boxes (bounding_boxes )
1569+
15601570 return clamp_bounding_boxes (bounding_boxes , format = format , canvas_size = canvas_size ), canvas_size
15611571
15621572
0 commit comments