1+ import matplotlib .pyplot as plt
2+ import torch
3+ import torch .nn as nn
4+ import torch .optim as optim
5+ import argparse
6+ import os
7+ import numpy as np
8+
9+ os .environ ["KMP_DUPLICATE_LIB_OK" ]= "TRUE"
10+
11+ def plot_figure (folder_name , generated_samples , sample1 , sample2 , pref ):
12+ plt .scatter (generated_samples [:, 0 ], generated_samples [:, 1 ], label = 'Generated' , s = 50 )
13+ plt .scatter (sample1 [:, 0 ], sample1 [:, 1 ], label = 'Sample 1' , s = 25 , alpha = 0.5 )
14+ plt .scatter (sample2 [:, 0 ], sample2 [:, 1 ], label = 'Sample 2' , s = 25 , alpha = 0.5 )
15+ if abs (pref [0 ]) < 1e-6 :
16+ plt .legend (fontsize = 20 , loc = 'lower right' )
17+
18+ plt .xlabel ('$X_1$' , fontsize = 25 )
19+ plt .ylabel ('$X_2$' , fontsize = 25 )
20+ plt .xticks (fontsize = 20 )
21+ plt .yticks (fontsize = 20 )
22+ plt .axis ('equal' )
23+ plt .plot ([0 , 4 ], [0 , 4 ], linewidth = 2 , color = 'black' )
24+ fig_name = os .path .join (folder_name , 'res_{:.2f}.pdf' .format (pref [0 ]))
25+ plt .savefig (fig_name , bbox_inches = 'tight' )
26+ print ('Save fig to {}' .format (fig_name ))
27+
28+
29+ # Generator: Transforms random noise into samples resembling the target distribution
30+ class Generator (nn .Module ):
31+ def __init__ (self , input_dim , output_dim ):
32+ super (Generator , self ).__init__ ()
33+ self .model = nn .Sequential (
34+ nn .Linear (input_dim , 256 ),
35+ nn .ReLU (),
36+ nn .Linear (256 , 256 ),
37+ nn .ReLU (),
38+ nn .Linear (256 , output_dim )
39+ )
40+
41+ def forward (self , z ):
42+ return self .model (z )
43+
44+
45+ # Discriminator: Classifies whether samples are real (from target Gaussian) or fake (from generator)
46+ class Discriminator (nn .Module ):
47+ def __init__ (self , input_dim ):
48+ super (Discriminator , self ).__init__ ()
49+ self .model = nn .Sequential (
50+ nn .Linear (input_dim , 256 ),
51+ nn .ReLU (),
52+ nn .Linear (256 , 256 ),
53+ nn .ReLU (),
54+ nn .Linear (256 , 1 ),
55+ nn .Sigmoid ()
56+ )
57+
58+ def forward (self , x ):
59+ return self .model (x )
60+
61+
62+ # Function to sample from a target Gaussian distribution
63+ def sample_multiple_gaussian (batch_size , dim , mean = 0 , std = 1 ):
64+ dist1 = torch .normal (mean = mean , std = std , size = (batch_size , dim ))
65+ dist2 = torch .normal (mean = mean + 4 , std = std , size = (batch_size , dim ))
66+ return dist1 , dist2
67+
68+
69+ class MOGANTrainer :
70+ def __init__ (self , lr , num_epochs , batch_size , n_obj , pref , input_dim , output_dim ):
71+ '''
72+ :param lr, float: learning rate.
73+ :param num_epochs, int : number of epochs.
74+ :param batch_size, int : batch size.
75+ :param n_obj, int : number of objectives.
76+ :param pref, np.array : preference vector.
77+ '''
78+ self .lr = lr
79+ self .pref = pref
80+ self .num_epochs = num_epochs
81+ self .batch_size = batch_size
82+ self .n_obj = n_obj
83+ self .input_dim = input_dim
84+ self .output_dim = output_dim
85+
86+ self .generator = Generator (args .input_dim , args .output_dim )
87+ self .discriminator_arr = [Discriminator (args .output_dim ) for _ in range (n_obj )]
88+ self .d_optimizer_arr = [optim .Adam (discriminator .parameters (), lr = args .lr ) for discriminator in
89+ self .discriminator_arr ]
90+ self .g_optimizer = optim .Adam (self .generator .parameters (), lr = args .lr )
91+ self .criterion = nn .BCELoss ()
92+
93+ def train (self ):
94+ d_loss_arr = []
95+ for epoch in range (self .num_epochs ):
96+ for _ in range (1 ): # Training discriminator more than generator improves stability
97+ real_samples_1 , real_samples_2 = sample_multiple_gaussian (self .batch_size ,
98+ self .output_dim ) # Real samples from Gaussian distribution
99+ real_samples_arr = [real_samples_1 , real_samples_2 ]
100+ z = torch .randn (self .batch_size , self .input_dim ) # Random noise
101+ fake_samples = self .generator (z ) # Fake samples from generator
102+
103+ for idx , discriminator in enumerate (self .discriminator_arr ):
104+ real_samples = real_samples_arr [idx ]
105+ d_real = discriminator (real_samples )
106+ d_fake = discriminator (
107+ fake_samples .detach ()) # Detach to avoid backpropagating through the generator
108+
109+ real_loss = self .criterion (d_real , torch .ones_like (d_real ))
110+ fake_loss = self .criterion (d_fake , torch .zeros_like (d_fake ))
111+ d_loss = (real_loss + fake_loss ) / 2
112+
113+ self .d_optimizer_arr [idx ].zero_grad ()
114+ d_loss .backward ()
115+ self .d_optimizer_arr [idx ].step ()
116+
117+ # Generator training
118+ z = torch .randn (self .batch_size , self .input_dim )
119+ fake_samples = self .generator (z )
120+
121+ g_loss_arr = []
122+ for idx , discriminator in enumerate (self .discriminator_arr ):
123+ d_fake = discriminator (fake_samples )
124+ g_loss = self .criterion (d_fake , torch .ones_like (d_fake ))
125+ g_loss_arr .append (g_loss )
126+
127+ g_loss_arr = torch .stack (g_loss_arr )
128+ self .g_optimizer .zero_grad ()
129+ scalar_loss = torch .dot (g_loss_arr , torch .Tensor (self .pref ))
130+ scalar_loss .backward ()
131+ self .g_optimizer .step ()
132+ # Logging
133+ if (epoch + 1 ) % 100 == 0 :
134+ print (
135+ f'Epoch [{ epoch + 1 } /{ self .num_epochs } ], d_loss: { d_loss .item ():.4f} , g_loss: { g_loss .item ():.4f} ' )
136+ d_loss_arr .append (d_loss .item ())
137+
138+ def generate_samples (self , test_size ):
139+ with torch .no_grad ():
140+ z = torch .randn (test_size , self .input_dim )
141+ generated_samples = self .generator (z )
142+ real_samples_1 , real_samples_2 = sample_multiple_gaussian (self .batch_size ,
143+ self .output_dim ) # Real samples from Gaussian distribution
144+ return generated_samples .numpy (), real_samples_1 , real_samples_2
145+
146+
147+ if __name__ == '__main__' :
148+ parser = argparse .ArgumentParser (description = 'example script' )
149+ parser .add_argument ('--input-dim' , type = int , default = 10 ) # What does it mean?
150+ parser .add_argument ('--output-dim' , type = int , default = 2 )
151+ parser .add_argument ('--n-obj' , type = int , default = 2 )
152+
153+ parser .add_argument ('--batch-size' , type = int , default = 64 )
154+ parser .add_argument ('--test-size' , type = int , default = 64 )
155+ parser .add_argument ('--num-epochs' , type = int , default = 500 )
156+ parser .add_argument ('--lr' , type = float , default = 5e-6 )
157+ parser .add_argument ('--pref0' , type = float , default = 0.2 )
158+
159+ # Hyperparameters
160+ args = parser .parse_args ()
161+ pref = np .array ([args .pref0 , 1 - args .pref0 ])
162+ print ('Preference: ' , pref )
163+ # Model, optimizer, and loss function
164+
165+ trainer = MOGANTrainer (lr = args .lr , num_epochs = args .num_epochs ,
166+ n_obj = args .n_obj , pref = pref , batch_size = args .batch_size , input_dim = args .input_dim ,
167+ output_dim = args .output_dim )
168+
169+ trainer .train ()
170+ generate_samples , sample1 , sample2 = trainer .generate_samples (args .test_size )
171+
172+ folder_name = 'D:\\ pycharm_project\\ libmoon\\ Output\\ divergence'
173+ os .makedirs (folder_name , exist_ok = True )
174+ plot_figure (folder_name , generate_samples , sample1 , sample2 , pref )
0 commit comments