Skip to content

Commit a394f48

Browse files
committed
work
1 parent f2969fb commit a394f48

File tree

2 files changed

+35
-22
lines changed

2 files changed

+35
-22
lines changed

src/sac/agent.rs

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,21 @@ impl<B: AutodiffBackend> SACAgent<B> {
206206
let (next_action_sampled_raw, next_action_log_prob_raw) =
207207
self.pi.act_log_prob(next_states.clone());
208208

209-
let (next_action_sampled, action_scale, _) =
210-
scale_actions_to_env(next_action_sampled_raw, &self.action_space, train_device);
211-
let next_action_log_prob = next_action_log_prob_raw - action_scale.log().sum_dim(1);
209+
let (next_action_sampled, action_scale, _) = scale_actions_to_env(
210+
next_action_sampled_raw.clone(),
211+
&self.action_space,
212+
train_device,
213+
);
214+
// let next_action_log_prob =
215+
// next_action_log_prob_raw.clone() - action_scale.clone().log().sum_dim(1);
216+
217+
let next_action_log_prob_scaled = next_action_log_prob_raw
218+
- (next_action_sampled_raw.powi_scalar(2).neg().add_scalar(1.0))
219+
.mul(action_scale)
220+
.add_scalar(1e-6)
221+
.log();
212222

223+
let next_action_log_prob_scaled = next_action_log_prob_scaled.sum_dim(1);
213224
// disp_tensorf("next_action_sampled", &next_action_sampled);
214225
// disp_tensorf("next_action_log_prob", &next_action_log_prob);
215226

@@ -225,7 +236,7 @@ impl<B: AutodiffBackend> SACAgent<B> {
225236
// disp_tensorf("2next_q_vals", &next_q_vals);
226237

227238
// add the entropy term
228-
let next_q_vals = next_q_vals - next_action_log_prob.mul_scalar(ent_coef);
239+
let next_q_vals = next_q_vals - next_action_log_prob_scaled.mul_scalar(ent_coef);
229240
// disp_tensorf("3next_q_vals", &next_q_vals);
230241

231242
// td error + entropy term
@@ -358,12 +369,16 @@ impl<B: AutodiffBackend> Agent<B, Vec<f32>, Vec<f32>> for SACAgent<B> {
358369
let t_policy0 = std::time::Instant::now();
359370
let (actions_pi_raw, log_prob_raw) = self.pi.act_log_prob(states.clone());
360371

361-
let (actions_pi_scaled, action_scale, action_bias) =
362-
scale_actions_to_env(actions_pi_raw, &self.action_space, train_device);
372+
let (actions_pi_scaled, action_scale, _) =
373+
scale_actions_to_env(actions_pi_raw.clone(), &self.action_space, train_device);
363374

364375
// let log_prob = log_prob_raw - action_scale.log().sum_dim(1);
365376

366-
let log_prob_scaled = log_prob - ((1 - actions_pi_raw.powi(2)) * action_scale + 1e-6).log();
377+
let log_prob_scaled = log_prob_raw
378+
- (actions_pi_raw.powi_scalar(2).neg().add_scalar(1.0))
379+
.mul(action_scale)
380+
.add_scalar(1e-6)
381+
.log();
367382
let log_prob_scaled = log_prob_scaled.sum_dim(1);
368383

369384
self.profiler

src/sac/models.rs

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
use burn::{
22
module::Module,
3+
nn::{Linear, LinearConfig},
34
prelude::Backend,
45
tensor::{activation::relu, Tensor},
56
};
67

78
use crate::common::{
89
agent::Policy,
9-
distributions::{
10-
action_distribution::{ActionDistribution, SquashedDiagGaussianDistribution},
11-
distribution::BaseDistribution,
12-
normal::Normal,
13-
},
10+
distributions::{distribution::BaseDistribution, normal::Normal},
1411
utils::modules::MLP,
1512
};
1613

@@ -27,8 +24,8 @@ impl<B: Backend> PiModel<B> {
2724
pub fn new(obs_size: usize, n_actions: usize, device: &B::Device) -> Self {
2825
Self {
2926
mlp: MLP::new(&[obs_size, 256, 256].to_vec(), device),
30-
scale_head: LinearConfig::new(256, n_actions),
31-
loc_head: LinearConfig::new(256, n_actions),
27+
scale_head: LinearConfig::new(256, n_actions).init(device),
28+
loc_head: LinearConfig::new(256, n_actions).init(device),
3229
// dist: SquashedDiagGaussianDistribution::new(256, n_actions, device, 1e-6),
3330
n_actions,
3431
}
@@ -37,41 +34,42 @@ impl<B: Backend> PiModel<B> {
3734

3835
impl<B: Backend> PiModel<B> {
3936
fn forward(&self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 2>) {
40-
let latent = relu(self.mlp.forward(obs.clone().unsqueeze_dim(0)));
37+
let latent = relu(self.mlp.forward(obs.clone()));
4138
let loc = self.loc_head.forward(latent.clone());
4239
let log_scale = self.loc_head.forward(latent);
4340
let log_scale = log_scale.tanh();
4441

45-
let min_log_scale = -20;
46-
let max_log_scale = 2;
42+
let min_log_scale = -20.0;
43+
let max_log_scale = 2.0;
4744

4845
let log_scale = min_log_scale + 0.5 * (max_log_scale - min_log_scale) * (log_scale + 1.0);
4946

5047
(loc, log_scale)
5148
}
5249
pub fn act(&mut self, obs: &Tensor<B, 1>, deterministic: bool) -> Tensor<B, 1> {
53-
let (loc, log_scale) = self.forward(obs.unsqueeze_dim(0));
50+
let (loc, log_scale) = self.forward(obs.clone().unsqueeze_dim(0));
5451

5552
if deterministic {
56-
loc.tanh()
53+
loc.tanh().squeeze(0)
5754
} else {
5855
let scale = log_scale.exp();
5956
let dist = Normal::new(loc, scale);
6057
let x_t = dist.rsample();
6158
let action = x_t.tanh();
6259

63-
action
60+
action.squeeze(0)
6461
}
6562

6663
// self.dist.actions_from_obs(latent, deterministic).squeeze(0)
6764
}
6865

6966
pub fn act_log_prob(&mut self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 2>) {
7067
let (loc, log_scale) = self.forward(obs.unsqueeze_dim(0));
68+
let scale = log_scale.exp();
7169
let dist = Normal::new(loc, scale);
7270
let x_t = dist.rsample();
73-
let action = x_t.tanh();
74-
let log_prob = dist.log_prob(action);
71+
let action = x_t.clone().tanh();
72+
let log_prob = dist.log_prob(x_t);
7573

7674
(action, log_prob)
7775

0 commit comments

Comments
 (0)