Skip to content

Commit 4dda8fa

Browse files
committed
make max bound SAC
1 parent 451e50b commit 4dda8fa

File tree

4 files changed

+352
-4
lines changed

4 files changed

+352
-4
lines changed

robot_nav/models/SAC/BSA1C.py

Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
from pathlib import Path
2+
3+
import numpy as np
4+
import torch
5+
import torch.nn.functional as F
6+
from statistics import mean
7+
import robot_nav.models.SAC.SAC_utils as utils
8+
from robot_nav.models.SAC.BSA1C_critic import QCritic as critic_model
9+
from robot_nav.models.SAC.SAC_actor import DiagGaussianActor as actor_model
10+
from torch.utils.tensorboard import SummaryWriter
11+
from robot_nav.utils import get_max_bound
12+
13+
14+
class BSA1C(object):
15+
"""SAC algorithm."""
16+
17+
def __init__(
18+
self,
19+
state_dim,
20+
action_dim,
21+
device,
22+
max_action,
23+
discount=0.99,
24+
init_temperature=0.1,
25+
alpha_lr=1e-4,
26+
alpha_betas=(0.9, 0.999),
27+
actor_lr=1e-4,
28+
actor_betas=(0.9, 0.999),
29+
actor_update_frequency=1,
30+
critic_lr=1e-4,
31+
critic_betas=(0.9, 0.999),
32+
critic_tau=0.005,
33+
critic_target_update_frequency=2,
34+
learnable_temperature=True,
35+
save_every=0,
36+
load_model=False,
37+
log_dist_and_hist=False,
38+
save_directory=Path("robot_nav/models/SAC/checkpoint"),
39+
model_name="BSAC",
40+
load_directory=Path("robot_nav/models/SAC/checkpoint"),
41+
bound_weight=0.25,
42+
):
43+
super().__init__()
44+
45+
self.state_dim = state_dim
46+
self.action_dim = action_dim
47+
self.action_range = (-max_action, max_action)
48+
self.device = torch.device(device)
49+
self.discount = discount
50+
self.critic_tau = critic_tau
51+
self.actor_update_frequency = actor_update_frequency
52+
self.critic_target_update_frequency = critic_target_update_frequency
53+
self.learnable_temperature = learnable_temperature
54+
self.save_every = save_every
55+
self.model_name = model_name
56+
self.save_directory = save_directory
57+
self.log_dist_and_hist = log_dist_and_hist
58+
self.bound_weight = bound_weight
59+
60+
self.train_metrics_dict = {
61+
"train_critic/loss_av": [],
62+
"train_actor/loss_av": [],
63+
"train_actor/target_entropy_av": [],
64+
"train_actor/entropy_av": [],
65+
"train_alpha/loss_av": [],
66+
"train_alpha/value_av": [],
67+
"train/batch_reward_av": [],
68+
}
69+
70+
self.critic = critic_model(
71+
obs_dim=self.state_dim,
72+
action_dim=action_dim,
73+
hidden_dim=400,
74+
hidden_depth=2,
75+
).to(self.device)
76+
self.critic_target = critic_model(
77+
obs_dim=self.state_dim,
78+
action_dim=action_dim,
79+
hidden_dim=400,
80+
hidden_depth=2,
81+
).to(self.device)
82+
self.critic_target.load_state_dict(self.critic.state_dict())
83+
84+
self.actor = actor_model(
85+
obs_dim=self.state_dim,
86+
action_dim=action_dim,
87+
hidden_dim=400,
88+
hidden_depth=2,
89+
log_std_bounds=[-5, 2],
90+
).to(self.device)
91+
92+
if load_model:
93+
self.load(filename=model_name, directory=load_directory)
94+
95+
self.log_alpha = torch.tensor(np.log(init_temperature)).to(self.device)
96+
self.log_alpha.requires_grad = True
97+
# set target entropy to -|A|
98+
self.target_entropy = -action_dim
99+
100+
# optimizers
101+
self.actor_optimizer = torch.optim.Adam(
102+
self.actor.parameters(), lr=actor_lr, betas=actor_betas
103+
)
104+
105+
self.critic_optimizer = torch.optim.Adam(
106+
self.critic.parameters(), lr=critic_lr, betas=critic_betas
107+
)
108+
109+
self.log_alpha_optimizer = torch.optim.Adam(
110+
[self.log_alpha], lr=alpha_lr, betas=alpha_betas
111+
)
112+
113+
self.critic_target.train()
114+
115+
self.actor.train(True)
116+
self.critic.train(True)
117+
self.step = 0
118+
self.writer = SummaryWriter(comment=model_name)
119+
120+
def save(self, filename, directory):
121+
Path(directory).mkdir(parents=True, exist_ok=True)
122+
torch.save(self.actor.state_dict(), "%s/%s_actor.pth" % (directory, filename))
123+
torch.save(self.critic.state_dict(), "%s/%s_critic.pth" % (directory, filename))
124+
torch.save(
125+
self.critic_target.state_dict(),
126+
"%s/%s_critic_target.pth" % (directory, filename),
127+
)
128+
129+
def load(self, filename, directory):
130+
self.actor.load_state_dict(
131+
torch.load("%s/%s_actor.pth" % (directory, filename))
132+
)
133+
self.critic.load_state_dict(
134+
torch.load("%s/%s_critic.pth" % (directory, filename))
135+
)
136+
self.critic_target.load_state_dict(
137+
torch.load("%s/%s_critic_target.pth" % (directory, filename))
138+
)
139+
print(f"Loaded weights from: {directory}")
140+
141+
def train(self, replay_buffer, iterations, batch_size):
142+
for _ in range(iterations):
143+
self.update(
144+
replay_buffer=replay_buffer, step=self.step, batch_size=batch_size
145+
)
146+
147+
for key, value in self.train_metrics_dict.items():
148+
if len(value):
149+
self.writer.add_scalar(key, mean(value), self.step)
150+
self.train_metrics_dict[key] = []
151+
self.step += 1
152+
153+
if self.save_every > 0 and self.step % self.save_every == 0:
154+
self.save(filename=self.model_name, directory=self.save_directory)
155+
156+
@property
157+
def alpha(self):
158+
return self.log_alpha.exp()
159+
160+
def get_action(self, obs, add_noise):
161+
if add_noise:
162+
return (
163+
self.act(obs) + np.random.normal(0, 0.2, size=self.action_dim)
164+
).clip(self.action_range[0], self.action_range[1])
165+
else:
166+
return self.act(obs)
167+
168+
def act(self, obs, sample=False):
169+
obs = torch.FloatTensor(obs).to(self.device)
170+
obs = obs.unsqueeze(0)
171+
dist = self.actor(obs)
172+
action = dist.sample() if sample else dist.mean
173+
action = action.clamp(*self.action_range)
174+
assert action.ndim == 2 and action.shape[0] == 1
175+
return utils.to_np(action[0])
176+
177+
def update_critic(self, obs, action, reward, next_obs, done, step):
178+
dist = self.actor(next_obs)
179+
next_action = dist.rsample()
180+
log_prob = dist.log_prob(next_action).sum(-1, keepdim=True)
181+
target_q = self.critic_target(next_obs, next_action)
182+
target_V = target_q - self.alpha.detach() * log_prob
183+
target_Q = reward + ((1 - done) * self.discount * target_V)
184+
target_Q = target_Q.detach()
185+
186+
# get current Q estimates
187+
current_Q = self.critic(obs, action)
188+
189+
max_bound = get_max_bound(
190+
next_obs, self.discount, 0.5, 1, 0.3, 10, 100, reward, done, self.device
191+
)
192+
193+
max_excess_Q = F.relu(current_Q - max_bound)
194+
max_bound_loss = (max_excess_Q**2).mean()
195+
max_bound_loss = self.bound_weight * max_bound_loss
196+
critic_loss = (
197+
F.mse_loss(current_Q, target_Q)
198+
+ max_bound_loss
199+
)
200+
self.train_metrics_dict["train_critic/loss_av"].append(critic_loss.item())
201+
self.writer.add_scalar("train_critic/loss", critic_loss, step)
202+
self.writer.add_scalar("train_critic/max_bound_loss", max_bound_loss, step)
203+
204+
# Optimize the critic
205+
self.critic_optimizer.zero_grad()
206+
critic_loss.backward()
207+
self.critic_optimizer.step()
208+
if self.log_dist_and_hist:
209+
self.critic.log(self.writer, step)
210+
211+
def update_actor_and_alpha(self, obs, step):
212+
dist = self.actor(obs)
213+
action = dist.rsample()
214+
log_prob = dist.log_prob(action).sum(-1, keepdim=True)
215+
actor_Q = self.critic(obs, action)
216+
217+
# actor_Q = torch.min(actor_Q, max_bound)
218+
actor_loss = (self.alpha.detach() * log_prob - actor_Q).mean()
219+
self.train_metrics_dict["train_actor/loss_av"].append(actor_loss.item())
220+
self.train_metrics_dict["train_actor/target_entropy_av"].append(
221+
self.target_entropy
222+
)
223+
self.train_metrics_dict["train_actor/entropy_av"].append(
224+
-log_prob.mean().item()
225+
)
226+
self.writer.add_scalar("train_actor/loss", actor_loss, step)
227+
self.writer.add_scalar("train_actor/target_entropy", self.target_entropy, step)
228+
self.writer.add_scalar("train_actor/entropy", -log_prob.mean(), step)
229+
230+
# optimize the actor
231+
self.actor_optimizer.zero_grad()
232+
actor_loss.backward()
233+
self.actor_optimizer.step()
234+
if self.log_dist_and_hist:
235+
self.actor.log(self.writer, step)
236+
237+
if self.learnable_temperature:
238+
self.log_alpha_optimizer.zero_grad()
239+
alpha_loss = (
240+
self.alpha * (-log_prob - self.target_entropy).detach()
241+
).mean()
242+
self.train_metrics_dict["train_alpha/loss_av"].append(alpha_loss.item())
243+
self.train_metrics_dict["train_alpha/value_av"].append(self.alpha.item())
244+
self.writer.add_scalar("train_alpha/loss", alpha_loss, step)
245+
self.writer.add_scalar("train_alpha/value", self.alpha, step)
246+
alpha_loss.backward()
247+
self.log_alpha_optimizer.step()
248+
249+
def update(
250+
self,
251+
replay_buffer,
252+
step,
253+
batch_size,
254+
max_lin_vel=0.5,
255+
max_ang_vel=1,
256+
goal_reward=100,
257+
distance_norm=10,
258+
time_step=0.3,
259+
):
260+
(
261+
batch_states,
262+
batch_actions,
263+
batch_rewards,
264+
batch_dones,
265+
batch_next_states,
266+
) = replay_buffer.sample_batch(batch_size)
267+
268+
state = torch.Tensor(batch_states).to(self.device)
269+
next_state = torch.Tensor(batch_next_states).to(self.device)
270+
action = torch.Tensor(batch_actions).to(self.device)
271+
reward = torch.Tensor(batch_rewards).to(self.device)
272+
done = torch.Tensor(batch_dones).to(self.device)
273+
274+
self.train_metrics_dict["train/batch_reward_av"].append(
275+
batch_rewards.mean().item()
276+
)
277+
self.writer.add_scalar("train/batch_reward", batch_rewards.mean(), step)
278+
279+
self.update_critic(state, action, reward, next_state, done, step)
280+
281+
if step % self.actor_update_frequency == 0:
282+
self.update_actor_and_alpha(state, step)
283+
284+
if step % self.critic_target_update_frequency == 0:
285+
utils.soft_update_params(self.critic, self.critic_target, self.critic_tau)
286+
287+
def prepare_state(self, latest_scan, distance, cos, sin, collision, goal, action):
288+
# update the returned data from ROS into a form used for learning in the current model
289+
latest_scan = np.array(latest_scan)
290+
291+
inf_mask = np.isinf(latest_scan)
292+
latest_scan[inf_mask] = 7.0
293+
294+
max_bins = self.state_dim - 5
295+
bin_size = int(np.ceil(len(latest_scan) / max_bins))
296+
297+
# Initialize the list to store the minimum values of each bin
298+
min_values = []
299+
300+
# Loop through the data and create bins
301+
for i in range(0, len(latest_scan), bin_size):
302+
# Get the current bin
303+
bin = latest_scan[i : i + min(bin_size, len(latest_scan) - i)]
304+
# Find the minimum value in the current bin and append it to the min_values list
305+
min_values.append(min(bin) / 7)
306+
307+
# Normalize to [0, 1] range
308+
distance /= 10
309+
lin_vel = action[0] * 2
310+
ang_vel = (action[1] + 1) / 2
311+
state = min_values + [distance, cos, sin] + [lin_vel, ang_vel]
312+
313+
assert len(state) == self.state_dim
314+
terminal = 1 if collision or goal else 0
315+
316+
return state, terminal
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import torch
2+
from torch import nn
3+
4+
import robot_nav.models.SAC.SAC_utils as utils
5+
6+
7+
class QCritic(nn.Module):
8+
"""Critic network, employes double Q-learning."""
9+
10+
def __init__(self, obs_dim, action_dim, hidden_dim, hidden_depth):
11+
super().__init__()
12+
13+
self.Q1 = utils.mlp(obs_dim + action_dim, hidden_dim, 1, hidden_depth)
14+
15+
self.outputs = dict()
16+
self.apply(utils.weight_init)
17+
18+
def forward(self, obs, action):
19+
assert obs.size(0) == action.size(0)
20+
21+
obs_action = torch.cat([obs, action], dim=-1)
22+
q1 = self.Q1(obs_action)
23+
24+
self.outputs["q1"] = q1
25+
26+
return q1
27+
28+
def log(self, writer, step):
29+
for k, v in self.outputs.items():
30+
writer.add_histogram(f"train_critic/{k}_hist", v, step)

