|
| 1 | +''' |
| 2 | +Soft Actor-Critic version 2 |
| 3 | +using target Q instead of V net: 2 Q net, 2 target Q net, 1 policy net |
| 4 | +add alpha loss compared with version 1 |
| 5 | +paper: https://arxiv.org/pdf/1812.05905.pdf |
| 6 | +
|
| 7 | +Discrete version reference: |
| 8 | +https://towardsdatascience.com/adapting-soft-actor-critic-for-discrete-action-spaces-a20614d4a50a |
| 9 | +''' |
| 10 | + |
| 11 | + |
| 12 | +import random |
| 13 | +import gym |
| 14 | +import numpy as np |
| 15 | +import torch |
| 16 | +import torch.nn as nn |
| 17 | +import torch.optim as optim |
| 18 | +import torch.nn.functional as F |
| 19 | +from torch.distributions import Categorical |
| 20 | +from IPython.display import clear_output |
| 21 | +import matplotlib.pyplot as plt |
| 22 | +import argparse |
| 23 | +from common.buffers import ReplayBufferPER |
| 24 | + |
| 25 | +GPU = True |
| 26 | +device_idx = 0 |
| 27 | +if GPU: |
| 28 | + device = torch.device("cuda:" + str(device_idx) if torch.cuda.is_available() else "cpu") |
| 29 | +else: |
| 30 | + device = torch.device("cpu") |
| 31 | +print(device) |
| 32 | + |
| 33 | + |
| 34 | +parser = argparse.ArgumentParser(description='Train or test neural net motor controller.') |
| 35 | +parser.add_argument('--train', dest='train', action='store_true', default=False) |
| 36 | +parser.add_argument('--test', dest='test', action='store_true', default=False) |
| 37 | + |
| 38 | +args = parser.parse_args() |
| 39 | + |
| 40 | +class SoftQNetwork(nn.Module): |
| 41 | + def __init__(self, num_inputs, num_actions, hidden_size, init_w=3e-3): |
| 42 | + super(SoftQNetwork, self).__init__() |
| 43 | + |
| 44 | + self.linear1 = nn.Linear(num_inputs, hidden_size) |
| 45 | + self.linear2 = nn.Linear(hidden_size, hidden_size) |
| 46 | + # self.linear3 = nn.Linear(hidden_size, hidden_size) |
| 47 | + self.linear4 = nn.Linear(hidden_size, num_actions) |
| 48 | + |
| 49 | + self.linear4.weight.data.uniform_(-init_w, init_w) |
| 50 | + self.linear4.bias.data.uniform_(-init_w, init_w) |
| 51 | + |
| 52 | + def forward(self, state): |
| 53 | + x = F.tanh(self.linear1(state)) |
| 54 | + x = F.tanh(self.linear2(x)) |
| 55 | + # x = F.tanh(self.linear3(x)) |
| 56 | + x = self.linear4(x) |
| 57 | + return x |
| 58 | + |
| 59 | + |
| 60 | +class PolicyNetwork(nn.Module): |
| 61 | + def __init__(self, num_inputs, num_actions, hidden_size, init_w=3e-3, log_std_min=-20, log_std_max=2): |
| 62 | + super(PolicyNetwork, self).__init__() |
| 63 | + |
| 64 | + self.linear1 = nn.Linear(num_inputs, hidden_size) |
| 65 | + self.linear2 = nn.Linear(hidden_size, hidden_size) |
| 66 | + # self.linear3 = nn.Linear(hidden_size, hidden_size) |
| 67 | + # self.linear4 = nn.Linear(hidden_size, hidden_size) |
| 68 | + |
| 69 | + self.output = nn.Linear(hidden_size, num_actions) |
| 70 | + |
| 71 | + self.num_actions = num_actions |
| 72 | + |
| 73 | + def forward(self, state, softmax_dim=-1): |
| 74 | + x = F.tanh(self.linear1(state)) |
| 75 | + x = F.tanh(self.linear2(x)) |
| 76 | + # x = F.tanh(self.linear3(x)) |
| 77 | + # x = F.tanh(self.linear4(x)) |
| 78 | + |
| 79 | + probs = F.softmax(self.output(x), dim=softmax_dim) |
| 80 | + |
| 81 | + return probs |
| 82 | + |
| 83 | + def evaluate(self, state, epsilon=1e-6): |
| 84 | + ''' |
| 85 | + generate sampled action with state as input wrt the policy network; |
| 86 | + ''' |
| 87 | + probs = self.forward(state, softmax_dim=-1) |
| 88 | + log_probs = torch.log(probs) |
| 89 | + return log_probs |
| 90 | + |
| 91 | + def get_action(self, state, deterministic): |
| 92 | + state = torch.FloatTensor(state).unsqueeze(0).to(device) |
| 93 | + probs = self.forward(state) |
| 94 | + dist = Categorical(probs) |
| 95 | + |
| 96 | + if deterministic: |
| 97 | + action = np.argmax(probs.detach().cpu().numpy()) |
| 98 | + else: |
| 99 | + action = dist.sample().squeeze().detach().cpu().numpy() |
| 100 | + return action |
| 101 | + |
| 102 | + |
| 103 | +class SAC_Trainer(): |
| 104 | + def __init__(self, replay_buffer, hidden_dim): |
| 105 | + self.replay_buffer = replay_buffer |
| 106 | + |
| 107 | + self.soft_q_net1 = SoftQNetwork(state_dim, action_dim, hidden_dim).to(device) |
| 108 | + self.soft_q_net2 = SoftQNetwork(state_dim, action_dim, hidden_dim).to(device) |
| 109 | + self.target_soft_q_net1 = SoftQNetwork(state_dim, action_dim, hidden_dim).to(device) |
| 110 | + self.target_soft_q_net2 = SoftQNetwork(state_dim, action_dim, hidden_dim).to(device) |
| 111 | + self.policy_net = PolicyNetwork(state_dim, action_dim, hidden_dim).to(device) |
| 112 | + self.log_alpha = torch.zeros(1, dtype=torch.float32, requires_grad=True, device=device) |
| 113 | + print('Soft Q Network (1,2): ', self.soft_q_net1) |
| 114 | + print('Policy Network: ', self.policy_net) |
| 115 | + |
| 116 | + for target_param, param in zip(self.target_soft_q_net1.parameters(), self.soft_q_net1.parameters()): |
| 117 | + target_param.data.copy_(param.data) |
| 118 | + for target_param, param in zip(self.target_soft_q_net2.parameters(), self.soft_q_net2.parameters()): |
| 119 | + target_param.data.copy_(param.data) |
| 120 | + |
| 121 | + self.soft_q_criterion1 = nn.MSELoss(reduction="none") |
| 122 | + self.soft_q_criterion2 = nn.MSELoss(reduction="none") |
| 123 | + |
| 124 | + soft_q_lr = 3e-4 |
| 125 | + policy_lr = 3e-4 |
| 126 | + alpha_lr = 3e-4 |
| 127 | + |
| 128 | + self.soft_q_optimizer1 = optim.Adam(self.soft_q_net1.parameters(), lr=soft_q_lr) |
| 129 | + self.soft_q_optimizer2 = optim.Adam(self.soft_q_net2.parameters(), lr=soft_q_lr) |
| 130 | + self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=policy_lr) |
| 131 | + self.alpha_optimizer = optim.Adam([self.log_alpha], lr=alpha_lr) |
| 132 | + |
| 133 | + |
| 134 | + def update(self, batch_size, reward_scale=10., auto_entropy=True, target_entropy=-2, gamma=0.99, soft_tau=1e-2): |
| 135 | + state, action, reward, next_state, done = self.replay_buffer.sample(batch_size) |
| 136 | + # print('sample:', state, action, reward, done) |
| 137 | + |
| 138 | + state = torch.FloatTensor(state).to(device) |
| 139 | + next_state = torch.FloatTensor(next_state).to(device) |
| 140 | + action = torch.Tensor(action).to(torch.int64).to(device) |
| 141 | + reward = torch.FloatTensor(reward).unsqueeze(1).to(device) # reward is single value, unsqueeze() to add one dim to be [reward] at the sample dim; |
| 142 | + done = torch.FloatTensor(np.float32(done)).unsqueeze(1).to(device) |
| 143 | + predicted_q_value1 = self.soft_q_net1(state) |
| 144 | + predicted_q_value1 = predicted_q_value1.gather(1, action.unsqueeze(-1)) |
| 145 | + predicted_q_value2 = self.soft_q_net2(state) |
| 146 | + predicted_q_value2 = predicted_q_value2.gather(1, action.unsqueeze(-1)) |
| 147 | + log_prob = self.policy_net.evaluate(state) |
| 148 | + next_log_prob = self.policy_net.evaluate(next_state) |
| 149 | + # reward = reward_scale * (reward - reward.mean(dim=0)) / (reward.std(dim=0) + 1e-6) # normalize with batch mean and std; plus a small number to prevent numerical problem |
| 150 | + |
| 151 | + # Training Q Function |
| 152 | + self.alpha = self.log_alpha.exp() |
| 153 | + target_q_min = (next_log_prob.exp() * (torch.min(self.target_soft_q_net1(next_state),self.target_soft_q_net2(next_state)) - self.alpha * next_log_prob)).sum(dim=-1).unsqueeze(-1) |
| 154 | + target_q_value = reward + (1 - done) * gamma * target_q_min # if done==1, only reward |
| 155 | + q_value_loss1 = self.soft_q_criterion1(predicted_q_value1, target_q_value.detach()) # detach: no gradients for the variable |
| 156 | + q_value_loss2 = self.soft_q_criterion2(predicted_q_value2, target_q_value.detach()) |
| 157 | + weight_update = [min(l1.item(), l2.item()) for l1, l2 in zip(q_value_loss1, q_value_loss2)] |
| 158 | + self.replay_buffer.update_weights(weight_update) # update sample weights with td error |
| 159 | + |
| 160 | + self.soft_q_optimizer1.zero_grad() |
| 161 | + q_value_loss1.mean().backward() |
| 162 | + self.soft_q_optimizer1.step() |
| 163 | + self.soft_q_optimizer2.zero_grad() |
| 164 | + q_value_loss2.mean().backward() |
| 165 | + self.soft_q_optimizer2.step() |
| 166 | + |
| 167 | + # Training Policy Function |
| 168 | + predicted_new_q_value = torch.min(self.soft_q_net1(state),self.soft_q_net2(state)) |
| 169 | + policy_loss = (log_prob.exp()*(self.alpha * log_prob - predicted_new_q_value)).sum(dim=-1).mean() |
| 170 | + |
| 171 | + self.policy_optimizer.zero_grad() |
| 172 | + policy_loss.backward() |
| 173 | + self.policy_optimizer.step() |
| 174 | + |
| 175 | + # Updating alpha wrt entropy |
| 176 | + # alpha = 0.0 # trade-off between exploration (max entropy) and exploitation (max Q) |
| 177 | + if auto_entropy is True: |
| 178 | + alpha_loss = -(self.log_alpha * (log_prob + target_entropy).detach()).mean() |
| 179 | + # print('alpha loss: ',alpha_loss) |
| 180 | + self.alpha_optimizer.zero_grad() |
| 181 | + alpha_loss.backward() |
| 182 | + self.alpha_optimizer.step() |
| 183 | + else: |
| 184 | + self.alpha = 1. |
| 185 | + alpha_loss = 0 |
| 186 | + |
| 187 | + # print('q loss: ', q_value_loss1.item(), q_value_loss2.item()) |
| 188 | + # print('policy loss: ', policy_loss.item() ) |
| 189 | + |
| 190 | + # Soft update the target value net |
| 191 | + for target_param, param in zip(self.target_soft_q_net1.parameters(), self.soft_q_net1.parameters()): |
| 192 | + target_param.data.copy_( # copy data value into target parameters |
| 193 | + target_param.data * (1.0 - soft_tau) + param.data * soft_tau |
| 194 | + ) |
| 195 | + for target_param, param in zip(self.target_soft_q_net2.parameters(), self.soft_q_net2.parameters()): |
| 196 | + target_param.data.copy_( # copy data value into target parameters |
| 197 | + target_param.data * (1.0 - soft_tau) + param.data * soft_tau |
| 198 | + ) |
| 199 | + |
| 200 | + return predicted_new_q_value.mean() |
| 201 | + |
| 202 | + |
| 203 | + def save_model(self, path): |
| 204 | + torch.save(self.soft_q_net1.state_dict(), path+'_q1') |
| 205 | + torch.save(self.soft_q_net2.state_dict(), path+'_q2') |
| 206 | + torch.save(self.policy_net.state_dict(), path+'_policy') |
| 207 | + |
| 208 | + def load_model(self, path): |
| 209 | + self.soft_q_net1.load_state_dict(torch.load(path+'_q1')) |
| 210 | + self.soft_q_net2.load_state_dict(torch.load(path+'_q2')) |
| 211 | + self.policy_net.load_state_dict(torch.load(path+'_policy')) |
| 212 | + |
| 213 | + self.soft_q_net1.eval() |
| 214 | + self.soft_q_net2.eval() |
| 215 | + self.policy_net.eval() |
| 216 | + |
| 217 | + |
| 218 | +def plot(rewards): |
| 219 | + clear_output(True) |
| 220 | + plt.figure(figsize=(20,5)) |
| 221 | + plt.plot(rewards) |
| 222 | + plt.savefig('sac_v2.png') |
| 223 | + # plt.show() |
| 224 | + |
| 225 | + |
| 226 | +replay_buffer_size = 1e6 |
| 227 | +replay_buffer = ReplayBufferPER(replay_buffer_size) |
| 228 | + |
| 229 | +# choose env |
| 230 | +env = gym.make('CartPole-v1') |
| 231 | + |
| 232 | +state_dim = env.observation_space.shape[0] |
| 233 | +action_dim = env.action_space.n # discrete |
| 234 | + |
| 235 | +# hyper-parameters for RL training |
| 236 | +max_episodes = 10000 |
| 237 | +max_steps = 200 |
| 238 | +frame_idx = 0 |
| 239 | +batch_size = 256 |
| 240 | +update_itr = 1 |
| 241 | +AUTO_ENTROPY=True |
| 242 | +DETERMINISTIC=False |
| 243 | +hidden_dim = 64 |
| 244 | +rewards = [] |
| 245 | +model_path = './model/sac_discrete_v2' |
| 246 | +target_entropy = -1.*action_dim |
| 247 | +# target_entropy = 0.98 * -np.log(1 / action_dim) |
| 248 | + |
| 249 | +sac_trainer=SAC_Trainer(replay_buffer, hidden_dim=hidden_dim) |
| 250 | + |
| 251 | +if __name__ == '__main__': |
| 252 | + if args.train: |
| 253 | + # training loop |
| 254 | + for eps in range(max_episodes): |
| 255 | + state = env.reset() |
| 256 | + episode_reward = 0 |
| 257 | + |
| 258 | + for step in range(max_steps): |
| 259 | + action = sac_trainer.policy_net.get_action(state, deterministic = DETERMINISTIC) |
| 260 | + next_state, reward, done, _ = env.step(action) |
| 261 | + # env.render() |
| 262 | + |
| 263 | + replay_buffer.push(state, action, reward, next_state, done) |
| 264 | + |
| 265 | + state = next_state |
| 266 | + episode_reward += reward |
| 267 | + frame_idx += 1 |
| 268 | + |
| 269 | + |
| 270 | + if len(replay_buffer) > batch_size: |
| 271 | + for i in range(update_itr): |
| 272 | + _=sac_trainer.update(batch_size, reward_scale=1., auto_entropy=AUTO_ENTROPY, target_entropy=target_entropy) |
| 273 | + |
| 274 | + if done: |
| 275 | + break |
| 276 | + |
| 277 | + if eps % 20 == 0 and eps>0: # plot and model saving interval |
| 278 | + plot(rewards) |
| 279 | + np.save('rewards', rewards) |
| 280 | + sac_trainer.save_model(model_path) |
| 281 | + print('Episode: ', eps, '| Episode Reward: ', episode_reward, '| Episode Length: ', step) |
| 282 | + rewards.append(episode_reward) |
| 283 | + sac_trainer.save_model(model_path) |
| 284 | + |
| 285 | + if args.test: |
| 286 | + sac_trainer.load_model(model_path) |
| 287 | + for eps in range(10): |
| 288 | + state = env.reset() |
| 289 | + episode_reward = 0 |
| 290 | + |
| 291 | + for step in range(max_steps): |
| 292 | + action = sac_trainer.policy_net.get_action(state, deterministic = DETERMINISTIC) |
| 293 | + next_state, reward, done, _ = env.step(action) |
| 294 | + env.render() |
| 295 | + |
| 296 | + |
| 297 | + episode_reward += reward |
| 298 | + state=next_state |
| 299 | + |
| 300 | + print('Episode: ', eps, '| Episode Reward: ', episode_reward) |
0 commit comments