@@ -360,6 +360,16 @@ def _clamp_bounding_boxes(
360360 return out_boxes .to (in_dtype )
361361
362362
363+ def _clamp_rotated_bounding_boxes (
364+ bounding_boxes : torch .Tensor , format : BoundingBoxFormat , canvas_size : tuple [int , int ]
365+ ) -> torch .Tensor :
366+ # TODO: For now we are not clamping rotated bounding boxes.
367+ in_dtype = bounding_boxes .dtype
368+ out_boxes = bounding_boxes .clone () if bounding_boxes .is_floating_point () else bounding_boxes .float ()
369+
370+ return out_boxes .to (in_dtype )
371+
372+
363373def clamp_bounding_boxes (
364374 inpt : torch .Tensor ,
365375 format : Optional [BoundingBoxFormat ] = None ,
@@ -373,11 +383,21 @@ def clamp_bounding_boxes(
373383
374384 if format is None or canvas_size is None :
375385 raise ValueError ("For pure tensor inputs, `format` and `canvas_size` have to be passed." )
376- return _clamp_bounding_boxes (inpt , format = format , canvas_size = canvas_size )
386+ if tv_tensors .is_rotated_bounding_format (format ):
387+ return _clamp_rotated_bounding_boxes (inpt , format = format , canvas_size = canvas_size )
388+ else :
389+ return _clamp_bounding_boxes (inpt , format = format , canvas_size = canvas_size )
377390 elif isinstance (inpt , tv_tensors .BoundingBoxes ):
378391 if format is not None or canvas_size is not None :
379392 raise ValueError ("For bounding box tv_tensor inputs, `format` and `canvas_size` must not be passed." )
380- output = _clamp_bounding_boxes (inpt .as_subclass (torch .Tensor ), format = inpt .format , canvas_size = inpt .canvas_size )
393+ if tv_tensors .is_rotated_bounding_format (inpt .format ):
394+ output = _clamp_rotated_bounding_boxes (
395+ inpt .as_subclass (torch .Tensor ), format = inpt .format , canvas_size = inpt .canvas_size
396+ )
397+ else :
398+ output = _clamp_bounding_boxes (
399+ inpt .as_subclass (torch .Tensor ), format = inpt .format , canvas_size = inpt .canvas_size
400+ )
381401 return tv_tensors .wrap (output , like = inpt )
382402 else :
383403 raise TypeError (
0 commit comments