Skip to content

Commit cfa77f0

Browse files
authored
Merge branch 'recodehive:main' into main
2 parents 965eecd + 7e3d099 commit cfa77f0

File tree

24 files changed

+3430
-189
lines changed

24 files changed

+3430
-189
lines changed
Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,25 @@
1-
More details about project
1+
# **Bird Species Classification** 🐦
22

3+
### 🎯 Goal
4+
The primary goal of this project is to build deep learning models to classify Bird species .
5+
6+
### 🧵 Dataset : https://www.kaggle.com/datasets/akash2907/bird-species-classification/data
7+
8+
### 🧾 Description
9+
This dataset consists of over 170 labeled images of birds, including validation images. Each image belongs to only one bird category. The challenge is to develop models that can accurately classify these images into the correct species.
10+
11+
### 📚 Libraries Needed
12+
- os - Provides functions to interact with the operating system.
13+
- shutil - Offers file operations like copying, moving, and removing files.
14+
- time - Used for time-related functions.
15+
- torch - Core library for PyTorch, used for deep learning.
16+
- torch.nn - Contains neural network layers and loss functions.
17+
- torchvision - Provides datasets, models, and image transformation tools for computer vision.
18+
- torchvision.transforms - Contains common image transformation operations.
19+
- torch.optim - Optimizers for training neural networks.
20+
- matplotlib.pyplot - Used for data visualization, like plotting graphs.
21+
22+
## EDA Result 👉 [Classified Bird Species](https://github.com/Archi20876/machine-learning-repos/blob/main/Classification%20Models/Bird%20species%20classification/bird-species-classification.ipynb)
23+
24+
25+

ML Hub - Learning Resource(1)(1).png

1.12 MB
Loading

ML Hub - Learning Resource.jpg

