@@ -194,11 +194,6 @@ def _cxcywhr_to_xywhr(cxcywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
194194    if  not  inplace :
195195        cxcywhr  =  cxcywhr .clone ()
196196
197-     dtype  =  cxcywhr .dtype 
198-     need_cast  =  not  cxcywhr .is_floating_point ()
199-     if  need_cast :
200-         cxcywhr  =  cxcywhr .float ()
201- 
202197    half_wh  =  cxcywhr [..., 2 :- 1 ].div (- 2 , rounding_mode = None  if  cxcywhr .is_floating_point () else  "floor" ).abs_ ()
203198    r_rad  =  cxcywhr [..., 4 ].mul (torch .pi ).div (180.0 )
204199    cos , sin  =  r_rad .cos (), r_rad .sin ()
@@ -207,22 +202,13 @@ def _cxcywhr_to_xywhr(cxcywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
207202    # (cy + width / 2 * sin - height / 2 * cos) = y1 
208203    cxcywhr [..., 1 ].add_ (half_wh [..., 0 ].mul (sin )).sub_ (half_wh [..., 1 ].mul (cos ))
209204
210-     if  need_cast :
211-         cxcywhr .round_ ()
212-         cxcywhr  =  cxcywhr .to (dtype )
213- 
214205    return  cxcywhr 
215206
216207
217208def  _xywhr_to_cxcywhr (xywhr : torch .Tensor , inplace : bool ) ->  torch .Tensor :
218209    if  not  inplace :
219210        xywhr  =  xywhr .clone ()
220211
221-     dtype  =  xywhr .dtype 
222-     need_cast  =  not  xywhr .is_floating_point ()
223-     if  need_cast :
224-         xywhr  =  xywhr .float ()
225- 
226212    half_wh  =  xywhr [..., 2 :- 1 ].div (- 2 , rounding_mode = None  if  xywhr .is_floating_point () else  "floor" ).abs_ ()
227213    r_rad  =  xywhr [..., 4 ].mul (torch .pi ).div (180.0 )
228214    cos , sin  =  r_rad .cos (), r_rad .sin ()
@@ -231,10 +217,6 @@ def _xywhr_to_cxcywhr(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
231217    # (y1 - width / 2 * sin + height / 2 * cos) = cy 
232218    xywhr [..., 1 ].sub_ (half_wh [..., 0 ].mul (sin )).add_ (half_wh [..., 1 ].mul (cos ))
233219
234-     if  need_cast :
235-         xywhr .round_ ()
236-         xywhr  =  xywhr .to (dtype )
237- 
238220    return  xywhr 
239221
240222
@@ -243,11 +225,6 @@ def _xywhr_to_xyxyxyxy(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
243225    if  not  inplace :
244226        xywhr  =  xywhr .clone ()
245227
246-     dtype  =  xywhr .dtype 
247-     need_cast  =  not  xywhr .is_floating_point ()
248-     if  need_cast :
249-         xywhr  =  xywhr .float ()
250- 
251228    wh  =  xywhr [..., 2 :- 1 ]
252229    r_rad  =  xywhr [..., 4 ].mul (torch .pi ).div (180.0 )
253230    cos , sin  =  r_rad .cos (), r_rad .sin ()
@@ -265,10 +242,6 @@ def _xywhr_to_xyxyxyxy(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
265242    # y1 + h * cos = y4 
266243    xywhr [..., 7 ].add_ (wh [..., 1 ].mul (cos ))
267244
268-     if  need_cast :
269-         xywhr .round_ ()
270-         xywhr  =  xywhr .to (dtype )
271- 
272245    return  xywhr 
273246
274247
@@ -278,9 +251,11 @@ def _xyxyxyxy_to_xywhr(xyxyxyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
278251        xyxyxyxy  =  xyxyxyxy .clone ()
279252
280253    dtype  =  xyxyxyxy .dtype 
281-     need_cast  =  not  xyxyxyxy .is_floating_point ()
254+     acceptable_dtypes  =  [torch .float32 , torch .float64 ]  # Ensure consistency between CPU and GPU. 
255+     need_cast  =  dtype  not  in acceptable_dtypes 
282256    if  need_cast :
283-         xyxyxyxy  =  xyxyxyxy .float ()
257+         # Up-case to avoid overflow for square operations 
258+         xyxyxyxy  =  xyxyxyxy .to (torch .float32 )
284259
285260    r_rad  =  torch .atan2 (xyxyxyxy [..., 1 ].sub (xyxyxyxy [..., 3 ]), xyxyxyxy [..., 2 ].sub (xyxyxyxy [..., 0 ]))
286261    # x1, y1, (x2 - x1), (y2 - y1), (x3 - x2), (y3 - y2) x4, y4 
@@ -293,7 +268,6 @@ def _xyxyxyxy_to_xywhr(xyxyxyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
293268    xyxyxyxy [..., 4 ] =  r_rad .div_ (torch .pi ).mul_ (180.0 )
294269
295270    if  need_cast :
296-         xyxyxyxy .round_ ()
297271        xyxyxyxy  =  xyxyxyxy .to (dtype )
298272
299273    return  xyxyxyxy [..., :5 ]
0 commit comments