@@ -488,6 +488,84 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
488488        return  self ._call_kernel (F .pad , inpt , padding = self .padding , fill = fill , padding_mode = self .padding_mode )  # type: ignore[arg-type] 
489489
490490
491+ class  PadToSquare (Transform ):
492+     """Pad a non-square input to make it square by padding the shorter side to match the longer side. 
493+ 
494+     Args: 
495+         fill (number or tuple or dict, optional): Pixel fill value used when the  ``padding_mode`` is constant. 
496+             Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. 
497+             Fill value can be also a dictionary mapping data type to the fill value, e.g. 
498+             ``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and 
499+             ``Mask`` will be filled with 0. 
500+         padding_mode (str, optional): Type of padding. Should be: constant, edge, reflect or symmetric. 
501+             Default is "constant". 
502+ 
503+             - constant: pads with a constant value, this value is specified with fill 
504+ 
505+             - edge: pads with the last value at the edge of the image. 
506+ 
507+             - reflect: pads with reflection of image without repeating the last value on the edge. 
508+               For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode 
509+               will result in [3, 2, 1, 2, 3, 4, 3, 2] 
510+ 
511+             - symmetric: pads with reflection of image repeating the last value on the edge. 
512+               For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode 
513+               will result in [2, 1, 1, 2, 3, 4, 4, 3] 
514+ 
515+     Example: 
516+         >>> import torch 
517+         >>> from torchvision.transforms.v2 import PadToSquare 
518+         >>> rectangular_image = torch.randint(0, 255, (3, 224, 168), dtype=torch.uint8) 
519+         >>> transform = PadToSquare(padding_mode='constant', fill=0) 
520+         >>> square_image = transform(rectangular_image) 
521+         >>> print(square_image.size()) 
522+         torch.Size([3, 224, 224]) 
523+     """ 
524+ 
525+     def  __init__ (
526+         self ,
527+         fill : Union [_FillType , Dict [Union [Type , str ], _FillType ]] =  0 ,
528+         padding_mode : Literal ["constant" , "edge" , "reflect" , "symmetric" ] =  "constant" ,
529+     ):
530+         super ().__init__ ()
531+ 
532+         _check_padding_mode_arg (padding_mode )
533+ 
534+         if  padding_mode  not  in "constant" , "edge" , "reflect" , "symmetric" ]:
535+             raise  ValueError ("`padding_mode` must be one of 'constant', 'edge', 'reflect' or 'symmetric'." )
536+         self .padding_mode  =  padding_mode 
537+         self .fill  =  _setup_fill_arg (fill )
538+ 
539+     def  _get_params (self , flat_inputs : List [Any ]) ->  Dict [str , Any ]:
540+         # Get the original height and width from the inputs 
541+         orig_height , orig_width  =  query_size (flat_inputs )
542+ 
543+         # Find the target size (maximum of height and width) 
544+         target_size  =  max (orig_height , orig_width )
545+ 
546+         if  orig_height  <  target_size :
547+             # Need to pad height 
548+             pad_height  =  target_size  -  orig_height 
549+             pad_top  =  pad_height  //  2 
550+             pad_bottom  =  pad_height  -  pad_top 
551+             pad_left  =  0 
552+             pad_right  =  0 
553+         else :
554+             # Need to pad width 
555+             pad_width  =  target_size  -  orig_width 
556+             pad_left  =  pad_width  //  2 
557+             pad_right  =  pad_width  -  pad_left 
558+             pad_top  =  0 
559+             pad_bottom  =  0 
560+ 
561+         # The padding needs to be in the format [left, top, right, bottom] 
562+         return  dict (padding = [pad_left , pad_top , pad_right , pad_bottom ])
563+ 
564+     def  _transform (self , inpt : Any , params : Dict [str , Any ]) ->  Any :
565+         fill  =  _get_fill (self .fill , type (inpt ))
566+         return  self ._call_kernel (F .pad , inpt , padding = params ["padding" ], padding_mode = self .padding_mode , fill = fill )
567+ 
568+ 
491569class  RandomZoomOut (_RandomApplyTransform ):
492570    """ "Zoom out" transformation from 
493571    `"SSD: Single Shot MultiBox Detector" <https://arxiv.org/abs/1512.02325>`_. 
0 commit comments