1
+ ###
2
+ # Compared with ppo_gae_continuous2, using minibatch SGD, separate actor and critic networks.
3
+ ###
4
+
5
+ import gym
6
+ import torch
7
+ import torch .nn as nn
8
+ import torch .nn .functional as F
9
+ import torch .optim as optim
10
+ from torch .distributions import Categorical
11
+ # from torch.distributions import Normal
12
+ from torch .distributions .normal import Normal
13
+ import numpy as np
14
+ import wandb
15
+ from collections import deque
16
+ from torch .utils .tensorboard import SummaryWriter
17
+ import argparse
18
+ import random
19
+
20
+ #Hyperparameters
21
+ learning_rate = 3e-4
22
+ gamma = 0.99
23
+ lmbda = 0.95
24
+ eps_clip = 0.2
25
+ batch_size = 2048
26
+ mini_batch = int (batch_size // 32 )
27
+ K_epoch = 10
28
+ T_horizon = 10000
29
+ n_epis = 10000
30
+ vf_coef = 0.5
31
+
32
+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
33
+
34
+ def parse_args ():
35
+ parser = argparse .ArgumentParser ()
36
+ parser .add_argument ("--run" , type = str , default = 'test' ,
37
+ help = "the name of this experiment" )
38
+ parser .add_argument ('--wandb_activate' , type = bool , default = False , help = 'whether wandb' )
39
+ parser .add_argument ("--wandb_entity" , type = str , default = None ,
40
+ help = "the entity (team) of wandb's project" )
41
+ args = parser .parse_args ()
42
+ print (args )
43
+ return args
44
+
45
+ def layer_init (layer , std = np .sqrt (2 ), bias_const = 0.0 ):
46
+ torch .nn .init .orthogonal_ (layer .weight , std )
47
+ torch .nn .init .constant_ (layer .bias , bias_const )
48
+ return layer
49
+
50
+ class Actor (nn .Module ):
51
+ def __init__ (self , num_inputs , num_actions , hidden_dim ):
52
+ super ().__init__ ()
53
+ self .mean = nn .Sequential (
54
+ layer_init (nn .Linear (num_inputs , hidden_dim )),
55
+ nn .Tanh (),
56
+ layer_init (nn .Linear (hidden_dim , hidden_dim )),
57
+ nn .Tanh (),
58
+ layer_init (nn .Linear (hidden_dim , num_actions ), std = 0.01 ),
59
+ )
60
+ self .logstd = nn .Parameter (torch .zeros (1 , num_actions ))
61
+
62
+ def forward (self , x ):
63
+ action_mean = self .mean (x )
64
+ action_logstd = self .logstd .expand_as (action_mean )
65
+
66
+ return action_mean .squeeze (), action_logstd .squeeze ()
67
+
68
+ class Critic (nn .Module ):
69
+ def __init__ (self , num_inputs , hidden_dim ):
70
+ super ().__init__ ()
71
+ self .model = nn .Sequential (
72
+ layer_init (nn .Linear (num_inputs , hidden_dim )),
73
+ nn .Tanh (),
74
+ layer_init (nn .Linear (hidden_dim , hidden_dim )),
75
+ nn .Tanh (),
76
+ layer_init (nn .Linear (hidden_dim , 1 ), std = 1.0 ),
77
+ )
78
+
79
+ def forward (self , x ):
80
+ return self .model (x )
81
+
82
+
83
+ class PPO ():
84
+ def __init__ (self , num_inputs , num_actions , hidden_dim ):
85
+ self .data = deque (maxlen = batch_size ) # a ring buffer
86
+ self .max_grad_norm = 0.5
87
+ self .v_loss_clip = True
88
+
89
+ self .critic = Critic (num_inputs , hidden_dim ).to (device )
90
+ self .actor = Actor (num_inputs , num_actions , hidden_dim ).to (device )
91
+
92
+ self .parameters = list (self .critic .parameters ()) + list (self .actor .parameters ())
93
+ self .optimizer = optim .Adam (self .parameters , lr = learning_rate , eps = 1e-5 )
94
+
95
+ def pi (self , x ):
96
+ return self .actor (x )
97
+
98
+ def v (self , x ):
99
+ return self .critic (x )
100
+
101
+ def get_action_and_value (self , x , action = None ):
102
+ mean , log_std = self .pi (x )
103
+ std = torch .exp (log_std )
104
+ normal = Normal (mean , std )
105
+ if action is None :
106
+ action = normal .sample ()
107
+ log_prob = normal .log_prob (action ).sum (- 1 )
108
+ value = self .v (x )
109
+ return action .cpu ().numpy (), log_prob , value
110
+
111
+ def put_data (self , transition ):
112
+ self .data .append (transition )
113
+
114
+ def make_batch (self ,):
115
+ s , a , r , s_prime , logprob_a , v , done_mask = zip (* self .data )
116
+ s ,a ,r ,s_prime ,logprob_a ,v ,done_mask = torch .tensor (np .array (s ), dtype = torch .float ), torch .tensor (np .array (a )), \
117
+ torch .tensor (np .array (r ), dtype = torch .float ).unsqueeze (- 1 ), torch .tensor (np .array (s_prime ), dtype = torch .float ), \
118
+ torch .tensor (logprob_a ).unsqueeze (- 1 ), torch .tensor (v ).unsqueeze (- 1 ), torch .tensor (np .array (done_mask ), dtype = torch .float ).unsqueeze (- 1 )
119
+ return s .to (device ), a .to (device ), r .to (device ), s_prime .to (device ), done_mask .to (device ), logprob_a .to (device ), v .to (device )
120
+
121
+ def train_net (self ):
122
+ s , a , r , s_prime , done_mask , logprob_a , v = self .make_batch ()
123
+ with torch .no_grad ():
124
+ advantage = torch .zeros_like (r ).to (device )
125
+ lastgaelam = 0
126
+ for t in reversed (range (s .shape [0 ])):
127
+ if done_mask [t ] or t == s .shape [0 ]- 1 :
128
+ nextvalues = self .v (s_prime [t ])
129
+ else :
130
+ nextvalues = v [t + 1 ]
131
+ delta = r [t ] + gamma * nextvalues * (1.0 - done_mask [t ]) - v [t ]
132
+ advantage [t ] = lastgaelam = delta + gamma * lmbda * lastgaelam * (1.0 - done_mask [t ])
133
+ assert advantage .shape == v .shape
134
+ td_target = advantage + v
135
+
136
+ # minibatch SGD over the entire buffer (K epochs)
137
+ b_inds = np .arange (batch_size )
138
+ for epoch in range (K_epoch ):
139
+ np .random .shuffle (b_inds )
140
+ for start in range (0 , batch_size , mini_batch ):
141
+ end = start + mini_batch
142
+ minibatch_idx = b_inds [start :end ]
143
+
144
+ bs , ba , blogprob_a , bv = s [minibatch_idx ], a [minibatch_idx ], logprob_a [minibatch_idx ].reshape (- 1 ), v [minibatch_idx ].reshape (- 1 )
145
+ badvantage , btd_target = advantage [minibatch_idx ].reshape (- 1 ), td_target [minibatch_idx ].reshape (- 1 )
146
+
147
+ if not torch .isnan (badvantage .std ()):
148
+ badvantage = (badvantage - badvantage .mean ()) / (badvantage .std () + 1e-8 )
149
+
150
+ _ , newlogprob_a , new_vs = self .get_action_and_value (bs , ba )
151
+ new_vs = new_vs .reshape (- 1 )
152
+ ratio = torch .exp (newlogprob_a - blogprob_a ) # a/b == exp(log(a)-log(b))
153
+ surr1 = - ratio * badvantage
154
+ surr2 = - torch .clamp (ratio , 1 - eps_clip , 1 + eps_clip ) * badvantage
155
+ policy_loss = torch .max (surr1 , surr2 ).mean ()
156
+
157
+ if self .v_loss_clip : # clipped value loss
158
+ v_clipped = bv + torch .clamp (new_vs - bv , - eps_clip , eps_clip )
159
+ value_loss_clipped = (v_clipped - btd_target ) ** 2
160
+ value_loss_unclipped = (new_vs - btd_target ) ** 2
161
+ value_loss_max = torch .max (value_loss_unclipped , value_loss_clipped )
162
+ value_loss = 0.5 * value_loss_max .mean ()
163
+ else :
164
+ value_loss = F .smooth_l1_loss (new_vs , btd_target )
165
+
166
+ loss = policy_loss + vf_coef * value_loss
167
+ self .optimizer .zero_grad ()
168
+ loss .backward ()
169
+ nn .utils .clip_grad_norm_ (self .parameters , self .max_grad_norm )
170
+ self .optimizer .step ()
171
+
172
+ def main ():
173
+ args = parse_args ()
174
+ env_id = 2
175
+ seed = 1
176
+ env_name = ['HalfCheetah-v2' , 'Ant-v2' , 'Hopper-v2' ][env_id ]
177
+ env = gym .make (env_name )
178
+ env = gym .wrappers .RecordEpisodeStatistics (env ) # bypass the reward normalization to record episodic return
179
+ env = gym .wrappers .ClipAction (env )
180
+ env = gym .wrappers .NormalizeObservation (env )
181
+ env = gym .wrappers .TransformObservation (env , lambda obs : np .clip (obs , - 10 , 10 ))
182
+ env = gym .wrappers .NormalizeReward (env ) # this improves learning significantly
183
+ env = gym .wrappers .TransformReward (env , lambda reward : np .clip (reward , - 10 , 10 ))
184
+ env .seed (seed )
185
+ env .action_space .seed (seed )
186
+ env .observation_space .seed (seed )
187
+ random .seed (seed )
188
+ np .random .seed (seed )
189
+ torch .manual_seed (seed )
190
+ torch .backends .cudnn .deterministic = True
191
+ state_dim = env .observation_space .shape [0 ]
192
+ action_dim = env .action_space .shape [0 ]
193
+ hidden_dim = 64
194
+ model = PPO (state_dim , action_dim , hidden_dim )
195
+ score = 0.0
196
+ print_interval = 1
197
+ step = 0
198
+ update = 1
199
+
200
+ if args .wandb_activate :
201
+ wandb .init (
202
+ project = args .run ,
203
+ entity = args .wandb_entity ,
204
+ sync_tensorboard = True ,
205
+ config = vars (args ),
206
+ name = args .run + f'_{ env_name } ' ,
207
+ monitor_gym = True ,
208
+ save_code = True ,
209
+ )
210
+ writer = SummaryWriter (f"runs/{ args .run } _{ env_name } " )
211
+
212
+ for n_epi in range (n_epis ):
213
+ s = env .reset ()
214
+ done = False
215
+ epi_r = 0.
216
+ ## learning rate schedule
217
+ # frac = 1.0 - (n_epi - 1.0) / n_epis
218
+ # lrnow = frac * learning_rate
219
+ # model.optimizer.param_groups[0]["lr"] = lrnow
220
+
221
+ # while not done:
222
+ for t in range (T_horizon ):
223
+ step += 1
224
+ with torch .no_grad ():
225
+ a , logprob , v = model .get_action_and_value (torch .from_numpy (s ).float ().unsqueeze (0 ).to (device ))
226
+ s_prime , r , done , info = env .step (a )
227
+ # env.render()
228
+
229
+ model .put_data ((s , a , r , s_prime , logprob , v .squeeze (- 1 ), done ))
230
+ s = s_prime
231
+
232
+ score += r
233
+
234
+ if step % batch_size == 0 and step > 0 :
235
+ model .train_net ()
236
+ update += 1
237
+ eff_update = update
238
+
239
+ if 'episode' in info .keys ():
240
+ epi_r = info ['episode' ]['r' ]
241
+ print (f"Global steps: { step } , score: { epi_r } " )
242
+
243
+ if done :
244
+ break
245
+
246
+ if n_epi % print_interval == 0 and n_epi != 0 :
247
+ # print("Global steps: {}, # of episode :{}, avg score : {:.1f}".format(step, n_epi, score/print_interval)) # this is normalized reward
248
+ writer .add_scalar ("charts/episodic_return" , epi_r , n_epi )
249
+ writer .add_scalar ("charts/episodic_length" , t , n_epi )
250
+ writer .add_scalar ("charts/update" , update , n_epi )
251
+
252
+ score = 0.0
253
+
254
+ env .close ()
255
+
256
+ if __name__ == '__main__' :
257
+ main ()
0 commit comments