282 KB
Loading
439 KB
Loading
154 KB
Loading
Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
from tools.ops import *
2+
from tools.utils import *
3+
from glob import glob
4+
import time
5+
import numpy as np
6+
from net import generator
7+
from net.discriminator import D_net
8+
from tools.data_loader import ImageGenerator
9+
from tools.vgg19 import Vgg19
10+
11+
class AnimeGANv2(object) :
12+
def __init__(self, sess, args):
13+
self.model_name = 'AnimeGANv2'
14+
self.sess = sess
15+
self.checkpoint_dir = args.checkpoint_dir
16+
self.log_dir = args.log_dir
17+
self.dataset_name = args.dataset
18+
19+
self.epoch = args.epoch
20+
self.init_epoch = args.init_epoch # args.epoch // 20
21+
22+
self.gan_type = args.gan_type
23+
self.batch_size = args.batch_size
24+
self.save_freq = args.save_freq
25+
26+
self.init_lr = args.init_lr
27+
self.d_lr = args.d_lr
28+
self.g_lr = args.g_lr
29+
30+
""" Weight """
31+
self.g_adv_weight = args.g_adv_weight
32+
self.d_adv_weight = args.d_adv_weight
33+
self.con_weight = args.con_weight
34+
self.sty_weight = args.sty_weight
35+
self.color_weight = args.color_weight
36+
self.tv_weight = args.tv_weight
37+
38+
self.training_rate = args.training_rate
39+
self.ld = args.ld
40+
41+
self.img_size = args.img_size
42+
self.img_ch = args.img_ch
43+
44+
""" Discriminator """
45+
self.n_dis = args.n_dis
46+
self.ch = args.ch
47+
self.sn = args.sn
48+
49+
self.sample_dir = os.path.join(args.sample_dir, self.model_dir)
50+
check_folder(self.sample_dir)
51+
52+
self.real = tf.placeholder(tf.float32, [self.batch_size, self.img_size[0], self.img_size[1], self.img_ch], name='real_A')
53+
self.anime = tf.placeholder(tf.float32, [self.batch_size, self.img_size[0], self.img_size[1], self.img_ch], name='anime_A')
54+
self.anime_smooth = tf.placeholder(tf.float32, [self.batch_size, self.img_size[0], self.img_size[1], self.img_ch], name='anime_smooth_A')
55+
self.test_real = tf.placeholder(tf.float32, [1, None, None, self.img_ch], name='test_input')
56+
57+
self.anime_gray = tf.placeholder(tf.float32, [self.batch_size, self.img_size[0], self.img_size[1], self.img_ch],name='anime_B')
58+
59+
60+
self.real_image_generator = ImageGenerator('./dataset/train_photo', self.img_size, self.batch_size)
61+
self.anime_image_generator = ImageGenerator('./dataset/{}'.format(self.dataset_name + '/style'), self.img_size, self.batch_size)
62+
self.anime_smooth_generator = ImageGenerator('./dataset/{}'.format(self.dataset_name + '/smooth'), self.img_size, self.batch_size)
63+
self.dataset_num = max(self.real_image_generator.num_images, self.anime_image_generator.num_images)
64+
65+
self.vgg = Vgg19()
66+
67+
print()
68+
print("##### Information #####")
69+
print("# gan type : ", self.gan_type)
70+
print("# dataset : ", self.dataset_name)
71+
print("# max dataset number : ", self.dataset_num)
72+
print("# batch_size : ", self.batch_size)
73+
print("# epoch : ", self.epoch)
74+
print("# init_epoch : ", self.init_epoch)
75+
print("# training image size [H, W] : ", self.img_size)
76+
print("# g_adv_weight,d_adv_weight,con_weight,sty_weight,color_weight,tv_weight : ", self.g_adv_weight,self.d_adv_weight,self.con_weight,self.sty_weight,self.color_weight,self.tv_weight)
77+
print("# init_lr,g_lr,d_lr : ", self.init_lr,self.g_lr,self.d_lr)
78+
print(f"# training_rate G -- D: {self.training_rate} : 1" )
79+
print()
80+
81+
##################################################################################
82+
# Generator
83+
##################################################################################
84+
85+
def generator(self, x_init, reuse=False, scope="generator"):
86+
with tf.variable_scope(scope, reuse=reuse):
87+
G = generator.G_net(x_init)
88+
return G.fake
89+
90+
##################################################################################
91+
# Discriminator
92+
##################################################################################
93+
94+
def discriminator(self, x_init, reuse=False, scope="discriminator"):
95+
D = D_net(x_init, self.ch, self.n_dis, self.sn, reuse=reuse, scope=scope)
96+
return D
97+
98+
##################################################################################
99+
# Model
100+
##################################################################################
101+
def gradient_panalty(self, real, fake, scope="discriminator"):
102+
if self.gan_type.__contains__('dragan') :
103+
eps = tf.random_uniform(shape=tf.shape(real), minval=0., maxval=1.)
104+
_, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3])
105+
x_std = tf.sqrt(x_var) # magnitude of noise decides the size of local region
106+
107+
fake = real + 0.5 * x_std * eps
108+
109+
alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.)
110+
interpolated = real + alpha * (fake - real)
111+
112+
logit, _= self.discriminator(interpolated, reuse=True, scope=scope)
113+
114+
grad = tf.gradients(logit, interpolated)[0] # gradient of D(interpolated)
115+
grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm
116+
117+
GP = 0
118+
# WGAN - LP
119+
if self.gan_type.__contains__('lp'):
120+
GP = self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.)))
121+
122+
elif self.gan_type.__contains__('gp') or self.gan_type == 'dragan' :
123+
GP = self.ld * tf.reduce_mean(tf.square(grad_norm - 1.))
124+
125+
return GP
126+
127+
def build_model(self):
128+
129+
""" Define Generator, Discriminator """
130+
self.generated = self.generator(self.real)
131+
self.test_generated = self.generator(self.test_real, reuse=True)
132+
133+
134+
anime_logit = self.discriminator(self.anime)
135+
anime_gray_logit = self.discriminator(self.anime_gray, reuse=True)
136+
137+
generated_logit = self.discriminator(self.generated, reuse=True)
138+
smooth_logit = self.discriminator(self.anime_smooth, reuse=True)
139+
140+
""" Define Loss """
141+
if self.gan_type.__contains__('gp') or self.gan_type.__contains__('lp') or self.gan_type.__contains__('dragan') :
142+
GP = self.gradient_panalty(real=self.anime, fake=self.generated)
143+
else :
144+
GP = 0.0
145+
146+
# init pharse
147+
init_c_loss = con_loss(self.vgg, self.real, self.generated)
148+
init_loss = self.con_weight * init_c_loss
149+
150+
self.init_loss = init_loss
151+
152+
# gan
153+
c_loss, s_loss = con_sty_loss(self.vgg, self.real, self.anime_gray, self.generated)
154+
tv_loss = self.tv_weight * total_variation_loss(self.generated)
155+
t_loss = self.con_weight * c_loss + self.sty_weight * s_loss + color_loss(self.real,self.generated) * self.color_weight + tv_loss
156+
157+
g_loss = self.g_adv_weight * generator_loss(self.gan_type, generated_logit)
158+
d_loss = self.d_adv_weight * discriminator_loss(self.gan_type, anime_logit, anime_gray_logit, generated_logit, smooth_logit) + GP
159+
160+
self.Generator_loss = t_loss + g_loss
161+
self.Discriminator_loss = d_loss
162+
163+
""" Training """
164+
t_vars = tf.trainable_variables()
165+
G_vars = [var for var in t_vars if 'generator' in var.name]
166+
D_vars = [var for var in t_vars if 'discriminator' in var.name]
167+
168+
self.init_optim = tf.train.AdamOptimizer(self.init_lr, beta1=0.5, beta2=0.999).minimize(self.init_loss, var_list=G_vars)
169+
self.G_optim = tf.train.AdamOptimizer(self.g_lr , beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars)
170+
self.D_optim = tf.train.AdamOptimizer(self.d_lr , beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars)
171+
172+
"""" Summary """
173+
self.G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss)
174+
self.D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss)
175+
176+
self.G_gan = tf.summary.scalar("G_gan", g_loss)
177+
self.G_vgg = tf.summary.scalar("G_vgg", t_loss)
178+
self.G_init_loss = tf.summary.scalar("G_init", init_loss)
179+
180+
self.V_loss_merge = tf.summary.merge([self.G_init_loss])
181+
self.G_loss_merge = tf.summary.merge([self.G_loss, self.G_gan, self.G_vgg, self.G_init_loss])
182+
self.D_loss_merge = tf.summary.merge([self.D_loss])
183+
184+
def train(self):
185+
# initialize all variables
186+
self.sess.run(tf.global_variables_initializer())
187+
188+
# saver to save model
189+
self.saver = tf.train.Saver(max_to_keep=self.epoch)
190+
191+
# summary writer
192+
self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph)
193+
194+
""" Input Image"""
195+
real_img_op, anime_img_op, anime_smooth_op = self.real_image_generator.load_images(), self.anime_image_generator.load_images(), self.anime_smooth_generator.load_images()
196+
197+
198+
# restore check-point if it exits
199+
could_load, checkpoint_counter = self.load(self.checkpoint_dir)
200+
if could_load:
201+
start_epoch = checkpoint_counter + 1
202+
203+
print(" [*] Load SUCCESS")
204+
else:
205+
start_epoch = 0
206+
207+
print(" [!] Load failed...")
208+
209+
# loop for epoch
210+
init_mean_loss = []
211+
mean_loss = []
212+
# training times , G : D = self.training_rate : 1
213+
j = self.training_rate
214+
for epoch in range(start_epoch, self.epoch):
215+
for idx in range(int(self.dataset_num / self.batch_size)):
216+
anime, anime_smooth, real = self.sess.run([anime_img_op, anime_smooth_op, real_img_op])
217+
train_feed_dict = {
218+
self.real:real[0],
219+
self.anime:anime[0],
220+
self.anime_gray:anime[1],
221+
self.anime_smooth:anime_smooth[1]
222+
}
223+
224+
if epoch < self.init_epoch :
225+
# Init G
226+
start_time = time.time()
227+
228+
real_images, generator_images, _, v_loss, summary_str = self.sess.run([self.real, self.generated,
229+
self.init_optim,
230+
self.init_loss, self.V_loss_merge], feed_dict = train_feed_dict)
231+
self.writer.add_summary(summary_str, epoch)
232+
init_mean_loss.append(v_loss)
233+
234+
print("Epoch: %3d Step: %5d / %5d time: %f s init_v_loss: %.8f mean_v_loss: %.8f" % (epoch, idx,int(self.dataset_num / self.batch_size), time.time() - start_time, v_loss, np.mean(init_mean_loss)))
235+
if (idx+1)%200 ==0:
236+
init_mean_loss.clear()
237+
else :
238+
start_time = time.time()
239+
240+
if j == self.training_rate:
241+
# Update D
242+
_, d_loss, summary_str = self.sess.run([self.D_optim, self.Discriminator_loss, self.D_loss_merge],
243+
feed_dict=train_feed_dict)
244+
self.writer.add_summary(summary_str, epoch)
245+
246+
# Update G
247+
real_images, generator_images, _, g_loss, summary_str = self.sess.run([self.real, self.generated,self.G_optim,
248+
self.Generator_loss, self.G_loss_merge], feed_dict = train_feed_dict)
249+
self.writer.add_summary(summary_str, epoch)
250+
251+
mean_loss.append([d_loss, g_loss])
252+
if j == self.training_rate:
253+
254+
print(
255+
"Epoch: %3d Step: %5d / %5d time: %f s d_loss: %.8f, g_loss: %.8f -- mean_d_loss: %.8f, mean_g_loss: %.8f" % (
256+
epoch, idx, int(self.dataset_num / self.batch_size), time.time() - start_time, d_loss, g_loss, np.mean(mean_loss, axis=0)[0],
257+
np.mean(mean_loss, axis=0)[1]))
258+
else:
259+
print(
260+
"Epoch: %3d Step: %5d / %5d time: %f s , g_loss: %.8f -- mean_g_loss: %.8f" % (
261+
epoch, idx, int(self.dataset_num / self.batch_size), time.time() - start_time, g_loss, np.mean(mean_loss, axis=0)[1]))
262+
263+
if (idx + 1) % 200 == 0:
264+
mean_loss.clear()
265+
266+
j = j - 1
267+
if j < 1:
268+
j = self.training_rate
269+
270+
271+
if (epoch + 1) >= self.init_epoch and np.mod(epoch + 1, self.save_freq) == 0:
272+
self.save(self.checkpoint_dir, epoch)
273+
274+
if epoch >= self.init_epoch -1:
275+
""" Result Image """
276+
val_files = glob('./dataset/{}/*.*'.format('val'))
277+
save_path = './{}/{:03d}/'.format(self.sample_dir, epoch)
278+
check_folder(save_path)
279+
for i, sample_file in enumerate(val_files):
280+
print('val: '+ str(i) + sample_file)
281+
sample_image = np.asarray(load_test_data(sample_file, self.img_size))
282+
test_real,test_generated = self.sess.run([self.test_real,self.test_generated],feed_dict = {self.test_real:sample_image} )
283+
save_images(test_real, save_path+'{:03d}_a.jpg'.format(i), None)
284+
save_images(test_generated, save_path+'{:03d}_b.jpg'.format(i), None)
285+
286+
@property
287+
def model_dir(self):
288+
return "{}_{}_{}_{}_{}_{}_{}_{}_{}".format(self.model_name, self.dataset_name,
289+
self.gan_type,
290+
int(self.g_adv_weight), int(self.d_adv_weight),
291+
int(self.con_weight), int(self.sty_weight),
292+
int(self.color_weight), int(self.tv_weight))
293+
294+
295+
def save(self, checkpoint_dir, step):
296+
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
297+
if not os.path.exists(checkpoint_dir):
298+
os.makedirs(checkpoint_dir)
299+
self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step)
300+
301+
def load(self, checkpoint_dir):
302+
print(" [*] Reading checkpoints...")
303+
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
304+
305+
ckpt = tf.train.get_checkpoint_state(checkpoint_dir) # checkpoint file information
306+
307+
if ckpt and ckpt.model_checkpoint_path:
308+
ckpt_name = os.path.basename(ckpt.model_checkpoint_path) # first line
309+
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
310+
counter = int(ckpt_name.split('-')[-1])
311+
print(" [*] Success to read {}".format(os.path.join(checkpoint_dir, ckpt_name)))
312+
return True, counter
313+
else:
314+
print(" [*] Failed to find a checkpoint")
315+
return False, 0
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2024 RAMESWAR BISOYI
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

0 commit comments

Comments
 (0)