Skip to content

Commit c7db3c3

Browse files
committed
update for loading image and mask
1 parent e51f155 commit c7db3c3

File tree

1 file changed

+38
-24
lines changed

1 file changed

+38
-24
lines changed

utils.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,18 @@
66

77
from PIL import Image
88

9-
# set Mean and Std of RGB channels of IMAGENET to use pre-trained VGG net
10-
IMAGENET_MEAN = (0.485, 0.456, 0.406)
11-
IMAGENET_STD = (0.229, 0.224, 0.225)
12-
13-
# normalize a image with mean, std
14-
normalize = transforms.Normalize(mean=IMAGENET_MEAN,
15-
std=IMAGENET_STD)
16-
17-
# denormalize a output image
18-
denormalize = transforms.Normalize(mean=[-mean/std for mean, std in zip(IMAGENET_MEAN, IMAGENET_STD)],
19-
std=[1/std for std in IMAGENET_STD])
9+
def lastest_arverage_value(values, length=100):
10+
if len(values) < length:
11+
length = len(values)
12+
return sum(values[-length:])/length
2013

2114
class ImageFolder(torch.utils.data.Dataset):
22-
def __init__(self, root_path, transform):
15+
def __init__(self, root_path, imsize=None, cropsize=None, cencrop=False):
2316
super(ImageFolder, self).__init__()
2417

2518
self.file_names = sorted(os.listdir(root_path))
2619
self.root_path = root_path
27-
self.transform = transform
20+
self.transform = _transformer(imsize, cropsize, cencrop)
2821

2922
def __len__(self):
3023
return len(self.file_names)
@@ -33,31 +26,52 @@ def __getitem__(self, index):
3326
image = Image.open(os.path.join(self.root_path + self.file_names[index])).convert("RGB")
3427
return self.transform(image)
3528

36-
def get_transformer(imsize=None, cropsize=None):
29+
def _normalizer(denormalize=False):
30+
# set Mean and Std of RGB channels of IMAGENET to use pre-trained VGG net
31+
MEAN = [0.485, 0.456, 0.406]
32+
STD = [0.229, 0.224, 0.225]
33+
34+
if denormalize:
35+
MEAN = [-mean/std for mean, std in zip(MEAN, STD)]
36+
STD = [1/std for std in STD]
37+
38+
return transforms.Normalize(mean=MEAN, std=STD)
39+
40+
def _transformer(imsize=None, cropsize=None, cencrop=False):
41+
normalize = _normalizer()
3742
transformer = []
3843
if imsize:
3944
transformer.append(transforms.Resize(imsize))
4045
if cropsize:
41-
transformer.append(transforms.RandomCrop(cropsize)),
46+
if cencrop:
47+
transformer.append(transforms.CenterCrop(cropsize))
48+
else:
49+
transformer.append(transforms.RandomCrop(cropsize))
50+
4251
transformer.append(transforms.ToTensor())
4352
transformer.append(normalize)
4453
return transforms.Compose(transformer)
4554

4655
def imsave(tensor, path):
56+
denormalize = _normalizer(denormalize=True)
4757
if tensor.is_cuda:
4858
tensor = tensor.cpu()
4959
tensor = torchvision.utils.make_grid(tensor)
5060
torchvision.utils.save_image(denormalize(tensor).clamp_(0.0, 1.0), path)
5161
return None
5262

53-
def imload(path, imsize=None, cropsize=None):
54-
transformer = get_transformer(imsize, cropsize)
63+
def imload(path, imsize=None, cropsize=None, cencrop=False):
64+
transformer = _transformer(imsize, cropsize, cencrop)
5565
return transformer(Image.open(path).convert("RGB")).unsqueeze(0)
5666

57-
def extract_features(model, x, layer_index):
58-
features = []
59-
for i, layer in enumerate(model):
60-
x = layer(x)
61-
if i in layer_index:
62-
features.append(x)
63-
return features
67+
def imshow(tensor):
68+
denormalize = _normalizer(denormalize=True)
69+
if tensor.is_cuda:
70+
tensor = tensor.cpu()
71+
tensor = torchvision.utils.make_grid(denormalize(tensor.squeeze(0)))
72+
image = transforms.functional.to_pil_image(tensor.clamp_(0.0, 1.0))
73+
return image
74+
75+
def maskload(path):
76+
mask = Image.open(path).convert('L')
77+
return transforms.functional.to_tensor(mask).unsqueeze(0)

0 commit comments

Comments
 (0)