@@ -166,8 +166,7 @@ def make_batch(self):
166
166
s_prime_lst .append (s_prime )
167
167
prob_a_lst .append ([prob_a ])
168
168
value_lst .append (v )
169
- done_mask = 0 if done else 1
170
- done_lst .append ([done_mask ])
169
+ done_lst .append ([done ])
171
170
s ,a ,r ,s_prime ,v ,done_mask ,prob_a = torch .tensor (s_lst , dtype = torch .float ), torch .tensor (a_lst ), \
172
171
torch .tensor (r_lst , dtype = torch .float ), torch .tensor (s_prime_lst , dtype = torch .float ), \
173
172
torch .tensor (value_lst ), torch .tensor (done_lst , dtype = torch .float ), torch .tensor (prob_a_lst )
@@ -176,20 +175,21 @@ def make_batch(self):
176
175
177
176
def train_net (self ):
178
177
s , a , r , s_prime , done_mask , prob_a , v = self .make_batch ()
179
- done_mask_ = torch .flip (done_mask , dims = (0 ,))
180
178
with torch .no_grad ():
181
179
advantage = torch .zeros_like (r )
182
180
lastgaelam = 0
183
- for t in reversed (range (s .shape [0 ]- 1 )):
184
- if done_mask [t + 1 ]:
185
- nextvalues = self .v (s [t + 1 ])
181
+
182
+ for t in reversed (range (s .shape [0 ])):
183
+ if done_mask [t ] or t == s .shape [0 ]- 1 :
184
+ nextvalues = self .v (s_prime [t ])
186
185
else :
187
186
nextvalues = v [t + 1 ]
188
- delta = r [t ] + gamma * nextvalues * done_mask_ [ t + 1 ] - v [t ]
189
- advantage [t ] = lastgaelam = delta + gamma * lmbda * lastgaelam * done_mask_ [ t + 1 ]
187
+ delta = r [t ] + gamma * nextvalues - v [t ]
188
+ advantage [t ] = lastgaelam = delta + gamma * lmbda * lastgaelam
190
189
191
190
if not np .isnan (advantage .std ()):
192
191
advantage = (advantage - advantage .mean ()) / (advantage .std () + 1e-8 )
192
+ assert advantage .shape == v .shape
193
193
td_target = advantage + v
194
194
195
195
for i in range (K_epoch ):
0 commit comments