66
77from PIL import Image
88
9- # set Mean and Std of RGB channels of IMAGENET to use pre-trained VGG net
10- IMAGENET_MEAN = (0.485 , 0.456 , 0.406 )
11- IMAGENET_STD = (0.229 , 0.224 , 0.225 )
12-
13- # normalize a image with mean, std
14- normalize = transforms .Normalize (mean = IMAGENET_MEAN ,
15- std = IMAGENET_STD )
16-
17- # denormalize a output image
18- denormalize = transforms .Normalize (mean = [- mean / std for mean , std in zip (IMAGENET_MEAN , IMAGENET_STD )],
19- std = [1 / std for std in IMAGENET_STD ])
9+ def lastest_arverage_value (values , length = 100 ):
10+ if len (values ) < length :
11+ length = len (values )
12+ return sum (values [- length :])/ length
2013
2114class ImageFolder (torch .utils .data .Dataset ):
22- def __init__ (self , root_path , transform ):
15+ def __init__ (self , root_path , imsize = None , cropsize = None , cencrop = False ):
2316 super (ImageFolder , self ).__init__ ()
2417
2518 self .file_names = sorted (os .listdir (root_path ))
2619 self .root_path = root_path
27- self .transform = transform
20+ self .transform = _transformer ( imsize , cropsize , cencrop )
2821
2922 def __len__ (self ):
3023 return len (self .file_names )
@@ -33,31 +26,52 @@ def __getitem__(self, index):
3326 image = Image .open (os .path .join (self .root_path + self .file_names [index ])).convert ("RGB" )
3427 return self .transform (image )
3528
36- def get_transformer (imsize = None , cropsize = None ):
29+ def _normalizer (denormalize = False ):
30+ # set Mean and Std of RGB channels of IMAGENET to use pre-trained VGG net
31+ MEAN = [0.485 , 0.456 , 0.406 ]
32+ STD = [0.229 , 0.224 , 0.225 ]
33+
34+ if denormalize :
35+ MEAN = [- mean / std for mean , std in zip (MEAN , STD )]
36+ STD = [1 / std for std in STD ]
37+
38+ return transforms .Normalize (mean = MEAN , std = STD )
39+
40+ def _transformer (imsize = None , cropsize = None , cencrop = False ):
41+ normalize = _normalizer ()
3742 transformer = []
3843 if imsize :
3944 transformer .append (transforms .Resize (imsize ))
4045 if cropsize :
41- transformer .append (transforms .RandomCrop (cropsize )),
46+ if cencrop :
47+ transformer .append (transforms .CenterCrop (cropsize ))
48+ else :
49+ transformer .append (transforms .RandomCrop (cropsize ))
50+
4251 transformer .append (transforms .ToTensor ())
4352 transformer .append (normalize )
4453 return transforms .Compose (transformer )
4554
4655def imsave (tensor , path ):
56+ denormalize = _normalizer (denormalize = True )
4757 if tensor .is_cuda :
4858 tensor = tensor .cpu ()
4959 tensor = torchvision .utils .make_grid (tensor )
5060 torchvision .utils .save_image (denormalize (tensor ).clamp_ (0.0 , 1.0 ), path )
5161 return None
5262
53- def imload (path , imsize = None , cropsize = None ):
54- transformer = get_transformer (imsize , cropsize )
63+ def imload (path , imsize = None , cropsize = None , cencrop = False ):
64+ transformer = _transformer (imsize , cropsize , cencrop )
5565 return transformer (Image .open (path ).convert ("RGB" )).unsqueeze (0 )
5666
57- def extract_features (model , x , layer_index ):
58- features = []
59- for i , layer in enumerate (model ):
60- x = layer (x )
61- if i in layer_index :
62- features .append (x )
63- return features
67+ def imshow (tensor ):
68+ denormalize = _normalizer (denormalize = True )
69+ if tensor .is_cuda :
70+ tensor = tensor .cpu ()
71+ tensor = torchvision .utils .make_grid (denormalize (tensor .squeeze (0 )))
72+ image = transforms .functional .to_pil_image (tensor .clamp_ (0.0 , 1.0 ))
73+ return image
74+
75+ def maskload (path ):
76+ mask = Image .open (path ).convert ('L' )
77+ return transforms .functional .to_tensor (mask ).unsqueeze (0 )
0 commit comments