Skip to content

Commit 14828ed

Browse files
committed
fix ppo_gae_continuous2
1 parent 791354a commit 14828ed

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

ppo_gae_continuous2.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,7 @@ def make_batch(self):
166166
s_prime_lst.append(s_prime)
167167
prob_a_lst.append([prob_a])
168168
value_lst.append(v)
169-
done_mask = 0 if done else 1
170-
done_lst.append([done_mask])
169+
done_lst.append([done])
171170
s,a,r,s_prime,v,done_mask,prob_a = torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
172171
torch.tensor(r_lst, dtype=torch.float), torch.tensor(s_prime_lst, dtype=torch.float), \
173172
torch.tensor(value_lst), torch.tensor(done_lst, dtype=torch.float), torch.tensor(prob_a_lst)
@@ -176,20 +175,21 @@ def make_batch(self):
176175

177176
def train_net(self):
178177
s, a, r, s_prime, done_mask, prob_a, v = self.make_batch()
179-
done_mask_ = torch.flip(done_mask, dims=(0,))
180178
with torch.no_grad():
181179
advantage = torch.zeros_like(r)
182180
lastgaelam = 0
183-
for t in reversed(range(s.shape[0]-1)):
184-
if done_mask[t+1]:
185-
nextvalues = self.v(s[t+1])
181+
182+
for t in reversed(range(s.shape[0])):
183+
if done_mask[t] or t == s.shape[0]-1:
184+
nextvalues = self.v(s_prime[t])
186185
else:
187186
nextvalues = v[t+1]
188-
delta = r[t] + gamma * nextvalues * done_mask_[t+1] - v[t]
189-
advantage[t] = lastgaelam = delta + gamma * lmbda * lastgaelam * done_mask_[t+1]
187+
delta = r[t] + gamma * nextvalues - v[t]
188+
advantage[t] = lastgaelam = delta + gamma * lmbda * lastgaelam
190189

191190
if not np.isnan(advantage.std()):
192191
advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)
192+
assert advantage.shape == v.shape
193193
td_target = advantage + v
194194

195195
for i in range(K_epoch):

0 commit comments

Comments
 (0)