2121from torch .testing ._comparison import BooleanPair , NonePair , not_close_error_metas , NumberPair , TensorLikePair
2222from torchvision import io , tv_tensors
2323from torchvision .transforms ._functional_tensor import _max_value as get_max_value
24- from torchvision .transforms .v2 .functional import to_image , to_pil_image
24+ from torchvision .transforms .v2 .functional import clamp_bounding_boxes , to_image , to_pil_image
2525
2626
2727IN_OSS_CI = any (os .getenv (var ) == "true" for var in ["CIRCLECI" , "GITHUB_ACTIONS" ])
@@ -400,6 +400,12 @@ def make_image_pil(*args, **kwargs):
400400 return to_pil_image (make_image (* args , ** kwargs ))
401401
402402
403+ def make_keypoints (canvas_size = DEFAULT_SIZE , * , num_points = 4 , dtype = None , device = "cpu" ):
404+ y = torch .randint (0 , canvas_size [0 ], size = (num_points , 1 ), dtype = dtype , device = device )
405+ x = torch .randint (0 , canvas_size [1 ], size = (num_points , 1 ), dtype = dtype , device = device )
406+ return tv_tensors .KeyPoints (torch .cat ((x , y ), dim = - 1 ), canvas_size = canvas_size )
407+
408+
403409def make_bounding_boxes (
404410 canvas_size = DEFAULT_SIZE ,
405411 * ,
@@ -417,6 +423,13 @@ def sample_position(values, max_value):
417423 format = tv_tensors .BoundingBoxFormat [format ]
418424
419425 dtype = dtype or torch .float32
426+ int_dtype = dtype in (
427+ torch .uint8 ,
428+ torch .int8 ,
429+ torch .int16 ,
430+ torch .int32 ,
431+ torch .int64 ,
432+ )
420433
421434 h , w = (torch .randint (1 , s , (num_boxes ,)) for s in canvas_size )
422435 y = sample_position (h , canvas_size [0 ])
@@ -443,20 +456,31 @@ def sample_position(values, max_value):
443456 elif format is tv_tensors .BoundingBoxFormat .XYXYXYXY :
444457 r_rad = r * torch .pi / 180.0
445458 cos , sin = torch .cos (r_rad ), torch .sin (r_rad )
446- x1 , y1 = x , y
447- x3 = x1 + w * cos
448- y3 = y1 - w * sin
449- x2 = x3 + h * sin
450- y2 = y3 + h * cos
451- x4 = x1 + h * sin
452- y4 = y1 + h * cos
453- parts = (x1 , y1 , x3 , y3 , x2 , y2 , x4 , y4 )
459+ x1 = torch .round (x ) if int_dtype else x
460+ y1 = torch .round (y ) if int_dtype else y
461+ x2 = torch .round (x1 + w * cos ) if int_dtype else x1 + w * cos
462+ y2 = torch .round (y1 - w * sin ) if int_dtype else y1 - w * sin
463+ x3 = torch .round (x2 + h * sin ) if int_dtype else x2 + h * sin
464+ y3 = torch .round (y2 + h * cos ) if int_dtype else y2 + h * cos
465+ x4 = torch .round (x1 + h * sin ) if int_dtype else x1 + h * sin
466+ y4 = torch .round (y1 + h * cos ) if int_dtype else y1 + h * cos
467+ parts = (x1 , y1 , x2 , y2 , x3 , y3 , x4 , y4 )
454468 else :
455469 raise ValueError (f"Format { format } is not supported" )
456-
457- return tv_tensors .BoundingBoxes (
458- torch .stack (parts , dim = - 1 ).to (dtype = dtype , device = device ), format = format , canvas_size = canvas_size
459- )
470+ out_boxes = torch .stack (parts , dim = - 1 ).to (dtype = dtype , device = device )
471+ if tv_tensors .is_rotated_bounding_format (format ):
472+ # The rotated bounding boxes are not guaranteed to be within the canvas by design,
473+ # so we apply clamping. We also add a 2 buffer to the canvas size to avoid
474+ # numerical issues during the testing
475+ buffer = 4
476+ out_boxes = clamp_bounding_boxes (
477+ out_boxes , format = format , canvas_size = (canvas_size [0 ] - buffer , canvas_size [1 ] - buffer )
478+ )
479+ if format is tv_tensors .BoundingBoxFormat .XYWHR or format is tv_tensors .BoundingBoxFormat .CXCYWHR :
480+ out_boxes [:, :2 ] += buffer // 2
481+ elif format is tv_tensors .BoundingBoxFormat .XYXYXYXY :
482+ out_boxes [:, :] += buffer // 2
483+ return tv_tensors .BoundingBoxes (out_boxes , format = format , canvas_size = canvas_size )
460484
461485
462486def make_detection_masks (size = DEFAULT_SIZE , * , num_masks = 1 , dtype = None , device = "cpu" ):
0 commit comments