@@ -194,11 +194,6 @@ def _cxcywhr_to_xywhr(cxcywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
194
194
if not inplace :
195
195
cxcywhr = cxcywhr .clone ()
196
196
197
- dtype = cxcywhr .dtype
198
- need_cast = not cxcywhr .is_floating_point ()
199
- if need_cast :
200
- cxcywhr = cxcywhr .float ()
201
-
202
197
half_wh = cxcywhr [..., 2 :- 1 ].div (- 2 , rounding_mode = None if cxcywhr .is_floating_point () else "floor" ).abs_ ()
203
198
r_rad = cxcywhr [..., 4 ].mul (torch .pi ).div (180.0 )
204
199
cos , sin = r_rad .cos (), r_rad .sin ()
@@ -207,22 +202,13 @@ def _cxcywhr_to_xywhr(cxcywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
207
202
# (cy + width / 2 * sin - height / 2 * cos) = y1
208
203
cxcywhr [..., 1 ].add_ (half_wh [..., 0 ].mul (sin )).sub_ (half_wh [..., 1 ].mul (cos ))
209
204
210
- if need_cast :
211
- cxcywhr .round_ ()
212
- cxcywhr = cxcywhr .to (dtype )
213
-
214
205
return cxcywhr
215
206
216
207
217
208
def _xywhr_to_cxcywhr (xywhr : torch .Tensor , inplace : bool ) -> torch .Tensor :
218
209
if not inplace :
219
210
xywhr = xywhr .clone ()
220
211
221
- dtype = xywhr .dtype
222
- need_cast = not xywhr .is_floating_point ()
223
- if need_cast :
224
- xywhr = xywhr .float ()
225
-
226
212
half_wh = xywhr [..., 2 :- 1 ].div (- 2 , rounding_mode = None if xywhr .is_floating_point () else "floor" ).abs_ ()
227
213
r_rad = xywhr [..., 4 ].mul (torch .pi ).div (180.0 )
228
214
cos , sin = r_rad .cos (), r_rad .sin ()
@@ -231,10 +217,6 @@ def _xywhr_to_cxcywhr(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
231
217
# (y1 - width / 2 * sin + height / 2 * cos) = cy
232
218
xywhr [..., 1 ].sub_ (half_wh [..., 0 ].mul (sin )).add_ (half_wh [..., 1 ].mul (cos ))
233
219
234
- if need_cast :
235
- xywhr .round_ ()
236
- xywhr = xywhr .to (dtype )
237
-
238
220
return xywhr
239
221
240
222
@@ -243,11 +225,6 @@ def _xywhr_to_xyxyxyxy(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
243
225
if not inplace :
244
226
xywhr = xywhr .clone ()
245
227
246
- dtype = xywhr .dtype
247
- need_cast = not xywhr .is_floating_point ()
248
- if need_cast :
249
- xywhr = xywhr .float ()
250
-
251
228
wh = xywhr [..., 2 :- 1 ]
252
229
r_rad = xywhr [..., 4 ].mul (torch .pi ).div (180.0 )
253
230
cos , sin = r_rad .cos (), r_rad .sin ()
@@ -265,10 +242,6 @@ def _xywhr_to_xyxyxyxy(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
265
242
# y1 + h * cos = y4
266
243
xywhr [..., 7 ].add_ (wh [..., 1 ].mul (cos ))
267
244
268
- if need_cast :
269
- xywhr .round_ ()
270
- xywhr = xywhr .to (dtype )
271
-
272
245
return xywhr
273
246
274
247
@@ -278,9 +251,11 @@ def _xyxyxyxy_to_xywhr(xyxyxyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
278
251
xyxyxyxy = xyxyxyxy .clone ()
279
252
280
253
dtype = xyxyxyxy .dtype
281
- need_cast = not xyxyxyxy .is_floating_point ()
254
+ acceptable_dtypes = [torch .float32 , torch .float64 ] # Ensure consistency between CPU and GPU.
255
+ need_cast = dtype not in acceptable_dtypes
282
256
if need_cast :
283
- xyxyxyxy = xyxyxyxy .float ()
257
+ # Up-case to avoid overflow for square operations
258
+ xyxyxyxy = xyxyxyxy .to (torch .float32 )
284
259
285
260
r_rad = torch .atan2 (xyxyxyxy [..., 1 ].sub (xyxyxyxy [..., 3 ]), xyxyxyxy [..., 2 ].sub (xyxyxyxy [..., 0 ]))
286
261
# x1, y1, (x2 - x1), (y2 - y1), (x3 - x2), (y3 - y2) x4, y4
@@ -293,7 +268,6 @@ def _xyxyxyxy_to_xywhr(xyxyxyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
293
268
xyxyxyxy [..., 4 ] = r_rad .div_ (torch .pi ).mul_ (180.0 )
294
269
295
270
if need_cast :
296
- xyxyxyxy .round_ ()
297
271
xyxyxyxy = xyxyxyxy .to (dtype )
298
272
299
273
return xyxyxyxy [..., :5 ]
0 commit comments