@@ -104,16 +104,10 @@ def horizontal_flip_bounding_boxes(
104104 bounding_boxes [:, 0 ::2 ].sub_ (canvas_size [1 ]).neg_ ()
105105 bounding_boxes = bounding_boxes [:, [2 , 3 , 0 , 1 , 6 , 7 , 4 , 5 ]]
106106 elif format == tv_tensors .BoundingBoxFormat .XYWHR :
107-
108- dtype = bounding_boxes .dtype
109- if not torch .is_floating_point (bounding_boxes ):
110- # Casting to float to support cos and sin computations.
111- bounding_boxes = bounding_boxes .to (torch .float32 )
112107 angle_rad = bounding_boxes [:, 4 ].mul (torch .pi ).div (180 )
113108 bounding_boxes [:, 0 ].add_ (bounding_boxes [:, 2 ].mul (angle_rad .cos ())).sub_ (canvas_size [1 ]).neg_ ()
114109 bounding_boxes [:, 1 ].sub_ (bounding_boxes [:, 2 ].mul (angle_rad .sin ()))
115110 bounding_boxes [:, 4 ].neg_ ()
116- bounding_boxes = bounding_boxes .to (dtype )
117111 else : # format == tv_tensors.BoundingBoxFormat.CXCYWHR:
118112 bounding_boxes [:, 0 ].sub_ (canvas_size [1 ]).neg_ ()
119113 bounding_boxes [:, 4 ].neg_ ()
@@ -192,15 +186,10 @@ def vertical_flip_bounding_boxes(
192186 bounding_boxes [:, 1 ::2 ].sub_ (canvas_size [0 ]).neg_ ()
193187 bounding_boxes = bounding_boxes [:, [2 , 3 , 0 , 1 , 6 , 7 , 4 , 5 ]]
194188 elif format == tv_tensors .BoundingBoxFormat .XYWHR :
195- dtype = bounding_boxes .dtype
196- if not torch .is_floating_point (bounding_boxes ):
197- # Casting to float to support cos and sin computations.
198- bounding_boxes = bounding_boxes .to (torch .float64 )
199189 angle_rad = bounding_boxes [:, 4 ].mul (torch .pi ).div (180 )
200190 bounding_boxes [:, 1 ].sub_ (bounding_boxes [:, 2 ].mul (angle_rad .sin ())).sub_ (canvas_size [0 ]).neg_ ()
201191 bounding_boxes [:, 0 ].add_ (bounding_boxes [:, 2 ].mul (angle_rad .cos ()))
202192 bounding_boxes [:, 4 ].neg_ ().add_ (180 )
203- bounding_boxes = bounding_boxes .to (dtype )
204193 else : # format == tv_tensors.BoundingBoxFormat.CXCYWHR:
205194 bounding_boxes [:, 1 ].sub_ (canvas_size [0 ]).neg_ ()
206195 bounding_boxes [:, 4 ].neg_ ().add_ (180 )
@@ -1102,9 +1091,8 @@ def _affine_bounding_boxes_with_expand(
11021091
11031092 original_shape = bounding_boxes .shape
11041093 dtype = bounding_boxes .dtype
1105- acceptable_dtypes = [torch .float64 ] # Ensure consistency between CPU and GPU.
1106- need_cast = dtype not in acceptable_dtypes
1107- bounding_boxes = bounding_boxes .to (torch .float64 ) if need_cast else bounding_boxes .clone ()
1094+ need_cast = not bounding_boxes .is_floating_point ()
1095+ bounding_boxes = bounding_boxes .float () if need_cast else bounding_boxes .clone ()
11081096 device = bounding_boxes .device
11091097 is_rotated = tv_tensors .is_rotated_bounding_format (format )
11101098 intermediate_format = tv_tensors .BoundingBoxFormat .XYXYXYXY if is_rotated else tv_tensors .BoundingBoxFormat .XYXY
0 commit comments