@@ -43,6 +43,72 @@ def test_bbox_instance(data, format):
4343 assert bboxes .format == format
4444
4545
46+ @pytest .mark .parametrize (
47+ "format" ,
48+ [
49+ "XYXY" ,
50+ "XYWH" ,
51+ "CXCYWH" ,
52+ "XYXYXYXY" ,
53+ "XYWHR" ,
54+ "CXCYWHR" ,
55+ tv_tensors .BoundingBoxFormat .XYXY ,
56+ tv_tensors .BoundingBoxFormat .XYWH ,
57+ tv_tensors .BoundingBoxFormat .CXCYWH ,
58+ tv_tensors .BoundingBoxFormat .XYXYXYXY ,
59+ tv_tensors .BoundingBoxFormat .XYWHR ,
60+ tv_tensors .BoundingBoxFormat .CXCYWHR ,
61+ ],
62+ )
63+ def test_bbox_format (format ):
64+ if isinstance (format , str ):
65+ format = tv_tensors .BoundingBoxFormat [(format .upper ())]
66+ if format == tv_tensors .BoundingBoxFormat .XYXYXYXY :
67+ assert tv_tensors .is_rotated_bounding_format (format ) is True
68+ elif format == tv_tensors .BoundingBoxFormat .XYWHR :
69+ assert tv_tensors .is_rotated_bounding_format (format ) is True
70+ elif format == tv_tensors .BoundingBoxFormat .CXCYWHR :
71+ assert tv_tensors .is_rotated_bounding_format (format ) is True
72+ else :
73+ assert tv_tensors .is_rotated_bounding_format (format ) is False
74+
75+
76+ @pytest .mark .parametrize (
77+ "format" ,
78+ [
79+ "XYXY" ,
80+ "XYWH" ,
81+ "CXCYWH" ,
82+ "XYXYXYXY" ,
83+ "XYWHR" ,
84+ "CXCYWHR" ,
85+ tv_tensors .BoundingBoxFormat .XYXY ,
86+ tv_tensors .BoundingBoxFormat .XYWH ,
87+ tv_tensors .BoundingBoxFormat .CXCYWH ,
88+ tv_tensors .BoundingBoxFormat .XYXYXYXY ,
89+ tv_tensors .BoundingBoxFormat .XYWHR ,
90+ tv_tensors .BoundingBoxFormat .CXCYWHR ,
91+ ],
92+ )
93+ def test_bbox_format_scripted (format ):
94+ obj = tv_tensors .is_rotated_bounding_format
95+ try :
96+ fn = torch .jit .script (obj )
97+ except Exception as error :
98+ name = getattr (obj , "__name__" , obj .__class__ .__name__ )
99+ raise AssertionError (f"Trying to `torch.jit.script` `{ name } ` raised the error above." ) from error
100+ if isinstance (format , str ):
101+ format = tv_tensors .BoundingBoxFormat [(format .upper ())]
102+ if format == tv_tensors .BoundingBoxFormat .XYXYXYXY :
103+ assert fn (format ) is True
104+ elif format == tv_tensors .BoundingBoxFormat .XYWHR :
105+ assert fn (format ) is True
106+ elif format == tv_tensors .BoundingBoxFormat .CXCYWHR :
107+ assert fn (format ) is True
108+ else :
109+ assert fn (format ) is False
110+
111+
46112def test_bbox_dim_error ():
47113 data_3d = [[[1 , 2 , 3 , 4 ]]]
48114 with pytest .raises (ValueError , match = "Expected a 1D or 2D tensor, got 3D" ):
0 commit comments