Skip to content

Commit 4b19cd7

Browse files
committed
comments
1 parent 7b64017 commit 4b19cd7

File tree

4 files changed

+21
-54
lines changed

4 files changed

+21
-54
lines changed

src/common/distributions/action_distribution.rs

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -46,20 +46,6 @@ where
4646
}
4747
}
4848

49-
/// Continuous actions are usually considered to be independent,
50-
/// so we can sum components of the ``log_prob`` or the entropy.
51-
///
52-
/// # Shapes
53-
/// t: (batch, n_actions) or (batch)
54-
/// return: (batch) for (batch, n_actions) input, or (1) for (batch) input
55-
// fn sum_independent_dims<B: Backend>(t: Tensor<B, 1>) -> Tensor<B, 1>{
56-
// t.sum()
57-
// }
58-
59-
// fn sum_independent_dims_batched<B: Backend>(t: Tensor<B, 1>) -> Tensor<B, 1>{
60-
// t.sum_dim(1).squeeze(1)
61-
// }
62-
6349
#[derive(Debug, Module)]
6450
pub struct DiagGaussianDistribution<B: Backend> {
6551
means: Linear<B>,
@@ -87,10 +73,11 @@ impl<B: Backend> DiagGaussianDistribution<B> {
8773

8874
impl<B: Backend> ActionDistribution<B> for DiagGaussianDistribution<B> {
8975
fn log_prob(&self, sample: Tensor<B, 2>) -> Tensor<B, 2> {
76+
// (B, N)
9077
let log_prob = self.dist.log_prob(sample);
9178

92-
// TODO: add sum_independent_dims when multi-dim actions are supported
93-
log_prob
79+
// (B, 1)
80+
log_prob.sum_dim(1)
9481
}
9582

9683
fn entropy(&self) -> Tensor<B, 2> {
@@ -107,13 +94,12 @@ impl<B: Backend> ActionDistribution<B> for DiagGaussianDistribution<B> {
10794

10895
fn actions_from_obs(&mut self, obs: Tensor<B, 2>, deterministic: bool) -> Tensor<B, 2> {
10996
let loc = self.means.forward(obs.clone());
97+
let scale: Tensor<B, 2> = self.log_std.forward(obs).clamp(-20, 2).exp();
98+
self.dist = Normal::new(loc.clone(), scale);
11099

111100
if deterministic {
112-
loc
101+
self.dist.mean()
113102
} else {
114-
let scale: Tensor<B, 2> = self.log_std.forward(obs).clamp(-20, 2).exp();
115-
self.dist = Normal::new(loc.clone(), scale);
116-
117103
self.dist.rsample()
118104
}
119105
}
@@ -141,22 +127,29 @@ impl<B: Backend> SquashedDiagGaussianDistribution<B> {
141127
}
142128

143129
fn log_prob_correction(&self, ln_u: Tensor<B, 2>, a: Tensor<B, 2>) -> Tensor<B, 2> {
144-
ln_u - ((1.0 - a.powi_scalar(2.0) + self.epsilon) as Tensor<B, 2>)
130+
// ln_u: (B, 1)
131+
// a: (B, N)
132+
133+
// (B, 1)
134+
let correction = ((1.0 - a.powi_scalar(2.0) + self.epsilon) as Tensor<B, 2>)
145135
.log()
146-
.sum_dim(1)
136+
.sum_dim(1);
137+
138+
// (B, 1)
139+
ln_u - correction
147140
}
148141
}
149142

150143
impl<B: Backend> ActionDistribution<B> for SquashedDiagGaussianDistribution<B> {
151144
fn log_prob(&self, a: Tensor<B, 2>) -> Tensor<B, 2> {
145+
// (B, N)
152146
let u = tanh_bijector_inverse(a.clone());
153-
let ln_u = self.diag_gaus_dist.log_prob(u);
154147

155-
// Squash correction (from original SAC implementation)
156-
// this comes from the fact that tanh is bijective and differentiable
157-
let ln_a = self.log_prob_correction(ln_u, a);
148+
// (B, 1)
149+
let ln_u = self.diag_gaus_dist.log_prob(u);
158150

159-
ln_a
151+
// (B, 1)
152+
self.log_prob_correction(ln_u, a)
160153
}
161154

162155
fn entropy(&self) -> Tensor<B, 2> {

src/common/distributions/normal.rs

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,16 +89,6 @@ impl<B: Backend, const D: usize> BaseDistribution<B, D> for Normal<B, D> {
8989
let var = self.variance();
9090

9191
-(value - self.loc.clone()).powi_scalar(2) / (2 * var) - log_scale - (2.0 * PI).sqrt().ln()
92-
93-
// log_scale
94-
// .mul_scalar(-1.0)
95-
// .add_scalar(-0.5 * (2.0 * PI).log(E))
96-
// .sub(
97-
// (value - self.loc.clone())
98-
// .powi_scalar(2)
99-
// .div(var)
100-
// .mul_scalar(0.5),
101-
// )
10292
}
10393

10494
fn cdf(&self, _value: Tensor<B, D>) -> Tensor<B, D> {

src/sac/agent.rs

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -296,21 +296,7 @@ impl<B: AutodiffBackend> SACAgent<B> {
296296

297297
let actor_loss_back = actor_loss.backward();
298298
let actor_grads = GradientsParams::from_grads(actor_loss_back, &self.pi);
299-
300-
// do some checks to see that pi is actually updating
301-
// let mut pre_pi_summary = ModuleParamSummary::default();
302-
// self.pi.visit(&mut pre_pi_summary);
303-
304299
self.pi = self.pi_optim.step(lr, self.pi.clone(), actor_grads);
305-
// let mut post_pi_summary = ModuleParamSummary::default();
306-
// self.pi.visit(&mut post_pi_summary);
307-
308-
// println!("Pi Summary pre-step");
309-
// pre_pi_summary.print();
310-
// println!("Pi Summary post-step");
311-
// post_pi_summary.print();
312-
//
313-
// panic!();
314300

315301
log_dict
316302
}
@@ -407,8 +393,6 @@ impl<B: AutodiffBackend> Agent<B, Vec<f32>, Vec<f32>> for SACAgent<B> {
407393
),
408394
);
409395

410-
let log_prob = log_prob.sum_dim(1);
411-
412396
self.profiler
413397
.record("policy", t_policy0.elapsed().as_secs_f64());
414398

src/sac/models.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ impl<B: Backend> PiModel<B> {
5353

5454
pub fn act_log_prob(&mut self, obs: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 2>) {
5555
let latent = self.mlp.forward(obs.clone().unsqueeze());
56-
self.dist.actions_from_obs_with_log_probs(latent, deterministic)
56+
self.dist.actions_from_obs_with_log_probs(latent, false)
5757
}
5858
}
5959

0 commit comments

Comments
 (0)