Skip to content

Commit a3d553d

Browse files
committed
test
1 parent 0a3a121 commit a3d553d

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

ppo_gae_continuous.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ def train_net(self):
140140
# pi = self.pi(s, softmax_dim=1)
141141
# pi_a = pi.gather(1,a)
142142
ratio = torch.exp(log_pi_a - torch.log(prob_a)) # a/b == exp(log(a)-log(b))
143-
144143
surr1 = ratio * advantage
145144
surr2 = torch.clamp(ratio, 1-eps_clip, 1+eps_clip) * advantage
146145
loss = -torch.min(surr1, surr2) + F.smooth_l1_loss(self.v(s) , td_target.detach())
@@ -150,7 +149,8 @@ def train_net(self):
150149
self.optimizer.step()
151150

152151
def main():
153-
env = gym.make('HalfCheetah-v2')
152+
# env = gym.make('HalfCheetah-v2')
153+
env = gym.make('Ant-v2')
154154
state_dim = env.observation_space.shape[0]
155155
action_dim = env.action_space.shape[0]
156156
hidden_dim = 128
@@ -180,7 +180,6 @@ def main():
180180

181181
if done:
182182
break
183-
184183
if n_epi%print_interval==0 and n_epi!=0:
185184
print("# of episode :{}, avg score : {:.1f}".format(n_epi, score/print_interval))
186185
score = 0.0

0 commit comments

Comments
 (0)