1+ from __future__ import division
12import torch
23import torchvision .transforms .functional as F
4+ from torch import Tensor
5+ from torch .jit .annotations import Optional , List , BroadcastingList2 , Tuple
36
47
5- def vflip (img_tensor ):
8+ def _is_tensor_a_torch_image (input ):
9+ return len (input .shape ) == 3
10+
11+
12+ def vflip (img ):
13+ # type: (Tensor) -> Tensor
614 """Vertically flip the given the Image Tensor.
715
816 Args:
9- img_tensor (Tensor): Image Tensor to be flipped in the form [C, H, W].
17+ img (Tensor): Image Tensor to be flipped in the form [C, H, W].
1018
1119 Returns:
1220 Tensor: Vertically flipped image Tensor.
1321 """
14- if not F . _is_tensor_image ( img_tensor ):
22+ if not _is_tensor_a_torch_image ( img ):
1523 raise TypeError ('tensor is not a torch image.' )
1624
17- return img_tensor .flip (- 2 )
25+ return img .flip (- 2 )
1826
1927
20- def hflip (img_tensor ):
28+ def hflip (img ):
29+ # type: (Tensor) -> Tensor
2130 """Horizontally flip the given the Image Tensor.
2231
2332 Args:
24- img_tensor (Tensor): Image Tensor to be flipped in the form [C, H, W].
33+ img (Tensor): Image Tensor to be flipped in the form [C, H, W].
2534
2635 Returns:
2736 Tensor: Horizontally flipped image Tensor.
2837 """
29- if not F . _is_tensor_image ( img_tensor ):
38+ if not _is_tensor_a_torch_image ( img ):
3039 raise TypeError ('tensor is not a torch image.' )
3140
32- return img_tensor .flip (- 1 )
41+ return img .flip (- 1 )
3342
3443
3544def crop (img , top , left , height , width ):
45+ # type: (Tensor, int, int, int, int) -> Tensor
3646 """Crop the given Image Tensor.
3747
3848 Args:
@@ -45,13 +55,14 @@ def crop(img, top, left, height, width):
4555 Returns:
4656 Tensor: Cropped image.
4757 """
48- if not F . _is_tensor_image (img ):
58+ if not _is_tensor_a_torch_image (img ):
4959 raise TypeError ('tensor is not a torch image.' )
5060
5161 return img [..., top :top + height , left :left + width ]
5262
5363
5464def rgb_to_grayscale (img ):
65+ # type: (Tensor) -> Tensor
5566 """Convert the given RGB Image Tensor to Grayscale.
5667 For RGB to Grayscale conversion, ITU-R 601-2 luma transform is performed which
5768 is L = R * 0.2989 + G * 0.5870 + B * 0.1140
@@ -70,6 +81,7 @@ def rgb_to_grayscale(img):
7081
7182
7283def adjust_brightness (img , brightness_factor ):
84+ # type: (Tensor, float) -> Tensor
7385 """Adjust brightness of an RGB image.
7486
7587 Args:
@@ -81,13 +93,14 @@ def adjust_brightness(img, brightness_factor):
8193 Returns:
8294 Tensor: Brightness adjusted image.
8395 """
84- if not F . _is_tensor_image (img ):
96+ if not _is_tensor_a_torch_image (img ):
8597 raise TypeError ('tensor is not a torch image.' )
8698
87- return _blend (img , 0 , brightness_factor )
99+ return _blend (img , torch . zeros_like ( img ) , brightness_factor )
88100
89101
90102def adjust_contrast (img , contrast_factor ):
103+ # type: (Tensor, float) -> Tensor
91104 """Adjust contrast of an RGB image.
92105
93106 Args:
@@ -99,7 +112,7 @@ def adjust_contrast(img, contrast_factor):
99112 Returns:
100113 Tensor: Contrast adjusted image.
101114 """
102- if not F . _is_tensor_image (img ):
115+ if not _is_tensor_a_torch_image (img ):
103116 raise TypeError ('tensor is not a torch image.' )
104117
105118 mean = torch .mean (rgb_to_grayscale (img ).to (torch .float ))
@@ -108,6 +121,7 @@ def adjust_contrast(img, contrast_factor):
108121
109122
110123def adjust_saturation (img , saturation_factor ):
124+ # type: (Tensor, float) -> Tensor
111125 """Adjust color saturation of an RGB image.
112126
113127 Args:
@@ -119,13 +133,14 @@ def adjust_saturation(img, saturation_factor):
119133 Returns:
120134 Tensor: Saturation adjusted image.
121135 """
122- if not F . _is_tensor_image (img ):
136+ if not _is_tensor_a_torch_image (img ):
123137 raise TypeError ('tensor is not a torch image.' )
124138
125139 return _blend (img , rgb_to_grayscale (img ), saturation_factor )
126140
127141
128142def center_crop (img , output_size ):
143+ # type: (Tensor, BroadcastingList2[int]) -> Tensor
129144 """Crop the Image Tensor and resize it to desired size.
130145
131146 Args:
@@ -136,7 +151,7 @@ def center_crop(img, output_size):
136151 Returns:
137152 Tensor: Cropped image.
138153 """
139- if not F . _is_tensor_image (img ):
154+ if not _is_tensor_a_torch_image (img ):
140155 raise TypeError ('tensor is not a torch image.' )
141156
142157 _ , image_width , image_height = img .size ()
@@ -148,9 +163,10 @@ def center_crop(img, output_size):
148163
149164
150165def five_crop (img , size ):
166+ # type: (Tensor, BroadcastingList2[int]) -> List[Tensor]
151167 """Crop the given Image Tensor into four corners and the central crop.
152168 .. Note::
153- This transform returns a tuple of Tensors and there may be a
169+ This transform returns a List of Tensors and there may be a
154170 mismatch in the number of inputs and targets your ``Dataset`` returns.
155171
156172 Args:
@@ -159,10 +175,10 @@ def five_crop(img, size):
159175 made.
160176
161177 Returns:
162- tuple: tuple (tl, tr, bl, br, center)
178+ List: List (tl, tr, bl, br, center)
163179 Corresponding top left, top right, bottom left, bottom right and center crop.
164180 """
165- if not F . _is_tensor_image (img ):
181+ if not _is_tensor_a_torch_image (img ):
166182 raise TypeError ('tensor is not a torch image.' )
167183
168184 assert len (size ) == 2 , "Please provide only two dimensions (h, w) for size."
@@ -179,14 +195,15 @@ def five_crop(img, size):
179195 br = crop (img , image_width - crop_width , image_height - crop_height , image_width , image_height )
180196 center = center_crop (img , (crop_height , crop_width ))
181197
182- return ( tl , tr , bl , br , center )
198+ return [ tl , tr , bl , br , center ]
183199
184200
185201def ten_crop (img , size , vertical_flip = False ):
202+ # type: (Tensor, BroadcastingList2[int], bool) -> List[Tensor]
186203 """Crop the given Image Tensor into four corners and the central crop plus the
187204 flipped version of these (horizontal flipping is used by default).
188205 .. Note::
189- This transform returns a tuple of images and there may be a
206+ This transform returns a List of images and there may be a
190207 mismatch in the number of inputs and targets your ``Dataset`` returns.
191208
192209 Args:
@@ -196,11 +213,11 @@ def ten_crop(img, size, vertical_flip=False):
196213 vertical_flip (bool): Use vertical flipping instead of horizontal
197214
198215 Returns:
199- tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
216+ List: List (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
200217 Corresponding top left, top right, bottom left, bottom right and center crop
201218 and same for the flipped image's tensor.
202219 """
203- if not F . _is_tensor_image (img ):
220+ if not _is_tensor_a_torch_image (img ):
204221 raise TypeError ('tensor is not a torch image.' )
205222
206223 assert len (size ) == 2 , "Please provide only two dimensions (h, w) for size."
@@ -217,5 +234,6 @@ def ten_crop(img, size, vertical_flip=False):
217234
218235
219236def _blend (img1 , img2 , ratio ):
220- bound = 1 if img1 .dtype .is_floating_point else 255
237+ # type: (Tensor, Tensor, float) -> Tensor
238+ bound = 1 if img1 .dtype in [torch .half , torch .float32 , torch .float64 ] else 255
221239 return (ratio * img1 + (1 - ratio ) * img2 ).clamp (0 , bound ).to (img1 .dtype )
0 commit comments