Skip to content

Commit 56f0160

Browse files
committed
fix pmoe_ppo
1 parent fd11b43 commit 56f0160

File tree

4 files changed

+48
-19
lines changed

4 files changed

+48
-19
lines changed

pmoe_ppo.py

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1-
###
2-
# Compared with ppo_gae_continuous2, using minibatch SGD, separate actor and critic networks.
3-
###
1+
'''
2+
Probabilistic Mixture-of-Experts
3+
paper: https://arxiv.org/abs/2104.09122
4+
5+
Core features:
6+
It replaces the diagonal Gaussian distribution with (differentiable) Gaussian mixture model for policy function approximation, which is more expressive.
7+
This version is based on on-policy PPO algorithm.
8+
'''
49

510
import gym
611
import torch
@@ -134,18 +139,28 @@ def get_action_and_value(self, x, action=None, select_from_mixture=True, track_g
134139
mean, log_std, mix_coef = self.pi(x)
135140
std = torch.exp(log_std)
136141
normal = Normal(mean, std)
142+
143+
full_action = normal.sample()
144+
if select_from_mixture:
145+
mix_dist = Categorical(mix_coef)
146+
index = mix_dist.sample()
147+
a = full_action[index]
148+
else:
149+
a = full_action
150+
137151
if action is None:
138-
full_action = normal.sample()
139-
if select_from_mixture:
140-
index = Categorical(mix_coef).sample()
141-
action = full_action[index]
142-
# log_prob = normal.log_prob(full_action).sum((-1, -2)) # TODO prob of other entries?
143-
log_prob = normal.log_prob(action).sum(-1) # TODO prob of other entries?
152+
a_for_prob = a.unsqueeze(-2) # to (1, action_dim), matching with mean and std (K, action_dim)
153+
log_prob = (mix_coef @ normal.log_prob(a_for_prob).sum(-1).exp()).log()
154+
else: # work for batch
155+
a_for_prob = action.unsqueeze(-2) # use given action for calculating probability
156+
log_prob = torch.einsum('ij,ij->i', mix_coef, normal.log_prob(a_for_prob).sum(-1).exp()).log() # prob of action from the whole GMM, including the mixing
157+
a_for_prob = a_for_prob
158+
144159
value = self.v(x)
145160
if track_grad:
146-
return action, log_prob, value, mix_coef
161+
return a, log_prob, value, mix_coef
147162
else:
148-
return action.cpu().numpy(), log_prob, value, mix_coef
163+
return a.cpu().numpy(), log_prob, value, mix_coef
149164

150165
def put_data(self, transition):
151166
self.data.append(transition)
@@ -159,6 +174,7 @@ def make_batch(self,):
159174

160175
def train_net(self):
161176
s, a, r, s_prime, done_mask, logprob_a, v = self.make_batch()
177+
loss_list = []
162178
with torch.no_grad():
163179
advantage = torch.zeros_like(r).to(device)
164180
lastgaelam = 0
@@ -189,20 +205,19 @@ def train_net(self):
189205
# get mixing coefficients loss
190206
new_a, newlogprob_a, new_vs, new_mix_coef = self.get_action_and_value(bs, ba, select_from_mixture=False, track_grad=True)
191207
new_q = self.q_net(bs, new_a, match_shape=True)
192-
_, best_index = new_q.max(1)
193-
print(new_q.shape, best_index.shape, new_mix_coef.shape)
194-
coef_loss = F.mse_loss(new_mix_coef, F.one_hot(best_index, self.mix_num).float())
208+
_, best_index = new_q.max(-1)
209+
coef_loss = F.mse_loss(new_mix_coef, F.one_hot(best_index, self.mix_num).float()).mean()
195210

196211
# Q-net loss
197-
pred_q = self.q_net(bs, ba)
198-
q_loss = F.mse_loss(pred_q, btd_target)
199-
212+
pred_q = self.q_net(bs, ba).squeeze()
213+
q_loss = F.mse_loss(pred_q, btd_target).mean()
200214

201215
new_vs = new_vs.reshape(-1)
202216
ratio = torch.exp(newlogprob_a - blogprob_a) # a/b == exp(log(a)-log(b))
203217
surr1 = -ratio * badvantage
204218
surr2 = -torch.clamp(ratio, 1-eps_clip, 1+eps_clip) * badvantage
205219
policy_loss = torch.max(surr1, surr2).mean()
220+
# import pdb; pdb.set_trace()
206221

207222
if self.v_loss_clip: # clipped value loss
208223
v_clipped = bv + torch.clamp(new_vs - bv, -eps_clip, eps_clip)
@@ -218,6 +233,8 @@ def train_net(self):
218233
loss.backward()
219234
nn.utils.clip_grad_norm_(self.parameters, self.max_grad_norm)
220235
self.optimizer.step()
236+
loss_list = [coef_loss.item(), q_loss.item(), policy_loss.item(), value_loss.item()]
237+
return loss_list
221238

222239
def main():
223240
args = parse_args()
@@ -234,18 +251,21 @@ def main():
234251
env.seed(seed)
235252
env.action_space.seed(seed)
236253
env.observation_space.seed(seed)
254+
print(env.observation_space, env.action_space)
237255
random.seed(seed)
238256
np.random.seed(seed)
239257
torch.manual_seed(seed)
240258
torch.backends.cudnn.deterministic = True
241259
state_dim = env.observation_space.shape[0]
242260
action_dim = env.action_space.shape[0]
243261
hidden_dim = 64
244-
model = PMOE_PPO(state_dim, action_dim, hidden_dim)
262+
mix_num = 5 # number of experts
263+
model = PMOE_PPO(state_dim, action_dim, hidden_dim, mix_num)
245264
score = 0.0
246265
print_interval = 1
247266
step = 0
248267
update = 1
268+
loss_list = []
249269

250270
if args.wandb_activate:
251271
wandb.init(
@@ -277,12 +297,13 @@ def main():
277297
# env.render()
278298

279299
model.put_data((s, a, r, s_prime, logprob, v.squeeze(-1), done))
300+
280301
s = s_prime
281302

282303
score += r
283304

284305
if step % batch_size == 0 and step > 0:
285-
model.train_net()
306+
loss_list = model.train_net()
286307
update += 1
287308
eff_update = update
288309

@@ -298,6 +319,12 @@ def main():
298319
writer.add_scalar("charts/episodic_return", epi_r, n_epi)
299320
writer.add_scalar("charts/episodic_length", t, n_epi)
300321
writer.add_scalar("charts/update", update, n_epi)
322+
writer.add_scalar("charts/", update, n_epi)
323+
if len(loss_list) > 0:
324+
writer.add_scalar("charts/coeff_loss", loss_list[0], n_epi)
325+
writer.add_scalar("charts/Q_loss", loss_list[1], n_epi)
326+
writer.add_scalar("charts/policy_loss", loss_list[2], n_epi)
327+
writer.add_scalar("charts/value_loss", loss_list[3], n_epi)
301328

302329
score = 0.0
303330

pmoe.py renamed to pmoe_sac.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
55
Core features:
66
It replaces the diagonal Gaussian distribution with (differentiable) Gaussian mixture model for policy function approximation, which is more expressive.
7+
This version is based on off-policy SAC algorithm.
78
'''
89

910
import argparse

ppo_continuous_multiprocess2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
##################### hyper parameters ####################
7272

7373
ENV_NAME = 'LunarLanderContinuous-v2' # environment name: LunarLander-v2, Pendulum-v0
74+
7475
RANDOMSEED = 2 # random seed
7576

7677
EP_MAX = 1000 # total number of episodes for training

ppo_multi.png

55.5 KB
Loading

0 commit comments

Comments
 (0)