@@ -181,7 +181,8 @@ def _cxcywhr_to_xywhr(cxcywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
181181 cxcywhr = cxcywhr .clone ()
182182
183183 dtype = cxcywhr .dtype
184- if not cxcywhr .is_floating_point ():
184+ need_cast = not cxcywhr .is_floating_point ()
185+ if need_cast :
185186 cxcywhr = cxcywhr .float ()
186187
187188 half_wh = cxcywhr [..., 2 :- 1 ].div (- 2 , rounding_mode = None if cxcywhr .is_floating_point () else "floor" ).abs_ ()
@@ -192,15 +193,20 @@ def _cxcywhr_to_xywhr(cxcywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
192193 # (cy + width / 2 * sin - height / 2 * cos) = y1
193194 cxcywhr [..., 1 ].add_ (half_wh [..., 0 ].mul (sin )).sub_ (half_wh [..., 1 ].mul (cos ))
194195
195- return cxcywhr .to (dtype )
196+ if need_cast :
197+ cxcywhr .round_ ()
198+ cxcywhr = cxcywhr .to (dtype )
199+
200+ return cxcywhr
196201
197202
198203def _xywhr_to_cxcywhr (xywhr : torch .Tensor , inplace : bool ) -> torch .Tensor :
199204 if not inplace :
200205 xywhr = xywhr .clone ()
201206
202207 dtype = xywhr .dtype
203- if not xywhr .is_floating_point ():
208+ need_cast = not xywhr .is_floating_point ()
209+ if need_cast :
204210 xywhr = xywhr .float ()
205211
206212 half_wh = xywhr [..., 2 :- 1 ].div (- 2 , rounding_mode = None if xywhr .is_floating_point () else "floor" ).abs_ ()
@@ -211,7 +217,11 @@ def _xywhr_to_cxcywhr(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
211217 # (y1 - width / 2 * sin + height / 2 * cos) = cy
212218 xywhr [..., 1 ].sub_ (half_wh [..., 0 ].mul (sin )).add_ (half_wh [..., 1 ].mul (cos ))
213219
214- return xywhr .to (dtype )
220+ if need_cast :
221+ xywhr .round_ ()
222+ xywhr = xywhr .to (dtype )
223+
224+ return xywhr
215225
216226
217227def _xywhr_to_xyxyxyxy (xywhr : torch .Tensor , inplace : bool ) -> torch .Tensor :
@@ -220,7 +230,8 @@ def _xywhr_to_xyxyxyxy(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
220230 xywhr = xywhr .clone ()
221231
222232 dtype = xywhr .dtype
223- if not xywhr .is_floating_point ():
233+ need_cast = not xywhr .is_floating_point ()
234+ if need_cast :
224235 xywhr = xywhr .float ()
225236
226237 wh = xywhr [..., 2 :- 1 ]
@@ -239,7 +250,12 @@ def _xywhr_to_xyxyxyxy(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
239250 xywhr [..., 6 ].add_ (wh [..., 1 ].mul (sin ))
240251 # y1 + h * cos = y4
241252 xywhr [..., 7 ].add_ (wh [..., 1 ].mul (cos ))
242- return xywhr .to (dtype )
253+
254+ if need_cast :
255+ xywhr .round_ ()
256+ xywhr = xywhr .to (dtype )
257+
258+ return xywhr
243259
244260
245261def _xyxyxyxy_to_xywhr (xyxyxyxy : torch .Tensor , inplace : bool ) -> torch .Tensor :
@@ -248,7 +264,8 @@ def _xyxyxyxy_to_xywhr(xyxyxyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
248264 xyxyxyxy = xyxyxyxy .clone ()
249265
250266 dtype = xyxyxyxy .dtype
251- if not xyxyxyxy .is_floating_point ():
267+ need_cast = not xyxyxyxy .is_floating_point ()
268+ if need_cast :
252269 xyxyxyxy = xyxyxyxy .float ()
253270
254271 r_rad = torch .atan2 (xyxyxyxy [..., 1 ].sub (xyxyxyxy [..., 3 ]), xyxyxyxy [..., 2 ].sub (xyxyxyxy [..., 0 ]))
@@ -260,7 +277,12 @@ def _xyxyxyxy_to_xywhr(xyxyxyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
260277 # sqrt((x2 - x3) ** 2 + (y2 - y3) ** 2) = h
261278 xyxyxyxy [..., 3 ] = xyxyxyxy [..., 4 ].pow (2 ).add (xyxyxyxy [..., 5 ].pow (2 )).sqrt ()
262279 xyxyxyxy [..., 4 ] = r_rad .div_ (torch .pi ).mul_ (180.0 )
263- return xyxyxyxy [..., :5 ].to (dtype )
280+
281+ if need_cast :
282+ xyxyxyxy .round_ ()
283+ xyxyxyxy = xyxyxyxy .to (dtype )
284+
285+ return xyxyxyxy [..., :5 ]
264286
265287
266288def _convert_bounding_box_format (
@@ -423,14 +445,14 @@ def _clamp_along_y_axis(
423445 case_d = torch .zeros_like (case_c )
424446 case_e = torch .cat ([x .unsqueeze (1 ) for x in [x1 .clamp (0 ), y1 , x2 .clamp (0 ), y2 , x3 , y3 , x4 , y4 ]], dim = 1 )
425447
426- cond_a = x1 . lt ( 0 ).logical_and (x2 . ge ( 0 )) .logical_and (x3 . ge ( 0 )) .logical_and (x4 . ge ( 0 ) )
448+ cond_a = ( x1 < 0 ).logical_and (x2 >= 0 ) .logical_and (x3 >= 0 ) .logical_and (x4 >= 0 )
427449 cond_a = cond_a .logical_and (_area (case_a ) > _area (case_b ))
428- cond_a = cond_a .logical_or (x1 . lt ( 0 ).logical_and (x2 . ge ( 0 )) .logical_and (x3 . ge ( 0 )) .logical_and (x4 . le ( 0 ) ))
429- cond_b = x1 . lt ( 0 ).logical_and (x2 . ge ( 0 )) .logical_and (x3 . ge ( 0 )) .logical_and (x4 . ge ( 0 ) )
450+ cond_a = cond_a .logical_or (( x1 < 0 ).logical_and (x2 >= 0 ) .logical_and (x3 >= 0 ) .logical_and (x4 <= 0 ))
451+ cond_b = ( x1 < 0 ).logical_and (x2 >= 0 ) .logical_and (x3 >= 0 ) .logical_and (x4 >= 0 )
430452 cond_b = cond_b .logical_and (_area (case_a ) <= _area (case_b ))
431- cond_b = cond_b .logical_or (x1 . lt ( 0 ).logical_and (x2 . le ( 0 )) .logical_and (x3 . ge ( 0 )) .logical_and (x4 . ge ( 0 ) ))
432- cond_c = x1 . lt ( 0 ).logical_and (x2 . le ( 0 )) .logical_and (x3 . ge ( 0 )) .logical_and (x4 . le ( 0 ) )
433- cond_d = x1 . lt ( 0 ).logical_and (x2 . le ( 0 )) .logical_and (x3 . le ( 0 )) .logical_and (x4 . le ( 0 ) )
453+ cond_b = cond_b .logical_or (( x1 < 0 ).logical_and (x2 <= 0 ) .logical_and (x3 >= 0 ) .logical_and (x4 >= 0 ))
454+ cond_c = ( x1 < 0 ).logical_and (x2 <= 0 ) .logical_and (x3 >= 0 ) .logical_and (x4 <= 0 )
455+ cond_d = ( x1 < 0 ).logical_and (x2 <= 0 ) .logical_and (x3 <= 0 ) .logical_and (x4 <= 0 )
434456 cond_e = x1 .isclose (x2 )
435457
436458 for cond , case in zip (
@@ -465,15 +487,17 @@ def _clamp_rotated_bounding_boxes(
465487 torch.Tensor: Clamped bounding boxes in the original format and shape
466488 """
467489 original_shape = bounding_boxes .shape
468- original_dtype = bounding_boxes .dtype
469- bounding_boxes = bounding_boxes .clone () if bounding_boxes .is_floating_point () else bounding_boxes .float ()
490+ dtype = bounding_boxes .dtype
491+ acceptable_dtypes = [torch .float64 ] # Ensure consistency between CPU and GPU.
492+ need_cast = dtype not in acceptable_dtypes
493+ bounding_boxes = bounding_boxes .to (torch .float64 ) if need_cast else bounding_boxes .clone ()
470494 out_boxes = (
471495 convert_bounding_box_format (
472496 bounding_boxes , old_format = format , new_format = tv_tensors .BoundingBoxFormat .XYXYXYXY , inplace = True
473497 )
474498 ).reshape (- 1 , 8 )
475499
476- for _ in range (4 ):
500+ for _ in range (4 ): # Iterate over the 4 vertices.
477501 indices , out_boxes = _order_bounding_boxes_points (out_boxes )
478502 out_boxes = _clamp_along_y_axis (out_boxes )
479503 _ , out_boxes = _order_bounding_boxes_points (out_boxes , indices )
@@ -488,7 +512,10 @@ def _clamp_rotated_bounding_boxes(
488512 out_boxes , old_format = tv_tensors .BoundingBoxFormat .XYXYXYXY , new_format = format , inplace = True
489513 ).reshape (original_shape )
490514
491- out_boxes = out_boxes .to (original_dtype )
515+ if need_cast :
516+ if dtype in (torch .uint8 , torch .int8 , torch .int16 , torch .int32 , torch .int64 ):
517+ out_boxes .round_ ()
518+ out_boxes = out_boxes .to (dtype )
492519 return out_boxes
493520
494521
0 commit comments