Skip to content

Commit 34f6f70

Browse files
committed
add another version of ppo_gae_continuous
1 parent ca28b7e commit 34f6f70

File tree

2 files changed

+215
-2
lines changed

2 files changed

+215
-2
lines changed

ppo_gae_continuous.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def pi(self, x):
6060

6161
x = F.tanh(self.linear1(x))
6262
x = F.tanh(self.linear2(x))
63-
x1 = F.tanh(self.linear3(x))
63+
x1 = F.tanh(self.linear3(x).detach()) # std learning not BP to the feature
6464
x2 = F.tanh(self.linear4(x))
6565

6666
mean = F.tanh(self.mean_linear(x1))
@@ -144,8 +144,11 @@ def train_net(self):
144144
advantage_lst.append([advantage])
145145
advantage_lst.reverse()
146146
advantage = torch.tensor(advantage_lst, dtype=torch.float)
147+
147148
if not np.isnan(advantage.std()):
148-
advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-5)
149+
advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)
150+
151+
# td_target = advantage + self.v(s)
149152

150153
for i in range(K_epoch):
151154
mean, log_std = self.pi(s)

ppo_gae_continuous2.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
###
2+
# Similar as ppo_gae_continous.py, but change the update function
3+
# to follow the stablebaseline PPO2 (https://stable-baselines.readthedocs.io/en/master/_modules/stable_baselines/ppo2/ppo2.html#PPO2) and cleanrl (https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_continuous_action.py)
4+
# it track value of state during sample collection and thus save computation.
5+
###
6+
7+
import gym
8+
import torch
9+
import torch.nn as nn
10+
import torch.nn.functional as F
11+
import torch.optim as optim
12+
from torch.distributions import Categorical
13+
from torch.distributions import Normal
14+
import numpy as np
15+
16+
#Hyperparameters
17+
learning_rate = 1e-4
18+
gamma = 0.99
19+
lmbda = 0.95
20+
eps_clip = 0.1
21+
batch_size = 1280
22+
K_epoch = 10
23+
T_horizon = 10000
24+
25+
class NormalizedActions(gym.ActionWrapper):
26+
def _action(self, action):
27+
low = self.action_space.low
28+
high = self.action_space.high
29+
30+
action = low + (action + 1.0) * 0.5 * (high - low)
31+
action = np.clip(action, low, high)
32+
33+
return action
34+
35+
def _reverse_action(self, action):
36+
low = self.action_space.low
37+
high = self.action_space.high
38+
39+
action = 2 * (action - low) / (high - low) - 1
40+
action = np.clip(action, low, high)
41+
42+
return action
43+
44+
class PPO(nn.Module):
45+
def __init__(self, num_inputs, num_actions, hidden_size, action_range = 1.):
46+
super(PPO, self).__init__()
47+
self.data = []
48+
self.action_range = action_range
49+
50+
self.linear1 = nn.Linear(num_inputs, hidden_size)
51+
self.linear2 = nn.Linear(hidden_size, hidden_size)
52+
self.linear3 = nn.Linear(hidden_size, hidden_size)
53+
self.linear4 = nn.Linear(hidden_size, hidden_size)
54+
self.linear5 = nn.Linear(hidden_size, hidden_size)
55+
self.linear6 = nn.Linear(hidden_size, hidden_size)
56+
57+
self.mean_linear = nn.Linear(hidden_size, num_actions)
58+
self.log_std_linear = nn.Linear(hidden_size, num_actions)
59+
# self.log_std_param = nn.Parameter(torch.zeros(num_actions, requires_grad=True))
60+
61+
self.v_linear = nn.Linear(hidden_size, 1)
62+
63+
self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
64+
65+
def pi(self, x):
66+
67+
x = F.tanh(self.linear1(x))
68+
x = F.tanh(self.linear2(x))
69+
x1 = F.tanh(self.linear3(x).detach()) # std learning not BP to the feature
70+
x2 = F.tanh(self.linear4(x))
71+
72+
mean = F.tanh(self.mean_linear(x1))
73+
log_std = self.log_std_linear(x2)
74+
# log_std = self.log_std_param.expand_as(mean)
75+
76+
return mean, log_std
77+
78+
def v(self, x):
79+
x = F.tanh(self.linear1(x))
80+
x = F.tanh(self.linear2(x))
81+
x = F.tanh(self.linear5(x))
82+
x = F.tanh(self.linear6(x))
83+
84+
v = self.v_linear(x)
85+
return v
86+
87+
def get_action(self, x):
88+
mean, log_std = self.pi(x)
89+
std = log_std.exp()
90+
normal = Normal(mean, std)
91+
action = normal.sample()
92+
log_prob = normal.log_prob(action).sum(-1)
93+
prob = log_prob.exp()
94+
95+
## The following way of generating action seems not correct.
96+
## All dimensions of action depends on the same hidden variable z.
97+
## In some envs like Ant-v2, it may let the agent not fall easity due to the correlation of actions.
98+
## But this does not in general holds true, and may cause numerical problem (nan) in update.
99+
# normal = Normal(0, 1)
100+
# z = normal.sample()
101+
# action = mean + std*z
102+
# log_prob = Normal(mean, std).log_prob(action)
103+
# log_prob = log_prob.sum(dim=-1, keepdim=True) # reduce dim
104+
# prob = log_prob.exp()
105+
106+
action = self.action_range*action # scale the action
107+
value = self.v(x).detach().numpy()
108+
return action.detach().numpy(), prob, value
109+
110+
def get_log_prob(self, mean, log_std, action):
111+
action = action/self.action_range
112+
log_prob = Normal(mean, log_std.exp()).log_prob(action)
113+
log_prob = log_prob.sum(dim=-1, keepdim=True) # reduce dim
114+
return log_prob
115+
116+
def put_data(self, transition):
117+
self.data.append(transition)
118+
119+
def make_batch(self):
120+
s_lst, a_lst, r_lst, s_prime_lst, prob_a_lst, value_lst, done_lst = [], [], [], [], [], [], []
121+
for transition in self.data:
122+
s, a, r, s_prime, prob_a, v, done = transition
123+
124+
s_lst.append(s)
125+
a_lst.append(a)
126+
r_lst.append([r])
127+
s_prime_lst.append(s_prime)
128+
prob_a_lst.append([prob_a])
129+
value_lst.append([v])
130+
done_mask = 0 if done else 1
131+
done_lst.append([done_mask])
132+
s,a,r,s_prime,v,done_mask,prob_a = torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
133+
torch.tensor(r_lst, dtype=torch.float), torch.tensor(s_prime_lst, dtype=torch.float), \
134+
torch.tensor(value_lst), torch.tensor(done_lst, dtype=torch.float), torch.tensor(prob_a_lst)
135+
self.data = []
136+
return s, a, r, s_prime, done_mask, prob_a, v
137+
138+
def train_net(self):
139+
s, a, r, s_prime, done_mask, prob_a, v = self.make_batch()
140+
done_mask_ = torch.flip(done_mask, dims=(0,))
141+
with torch.no_grad():
142+
advantage = torch.zeros_like(r)
143+
lastgaelam = 0
144+
for t in reversed(range(s.shape[0]-1)):
145+
if done_mask[t+1]:
146+
nextvalues = self.v(s[t+1])
147+
else:
148+
nextvalues = v[t+1]
149+
delta = r[t] + gamma * nextvalues * done_mask_[t+1] - v[t]
150+
advantage[t] = lastgaelam = delta + gamma * lmbda * lastgaelam * done_mask_[t+1]
151+
152+
if not np.isnan(advantage.std()):
153+
advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)
154+
155+
td_target = advantage + self.v(s)
156+
157+
for i in range(K_epoch):
158+
mean, log_std = self.pi(s)
159+
log_pi_a = self.get_log_prob(mean, log_std, a)
160+
# pi = self.pi(s, softmax_dim=1)
161+
# pi_a = pi.gather(1,a)
162+
ratio = torch.exp(log_pi_a - torch.log(prob_a)) # a/b == exp(log(a)-log(b))
163+
surr1 = ratio * advantage
164+
surr2 = torch.clamp(ratio, 1-eps_clip, 1+eps_clip) * advantage
165+
loss = -torch.min(surr1, surr2) + F.smooth_l1_loss(self.v(s) , td_target.detach())
166+
167+
self.optimizer.zero_grad()
168+
loss.mean().backward()
169+
self.optimizer.step()
170+
171+
def main():
172+
# env = gym.make('HalfCheetah-v2')
173+
env = gym.make('Ant-v2')
174+
state_dim = env.observation_space.shape[0]
175+
action_dim = env.action_space.shape[0]
176+
hidden_dim = 128
177+
model = PPO(state_dim, action_dim, hidden_dim)
178+
score = 0.0
179+
print_interval = 2
180+
step = 0
181+
182+
for n_epi in range(10000):
183+
s = env.reset()
184+
done = False
185+
# while not done:
186+
for t in range(T_horizon):
187+
step += 1
188+
a, prob, v = model.get_action(torch.from_numpy(s).float())
189+
s_prime, r, done, info = env.step(a)
190+
# print(a)
191+
# env.render()
192+
193+
model.put_data((s, a, r, s_prime, prob, v, done))
194+
s = s_prime
195+
196+
score += r
197+
198+
if (step+1) % batch_size == 0:
199+
model.train_net()
200+
201+
if done:
202+
break
203+
if n_epi%print_interval==0 and n_epi!=0:
204+
print("# of episode :{}, avg score : {:.1f}".format(n_epi, score/print_interval))
205+
score = 0.0
206+
207+
env.close()
208+
209+
if __name__ == '__main__':
210+
main()

0 commit comments

Comments
 (0)