Skip to content

Commit 10f3416

Browse files
surgan12fmassa
authored andcommitted
Scriptability checks for Tensor Transforms (#1690)
* scriptability checks * tests upds * linter upds * linter upds * upds * tuple list changes * linter updates
1 parent 900c88c commit 10f3416

File tree

2 files changed

+84
-26
lines changed

2 files changed

+84
-26
lines changed

test/test_functional_tensor.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,45 @@
11
from __future__ import division
22
import torch
3+
from torch import Tensor
34
import torchvision.transforms as transforms
45
import torchvision.transforms.functional_tensor as F_t
56
import torchvision.transforms.functional as F
67
import numpy as np
78
import unittest
89
import random
10+
from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple
911

1012

1113
class Tester(unittest.TestCase):
1214

1315
def test_vflip(self):
16+
script_vflip = torch.jit.script(F_t.vflip)
1417
img_tensor = torch.randn(3, 16, 16)
1518
img_tensor_clone = img_tensor.clone()
1619
vflipped_img = F_t.vflip(img_tensor)
1720
vflipped_img_again = F_t.vflip(vflipped_img)
1821
self.assertEqual(vflipped_img.shape, img_tensor.shape)
1922
self.assertTrue(torch.equal(img_tensor, vflipped_img_again))
2023
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
24+
# scriptable function test
25+
vflipped_img_script = script_vflip(img_tensor)
26+
self.assertTrue(torch.equal(vflipped_img, vflipped_img_script))
2127

2228
def test_hflip(self):
29+
script_hflip = torch.jit.script(F_t.hflip)
2330
img_tensor = torch.randn(3, 16, 16)
2431
img_tensor_clone = img_tensor.clone()
2532
hflipped_img = F_t.hflip(img_tensor)
2633
hflipped_img_again = F_t.hflip(hflipped_img)
2734
self.assertEqual(hflipped_img.shape, img_tensor.shape)
2835
self.assertTrue(torch.equal(img_tensor, hflipped_img_again))
2936
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
37+
# scriptable function test
38+
hflipped_img_script = script_hflip(img_tensor)
39+
self.assertTrue(torch.equal(hflipped_img, hflipped_img_script))
3040

3141
def test_crop(self):
42+
script_crop = torch.jit.script(F_t.crop)
3243
img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
3344
img_tensor_clone = img_tensor.clone()
3445
top = random.randint(0, 15)
@@ -42,11 +53,18 @@ def test_crop(self):
4253
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
4354
self.assertTrue(torch.equal(img_cropped, (img_cropped_GT * 255).to(torch.uint8)),
4455
"functional_tensor crop not working")
56+
# scriptable function test
57+
cropped_img_script = script_crop(img_tensor, top, left, height, width)
58+
self.assertTrue(torch.equal(img_cropped, cropped_img_script))
4559

4660
def test_adjustments(self):
47-
fns = ((F.adjust_brightness, F_t.adjust_brightness),
48-
(F.adjust_contrast, F_t.adjust_contrast),
49-
(F.adjust_saturation, F_t.adjust_saturation))
61+
script_adjust_brightness = torch.jit.script(F_t.adjust_brightness)
62+
script_adjust_contrast = torch.jit.script(F_t.adjust_contrast)
63+
script_adjust_saturation = torch.jit.script(F_t.adjust_saturation)
64+
65+
fns = ((F.adjust_brightness, F_t.adjust_brightness, script_adjust_brightness),
66+
(F.adjust_contrast, F_t.adjust_contrast, script_adjust_contrast),
67+
(F.adjust_saturation, F_t.adjust_saturation, script_adjust_saturation))
5068

5169
for _ in range(20):
5270
channels = 3
@@ -60,11 +78,13 @@ def test_adjustments(self):
6078

6179
factor = 3 * torch.rand(1)
6280
img_clone = img.clone()
63-
for f, ft in fns:
81+
for f, ft, sft in fns:
6482

6583
ft_img = ft(img, factor)
84+
sft_img = sft(img, factor)
6685
if not img.dtype.is_floating_point:
6786
ft_img = ft_img.to(torch.float) / 255
87+
sft_img = sft_img.to(torch.float) / 255
6888

6989
img_pil = transforms.ToPILImage()(img)
7090
f_img_pil = f(img_pil, factor)
@@ -73,28 +93,39 @@ def test_adjustments(self):
7393
# F uses uint8 and F_t uses float, so there is a small
7494
# difference in values caused by (at most 5) truncations.
7595
max_diff = (ft_img - f_img).abs().max()
96+
max_diff_scripted = (sft_img - f_img).abs().max()
7697
self.assertLess(max_diff, 5 / 255 + 1e-5)
98+
self.assertLess(max_diff_scripted, 5 / 255 + 1e-5)
7799
self.assertTrue(torch.equal(img, img_clone))
78100

79101
def test_rgb_to_grayscale(self):
102+
script_rgb_to_grayscale = torch.jit.script(F_t.rgb_to_grayscale)
80103
img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
81104
img_tensor_clone = img_tensor.clone()
82105
grayscale_tensor = F_t.rgb_to_grayscale(img_tensor).to(int)
83106
grayscale_pil_img = torch.tensor(np.array(F.to_grayscale(F.to_pil_image(img_tensor)))).to(int)
84107
max_diff = (grayscale_tensor - grayscale_pil_img).abs().max()
85108
self.assertLess(max_diff, 1.0001)
86109
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
110+
# scriptable function test
111+
grayscale_script = script_rgb_to_grayscale(img_tensor).to(int)
112+
self.assertTrue(torch.equal(grayscale_script, grayscale_tensor))
87113

88114
def test_center_crop(self):
115+
script_center_crop = torch.jit.script(F_t.center_crop)
89116
img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
90117
img_tensor_clone = img_tensor.clone()
91118
cropped_tensor = F_t.center_crop(img_tensor, [10, 10])
92119
cropped_pil_image = F.center_crop(transforms.ToPILImage()(img_tensor), [10, 10])
93120
cropped_pil_tensor = (transforms.ToTensor()(cropped_pil_image) * 255).to(torch.uint8)
94121
self.assertTrue(torch.equal(cropped_tensor, cropped_pil_tensor))
95122
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
123+
# scriptable function test
124+
cropped_script = script_center_crop(img_tensor, [10, 10])
125+
self.assertTrue(torch.equal(cropped_script, cropped_tensor))
96126

97127
def test_five_crop(self):
128+
script_five_crop = torch.jit.script(F_t.five_crop)
98129
img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
99130
img_tensor_clone = img_tensor.clone()
100131
cropped_tensor = F_t.five_crop(img_tensor, [10, 10])
@@ -110,8 +141,13 @@ def test_five_crop(self):
110141
self.assertTrue(torch.equal(cropped_tensor[4],
111142
(transforms.ToTensor()(cropped_pil_image[4]) * 255).to(torch.uint8)))
112143
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
144+
# scriptable function test
145+
cropped_script = script_five_crop(img_tensor, [10, 10])
146+
for cropped_script_img, cropped_tensor_img in zip(cropped_script, cropped_tensor):
147+
self.assertTrue(torch.equal(cropped_script_img, cropped_tensor_img))
113148

114149
def test_ten_crop(self):
150+
script_ten_crop = torch.jit.script(F_t.ten_crop)
115151
img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
116152
img_tensor_clone = img_tensor.clone()
117153
cropped_tensor = F_t.ten_crop(img_tensor, [10, 10])
@@ -137,6 +173,10 @@ def test_ten_crop(self):
137173
self.assertTrue(torch.equal(cropped_tensor[9],
138174
(transforms.ToTensor()(cropped_pil_image[9]) * 255).to(torch.uint8)))
139175
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
176+
# scriptable function test
177+
cropped_script = script_ten_crop(img_tensor, [10, 10])
178+
for cropped_script_img, cropped_tensor_img in zip(cropped_script, cropped_tensor):
179+
self.assertTrue(torch.equal(cropped_script_img, cropped_tensor_img))
140180

141181

142182
if __name__ == '__main__':

torchvision/transforms/functional_tensor.py

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,48 @@
1+
from __future__ import division
12
import torch
23
import torchvision.transforms.functional as F
4+
from torch import Tensor
5+
from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple
36

47

5-
def vflip(img_tensor):
8+
def _is_tensor_a_torch_image(input):
9+
return len(input.shape) == 3
10+
11+
12+
def vflip(img):
13+
# type: (Tensor) -> Tensor
614
"""Vertically flip the given the Image Tensor.
715
816
Args:
9-
img_tensor (Tensor): Image Tensor to be flipped in the form [C, H, W].
17+
img (Tensor): Image Tensor to be flipped in the form [C, H, W].
1018
1119
Returns:
1220
Tensor: Vertically flipped image Tensor.
1321
"""
14-
if not F._is_tensor_image(img_tensor):
22+
if not _is_tensor_a_torch_image(img):
1523
raise TypeError('tensor is not a torch image.')
1624

17-
return img_tensor.flip(-2)
25+
return img.flip(-2)
1826

1927

20-
def hflip(img_tensor):
28+
def hflip(img):
29+
# type: (Tensor) -> Tensor
2130
"""Horizontally flip the given the Image Tensor.
2231
2332
Args:
24-
img_tensor (Tensor): Image Tensor to be flipped in the form [C, H, W].
33+
img (Tensor): Image Tensor to be flipped in the form [C, H, W].
2534
2635
Returns:
2736
Tensor: Horizontally flipped image Tensor.
2837
"""
29-
if not F._is_tensor_image(img_tensor):
38+
if not _is_tensor_a_torch_image(img):
3039
raise TypeError('tensor is not a torch image.')
3140

32-
return img_tensor.flip(-1)
41+
return img.flip(-1)
3342

3443

3544
def crop(img, top, left, height, width):
45+
# type: (Tensor, int, int, int, int) -> Tensor
3646
"""Crop the given Image Tensor.
3747
3848
Args:
@@ -45,13 +55,14 @@ def crop(img, top, left, height, width):
4555
Returns:
4656
Tensor: Cropped image.
4757
"""
48-
if not F._is_tensor_image(img):
58+
if not _is_tensor_a_torch_image(img):
4959
raise TypeError('tensor is not a torch image.')
5060

5161
return img[..., top:top + height, left:left + width]
5262

5363

5464
def rgb_to_grayscale(img):
65+
# type: (Tensor) -> Tensor
5566
"""Convert the given RGB Image Tensor to Grayscale.
5667
For RGB to Grayscale conversion, ITU-R 601-2 luma transform is performed which
5768
is L = R * 0.2989 + G * 0.5870 + B * 0.1140
@@ -70,6 +81,7 @@ def rgb_to_grayscale(img):
7081

7182

7283
def adjust_brightness(img, brightness_factor):
84+
# type: (Tensor, float) -> Tensor
7385
"""Adjust brightness of an RGB image.
7486
7587
Args:
@@ -81,13 +93,14 @@ def adjust_brightness(img, brightness_factor):
8193
Returns:
8294
Tensor: Brightness adjusted image.
8395
"""
84-
if not F._is_tensor_image(img):
96+
if not _is_tensor_a_torch_image(img):
8597
raise TypeError('tensor is not a torch image.')
8698

87-
return _blend(img, 0, brightness_factor)
99+
return _blend(img, torch.zeros_like(img), brightness_factor)
88100

89101

90102
def adjust_contrast(img, contrast_factor):
103+
# type: (Tensor, float) -> Tensor
91104
"""Adjust contrast of an RGB image.
92105
93106
Args:
@@ -99,7 +112,7 @@ def adjust_contrast(img, contrast_factor):
99112
Returns:
100113
Tensor: Contrast adjusted image.
101114
"""
102-
if not F._is_tensor_image(img):
115+
if not _is_tensor_a_torch_image(img):
103116
raise TypeError('tensor is not a torch image.')
104117

105118
mean = torch.mean(rgb_to_grayscale(img).to(torch.float))
@@ -108,6 +121,7 @@ def adjust_contrast(img, contrast_factor):
108121

109122

110123
def adjust_saturation(img, saturation_factor):
124+
# type: (Tensor, float) -> Tensor
111125
"""Adjust color saturation of an RGB image.
112126
113127
Args:
@@ -119,13 +133,14 @@ def adjust_saturation(img, saturation_factor):
119133
Returns:
120134
Tensor: Saturation adjusted image.
121135
"""
122-
if not F._is_tensor_image(img):
136+
if not _is_tensor_a_torch_image(img):
123137
raise TypeError('tensor is not a torch image.')
124138

125139
return _blend(img, rgb_to_grayscale(img), saturation_factor)
126140

127141

128142
def center_crop(img, output_size):
143+
# type: (Tensor, BroadcastingList2[int]) -> Tensor
129144
"""Crop the Image Tensor and resize it to desired size.
130145
131146
Args:
@@ -136,7 +151,7 @@ def center_crop(img, output_size):
136151
Returns:
137152
Tensor: Cropped image.
138153
"""
139-
if not F._is_tensor_image(img):
154+
if not _is_tensor_a_torch_image(img):
140155
raise TypeError('tensor is not a torch image.')
141156

142157
_, image_width, image_height = img.size()
@@ -148,9 +163,10 @@ def center_crop(img, output_size):
148163

149164

150165
def five_crop(img, size):
166+
# type: (Tensor, BroadcastingList2[int]) -> List[Tensor]
151167
"""Crop the given Image Tensor into four corners and the central crop.
152168
.. Note::
153-
This transform returns a tuple of Tensors and there may be a
169+
This transform returns a List of Tensors and there may be a
154170
mismatch in the number of inputs and targets your ``Dataset`` returns.
155171
156172
Args:
@@ -159,10 +175,10 @@ def five_crop(img, size):
159175
made.
160176
161177
Returns:
162-
tuple: tuple (tl, tr, bl, br, center)
178+
List: List (tl, tr, bl, br, center)
163179
Corresponding top left, top right, bottom left, bottom right and center crop.
164180
"""
165-
if not F._is_tensor_image(img):
181+
if not _is_tensor_a_torch_image(img):
166182
raise TypeError('tensor is not a torch image.')
167183

168184
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
@@ -179,14 +195,15 @@ def five_crop(img, size):
179195
br = crop(img, image_width - crop_width, image_height - crop_height, image_width, image_height)
180196
center = center_crop(img, (crop_height, crop_width))
181197

182-
return (tl, tr, bl, br, center)
198+
return [tl, tr, bl, br, center]
183199

184200

185201
def ten_crop(img, size, vertical_flip=False):
202+
# type: (Tensor, BroadcastingList2[int], bool) -> List[Tensor]
186203
"""Crop the given Image Tensor into four corners and the central crop plus the
187204
flipped version of these (horizontal flipping is used by default).
188205
.. Note::
189-
This transform returns a tuple of images and there may be a
206+
This transform returns a List of images and there may be a
190207
mismatch in the number of inputs and targets your ``Dataset`` returns.
191208
192209
Args:
@@ -196,11 +213,11 @@ def ten_crop(img, size, vertical_flip=False):
196213
vertical_flip (bool): Use vertical flipping instead of horizontal
197214
198215
Returns:
199-
tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
216+
List: List (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
200217
Corresponding top left, top right, bottom left, bottom right and center crop
201218
and same for the flipped image's tensor.
202219
"""
203-
if not F._is_tensor_image(img):
220+
if not _is_tensor_a_torch_image(img):
204221
raise TypeError('tensor is not a torch image.')
205222

206223
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
@@ -217,5 +234,6 @@ def ten_crop(img, size, vertical_flip=False):
217234

218235

219236
def _blend(img1, img2, ratio):
220-
bound = 1 if img1.dtype.is_floating_point else 255
237+
# type: (Tensor, Tensor, float) -> Tensor
238+
bound = 1 if img1.dtype in [torch.half, torch.float32, torch.float64] else 255
221239
return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype)

0 commit comments

Comments
 (0)