robot_nav/test_random.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from robot_nav.models.BPG.BTD3 import BTD3
88
from robot_nav.models.CNNTD3.CNNTD3 import CNNTD3
99
from robot_nav.models.SAC.BSAC import BSAC
10+
from robot_nav.models.SAC.BSA1C import BSA1C
1011
import statistics
1112
import numpy as np
1213
import tqdm
@@ -28,13 +29,13 @@ def main(args=None):
2829
max_steps = 300 # maximum number of steps in single episode
2930
test_scenarios = 1000
3031

31-
model = BSAC(
32+
model = BSA1C(
3233
state_dim=state_dim,
3334
action_dim=action_dim,
3435
max_action=max_action,
3536
device=device,
3637
load_model=True,
37-
model_name="BSACw025exp1",
38+
model_name="BSA1Cw025exp1",
3839
) # instantiate a model
3940

4041
sim = SIM_ENV(

robot_nav/train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from robot_nav.models.BPG.BCNNPG import BCNNPG
66
from robot_nav.models.SAC.SAC import SAC
77
from robot_nav.models.SAC.BSAC import BSAC
8+
from robot_nav.models.SAC.BSA1C import BSA1C
89
from robot_nav.models.HCM.hardcoded_model import HCM
910
from robot_nav.models.PPO.PPO import PPO
1011
from robot_nav.models.CNNTD3.CNNTD3 import CNNTD3
@@ -40,14 +41,14 @@ def main(args=None):
4041
)
4142
save_every = 5 # save the model every n training cycles
4243

43-
model = BSAC(
44+
model = BSA1C(
4445
state_dim=state_dim,
4546
action_dim=action_dim,
4647
max_action=max_action,
4748
device=device,
4849
save_every=save_every,
4950
load_model=False,
50-
model_name="BSACw025exp1",
51+
model_name="BSA1Cw025exp1",
5152
# bound_weight=0.0,
5253
) # instantiate a model
5354

0 commit comments

Comments
 (0)