Skip to content

Commit 5b67465

Browse files
authored
Merge pull request #1230 from ananyakaligal/main
Added SRGAN
2 parents 0772013 + e2148e8 commit 5b67465

File tree

17 files changed

+682
-0
lines changed

17 files changed

+682
-0
lines changed
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
*.egg-info/
24+
.installed.cfg
25+
*.egg
26+
MANIFEST
27+
28+
# PyInstaller
29+
# Usually these files are written by a python script from a template
30+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
31+
*.manifest
32+
*.spec
33+
34+
# Installer logs
35+
pip-log.txt
36+
pip-delete-this-directory.txt
37+
38+
# Unit test / coverage reports
39+
htmlcov/
40+
.tox/
41+
.coverage
42+
.coverage.*
43+
.cache
44+
nosetests.xml
45+
coverage.xml
46+
*.cover
47+
.hypothesis/
48+
.pytest_cache/
49+
50+
# Translations
51+
*.mo
52+
*.pot
53+
54+
# Django stuff:
55+
*.log
56+
local_settings.py
57+
db.sqlite3
58+
59+
# Flask stuff:
60+
instance/
61+
.webassets-cache
62+
63+
# Scrapy stuff:
64+
.scrapy
65+
66+
# Sphinx documentation
67+
docs/_build/
68+
69+
# PyBuilder
70+
target/
71+
72+
# Jupyter Notebook
73+
.ipynb_checkpoints
74+
75+
# pyenv
76+
.python-version
77+
78+
# celery beat schedule file
79+
celerybeat-schedule
80+
81+
# SageMath parsed files
82+
*.sage.py
83+
84+
# Environments
85+
.env
86+
.venv
87+
env/
88+
venv/
89+
ENV/
90+
env.bak/
91+
venv.bak/
92+
93+
# Spyder project settings
94+
.spyderproject
95+
.spyproject
96+
97+
# Rope project settings
98+
.ropeproject
99+
100+
# mkdocs documentation
101+
/site
102+
103+
# mypy
104+
.mypy_cache/
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
This is basically a Super Resolution Generative Adversarial Network (SRGAN) with the purpose of upscaling image resolutions by a factor of two using deep learning. This way, a picture which initially appears pixellated and/or blurry can be modified so that the features are quite more distinguishable. The model is trained on the COCO unlabeled2017 dataset. Download [here](http://cocodataset.org/#download).
2+
3+
## Requirements
4+
- Tensorflow 2.0
5+
- Scipy, Numpy
6+
- PIL
7+
- Matplotlib
8+
- MS COCO unlabeled2017 Dataset (for training)
9+
10+
## Usage
11+
To train model (which we highly reccomend doing some more):
12+
```
13+
python srgan.py
14+
```
15+
To run the model on an image:
16+
```
17+
python srgan.py -p image.jpg
18+
```
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import scipy.misc
2+
from glob import glob
3+
import numpy as np
4+
import matplotlib.pyplot as plt
5+
6+
class DataLoader():
7+
def __init__(self, dataset_name, img_res=(200, 200)):
8+
self.dataset_name = dataset_name
9+
self.img_res = img_res
10+
def load_data(self, batch_size=1, is_testing=False):
11+
data_type = "train" if not is_testing else "test"
12+
13+
path = glob('../../datasets/%s/*' % (self.dataset_name))
14+
15+
batch_images = np.random.choice(path, size=batch_size)
16+
17+
imgs_hr = []
18+
imgs_lr = []
19+
for img_path in batch_images:
20+
img = self.imread(img_path)
21+
22+
h, w = self.img_res
23+
low_h, low_w = int(h / 2), int(w / 2)
24+
25+
img_hr = scipy.misc.imresize(img, self.img_res)
26+
img_lr = scipy.misc.imresize(img, (low_h, low_w))
27+
28+
# If training => do random flip
29+
if not is_testing and np.random.random() < 0.5:
30+
img_hr = np.fliplr(img_hr)
31+
img_lr = np.fliplr(img_lr)
32+
33+
imgs_hr.append(img_hr)
34+
imgs_lr.append(img_lr)
35+
36+
imgs_hr = np.array(imgs_hr) / 127.5 - 1.
37+
imgs_lr = np.array(imgs_lr) / 127.5 - 1.
38+
return imgs_hr, imgs_lr
39+
def load_pred(self, path):
40+
img = self.imread(path)
41+
imgs_hr = []
42+
imgs_lr = []
43+
h, w = self.img_res
44+
low_h, low_w = int(h / 2), int(w / 2)
45+
img_hr = scipy.misc.imresize(img, (self.img_res))
46+
img_lr = scipy.misc.imresize(img, (low_h, low_w))
47+
imgs_hr.append(img_hr)
48+
imgs_lr.append(img_lr)
49+
imgs_hr = np.array(imgs_hr) / 127.5 - 1.
50+
imgs_lr = np.array(imgs_lr) / 127.5 - 1.
51+
return imgs_hr, imgs_lr
52+
53+
def load_resize(self, path):
54+
img = self.imread(path)
55+
imgs_hr = []
56+
imgs_lr = []
57+
h, w = self.img_res
58+
low_h, low_w = int(h/2), int(w/2)
59+
img_hr = scipy.misc.imresize(img, (self.img_res))
60+
img_lr = scipy.misc.imresize(img, (img_lr))
61+
imgs_hr = np.array(imgs_hr) / 127.5 - 1.
62+
imgs_lr = np.array(imgs_lr) / 127.5 - 1.
63+
imgs_hr = np.resize(imgs_hr, (-1, 400,400,3))
64+
65+
def imread(self, path):
66+
return scipy.misc.imread(path, mode='RGB').astype(np.float)
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
import os
4+
import math
5+
path = '../../datasets/unlabeled2017/000000002272.jpg'
6+
7+
class ImageSlicer(object):
8+
def __init__(self, source, size, strides=[None, None], BATCH = False, PADDING=False):
9+
self.source = source
10+
self.size = size
11+
self.strides = strides
12+
self.BATCH = BATCH
13+
self.PADDING = PADDING
14+
15+
def __read_images(self):
16+
Images = []
17+
image_names = sorted(os.listdir(self.source))
18+
for im in image_names:
19+
image = plt.imread(os.path.join(dir_path,im))
20+
Images.append(image)
21+
return Images
22+
23+
def __offset_op(self, input_length, output_length, stride):
24+
offset = (input_length) - (stride*((input_length - output_length)//stride)+output_length)
25+
return offset
26+
27+
def __padding_op(self, Image):
28+
if self.offset_x > 0:
29+
padding_x = self.strides[0] - self.offset_x
30+
else:
31+
padding_x = 0
32+
if self.offset_y > 0:
33+
padding_y = self.strides[1] - self.offset_y
34+
else:
35+
padding_y = 0
36+
Padded_Image = np.zeros(shape=(Image.shape[0]+padding_x, Image.shape[1]+padding_y, Image.shape[2]),dtype=Image.dtype)
37+
Padded_Image[padding_x//2:(padding_x//2)+(Image.shape[0]),padding_y//2:(padding_y//2)+Image.shape[1],:] = Image
38+
return Padded_Image
39+
40+
def __convolution_op(self, Image):
41+
start_x = 0
42+
start_y = 0
43+
self.n_rows = Image.shape[0]//self.strides[0] + 1
44+
self.n_columns = Image.shape[1]//self.strides[1] + 1
45+
# print(str(self.n_rows)+" rows")
46+
# print(str(self.n_columns)+" columns")
47+
small_images = []
48+
for i in range(self.n_rows-1):
49+
for j in range(self.n_columns-1):
50+
new_start_x = start_x+i*self.strides[0]
51+
new_start_y= start_y+j*self.strides[1]
52+
small_images.append(Image[new_start_x:new_start_x+self.size[0],new_start_y:new_start_y+self.size[1],:])
53+
return small_images
54+
55+
def transform(self):
56+
57+
if not(os.path.exists(self.source)):
58+
raise Exception("Path does not exist!")
59+
60+
else:
61+
if self.source and not(self.BATCH):
62+
Image = plt.imread(self.source)
63+
Images = [Image]
64+
else:
65+
Images = self.__read_images()
66+
67+
im_size = Images[0].shape
68+
num_images = len(Images)
69+
transformed_images = dict()
70+
Images = np.array(Images)
71+
72+
if self.PADDING:
73+
74+
padded_images = []
75+
76+
if self.strides[0]==None and self.strides[1]==None:
77+
self.strides[0] = self.size[0]
78+
self.strides[1] = self.size[1]
79+
self.offset_x = Images.shape[1]%self.size[0]
80+
self.offset_y = Images.shape[2]%self.size[1]
81+
padded_images = list(map(self.__padding_op, Images))
82+
83+
elif self.strides[0]==None and self.strides[1]!=None:
84+
self.strides[0] = self.size[0]
85+
self.offset_x = Images.shape[1]%self.size[0]
86+
if self.strides[1] <= Images.shape[2]:
87+
self.offset_y = self.__offset_op(Images.shape[2], self.size[1], self.strides[1])
88+
else:
89+
raise Exception("stride_y must be between {0} and {1}".format(1,Images.shape[2]))
90+
padded_images = list(map(self.__padding_op, Images))
91+
92+
elif self.strides[0]!=None and self.strides[1]==None:
93+
self.strides[1] = self.size[1]
94+
self.offset_y = Images.shape[2]%self.size[1]
95+
if self.strides[0] <=Images.shape[1]:
96+
self.offset_x = self.__offset_op(Images.shape[1], self.size[0], self.strides[0])
97+
else:
98+
raise Exception("stride_x must be between {0} and {1}".format(1,Images.shape[1]))
99+
padded_images = list(map(self.__padding_op, Images))
100+
101+
else:
102+
if self.strides[0] > Images.shape[1]:
103+
raise Exception("stride_x must be between {0} and {1}".format(1,Images.shape[1]))
104+
105+
elif self.strides[1] > Images.shape[2]:
106+
raise Exception("stride_y must be between {0} and {1}".format(1,Images.shape[2]))
107+
108+
else:
109+
self.offset_x = self.__offset_op(Images.shape[1], self.size[0], self.strides[0])
110+
self.offset_y = self.__offset_op(Images.shape[2], self.size[1], self.strides[1])
111+
padded_images = list(map(self.__padding_op, Images))
112+
113+
for i, Image in enumerate(padded_images):
114+
transformed_images[str(i)] = self.__convolution_op(Image)
115+
116+
else:
117+
if self.strides[0]==None and self.strides[1]==None:
118+
self.strides[0] = self.size[0]
119+
self.strides[1] = self.size[1]
120+
121+
elif self.strides[0]==None and self.strides[1]!=None:
122+
if self.strides[1] > Images.shape[2]:
123+
raise Exception("stride_y must be between {0} and {1}".format(1,Images.shape[2]))
124+
self.strides[0] = self.size[0]
125+
126+
elif self.strides[0]!=None and self.strides[1]==None:
127+
if self.strides[0] > Images.shape[1]:
128+
raise Exception("stride_x must be between {0} and {1}".format(1,Images.shape[1]))
129+
self.strides[1] = self.size[1]
130+
else:
131+
if self.strides[0] > Images.shape[1]:
132+
raise Exception("stride_x must be between {0} and {1}".format(1,Images.shape[1]))
133+
elif self.strides[1] > Images.shape[2]:
134+
raise Exception("stride_y must be between {0} and {1}".format(1,Images.shape[2]))
135+
136+
for i, Image in enumerate(Images):
137+
transformed_images[str(i)] = self.__convolution_op(Image)
138+
139+
return transformed_images
140+
141+
def save_images(self,transformed):
142+
self.r,self.c = self.n_rows-1, self.n_columns-1
143+
for key, val in transformed.items():
144+
val = np.array(val, dtype=np.float64) /127.5 -1.
145+
val = .5 * val + 0.5
146+
return val

0 commit comments

Comments
 (0)