@@ -6877,18 +6877,39 @@ def test_no_valid_input(self, query):
68776877 query (["blah" ])
68786878
68796879 @pytest .mark .parametrize (
6880- "boxes" , [tv_tensors .BoundingBoxes (torch .tensor ([[1 , 1 , 2 , 2 ]]), format = "XYXY" , canvas_size = (4 , 4 ))]
6880+ "boxes" , [
6881+ tv_tensors .BoundingBoxes (torch .tensor ([[1. , 1. , 2. , 2. ]]), format = "XYXY" , canvas_size = (4 , 4 )), # [boxes0]
6882+ tv_tensors .BoundingBoxes (torch .tensor ([[1. , 1. , 1. , 1. ]]), format = "XYWH" , canvas_size = (4 , 4 )), # [boxes1]
6883+ tv_tensors .BoundingBoxes (torch .tensor ([[1.5 , 1.5 , 1. , 1. ]]), format = "CXCYWH" , canvas_size = (4 , 4 )), # [boxes2]
6884+ tv_tensors .BoundingBoxes (torch .tensor ([[1.5 , 1.5 , 1. , 1. , 45 ]]), format = "CXCYWHR" , canvas_size = (4 , 4 )), # [boxes3]
6885+ tv_tensors .BoundingBoxes (torch .tensor ([[1. , 1. , 1. , 1. , 45. ]]), format = "XYWHR" , canvas_size = (4 , 4 )), # [boxes4]
6886+ tv_tensors .BoundingBoxes (torch .tensor ([[1. , 1. , 1. , 2. , 2. , 2. , 2. , 1. ]]), format = "XY" * 4 , canvas_size = (4 , 4 )), # [boxes5]
6887+ ]
68816888 )
68826889 def test_convert_bounding_boxes_to_points (self , boxes : tv_tensors .BoundingBoxes ):
6883- # TODO: this test can't handle rotated boxes yet
68846890 kp = F .convert_bounding_boxes_to_points (boxes )
6885- assert kp .shape == boxes .shape + ( 2 , )
6891+ assert kp .shape == ( boxes .shape [ 0 ], 4 , 2 )
68866892 assert kp .dtype == boxes .dtype
68876893 # kp is a list of A, B, C, D polygons.
6888- # If we use A | C, we should get back the XYXY format of bounding box
6889- reconverted = torch .cat ([kp [..., 0 , :], kp [..., 2 , :]], dim = - 1 )
6890- reconverted_bbox = F .convert_bounding_box_format (
6891- tv_tensors .BoundingBoxes (reconverted , format = tv_tensors .BoundingBoxFormat .XYXY , canvas_size = kp .canvas_size ),
6892- new_format = boxes .format ,
6893- )
6894- assert (reconverted_bbox == boxes ).all (), f"Invalid reconversion : { reconverted_bbox } "
6894+
6895+ if F ._meta .is_rotated_bounding_box_format (boxes .format ):
6896+ # In the rotated case
6897+ # If we convert to XYXYXYXY format, we should get what we want.
6898+ reconverted = kp .reshape (- 1 , 8 )
6899+ reconverted_bbox = F .convert_bounding_box_format (
6900+ tv_tensors .BoundingBoxes (reconverted , format = tv_tensors .BoundingBoxFormat .XYXYXYXY , canvas_size = kp .canvas_size ),
6901+ new_format = boxes .format
6902+ )
6903+ assert ((reconverted_bbox - boxes ).abs () < 1e-5 ).all (), ( # Rotational computations mean that we can't ensure exactitude.
6904+ f"Invalid reconversion :\n \t Got: { reconverted_bbox } \n \t From: { boxes } \n \t "
6905+ f"Diff: { reconverted_bbox - boxes } "
6906+ )
6907+ else :
6908+ # In the unrotated case
6909+ # If we use A | C, we should get back the XYXY format of bounding box
6910+ reconverted = torch .cat ([kp [..., 0 , :], kp [..., 2 , :]], dim = - 1 )
6911+ reconverted_bbox = F .convert_bounding_box_format (
6912+ tv_tensors .BoundingBoxes (reconverted , format = tv_tensors .BoundingBoxFormat .XYXY , canvas_size = kp .canvas_size ),
6913+ new_format = boxes .format ,
6914+ )
6915+ assert (reconverted_bbox == boxes ).all (), f"Invalid reconversion :\n \t Got: { reconverted_bbox } \n \t From: { boxes } "
0 commit comments