@@ -194,11 +194,6 @@ def _cxcywhr_to_xywhr(cxcywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
194194 if not inplace :
195195 cxcywhr = cxcywhr .clone ()
196196
197- dtype = cxcywhr .dtype
198- need_cast = not cxcywhr .is_floating_point ()
199- if need_cast :
200- cxcywhr = cxcywhr .float ()
201-
202197 half_wh = cxcywhr [..., 2 :- 1 ].div (- 2 , rounding_mode = None if cxcywhr .is_floating_point () else "floor" ).abs_ ()
203198 r_rad = cxcywhr [..., 4 ].mul (torch .pi ).div (180.0 )
204199 cos , sin = r_rad .cos (), r_rad .sin ()
@@ -207,22 +202,13 @@ def _cxcywhr_to_xywhr(cxcywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
207202 # (cy + width / 2 * sin - height / 2 * cos) = y1
208203 cxcywhr [..., 1 ].add_ (half_wh [..., 0 ].mul (sin )).sub_ (half_wh [..., 1 ].mul (cos ))
209204
210- if need_cast :
211- cxcywhr .round_ ()
212- cxcywhr = cxcywhr .to (dtype )
213-
214205 return cxcywhr
215206
216207
217208def _xywhr_to_cxcywhr (xywhr : torch .Tensor , inplace : bool ) -> torch .Tensor :
218209 if not inplace :
219210 xywhr = xywhr .clone ()
220211
221- dtype = xywhr .dtype
222- need_cast = not xywhr .is_floating_point ()
223- if need_cast :
224- xywhr = xywhr .float ()
225-
226212 half_wh = xywhr [..., 2 :- 1 ].div (- 2 , rounding_mode = None if xywhr .is_floating_point () else "floor" ).abs_ ()
227213 r_rad = xywhr [..., 4 ].mul (torch .pi ).div (180.0 )
228214 cos , sin = r_rad .cos (), r_rad .sin ()
@@ -231,10 +217,6 @@ def _xywhr_to_cxcywhr(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
231217 # (y1 - width / 2 * sin + height / 2 * cos) = cy
232218 xywhr [..., 1 ].sub_ (half_wh [..., 0 ].mul (sin )).add_ (half_wh [..., 1 ].mul (cos ))
233219
234- if need_cast :
235- xywhr .round_ ()
236- xywhr = xywhr .to (dtype )
237-
238220 return xywhr
239221
240222
@@ -243,11 +225,6 @@ def _xywhr_to_xyxyxyxy(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
243225 if not inplace :
244226 xywhr = xywhr .clone ()
245227
246- dtype = xywhr .dtype
247- need_cast = not xywhr .is_floating_point ()
248- if need_cast :
249- xywhr = xywhr .float ()
250-
251228 wh = xywhr [..., 2 :- 1 ]
252229 r_rad = xywhr [..., 4 ].mul (torch .pi ).div (180.0 )
253230 cos , sin = r_rad .cos (), r_rad .sin ()
@@ -265,10 +242,6 @@ def _xywhr_to_xyxyxyxy(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
265242 # y1 + h * cos = y4
266243 xywhr [..., 7 ].add_ (wh [..., 1 ].mul (cos ))
267244
268- if need_cast :
269- xywhr .round_ ()
270- xywhr = xywhr .to (dtype )
271-
272245 return xywhr
273246
274247
@@ -278,7 +251,11 @@ def _xyxyxyxy_to_xywhr(xyxyxyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
278251 xyxyxyxy = xyxyxyxy .clone ()
279252
280253 dtype = xyxyxyxy .dtype
281- need_cast = not xyxyxyxy .is_floating_point ()
254+ acceptable_dtypes = [torch .float32 ] # Ensure consistency between CPU and GPU.
255+ need_cast = dtype not in acceptable_dtypes
256+ if need_cast :
257+ # Up-case to avoid overflow for square operations
258+ xyxyxyxy = xyxyxyxy .to (torch .float32 )
282259 if need_cast :
283260 xyxyxyxy = xyxyxyxy .float ()
284261
@@ -293,7 +270,6 @@ def _xyxyxyxy_to_xywhr(xyxyxyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
293270 xyxyxyxy [..., 4 ] = r_rad .div_ (torch .pi ).mul_ (180.0 )
294271
295272 if need_cast :
296- xyxyxyxy .round_ ()
297273 xyxyxyxy = xyxyxyxy .to (dtype )
298274
299275 return xyxyxyxy [..., :5 ]
0 commit comments