Skip to content

Commit a4b9533

Browse files
committed
model init
1 parent 91bbf17 commit a4b9533

File tree

5 files changed

+23
-12
lines changed

5 files changed

+23
-12
lines changed

examples/sac_pendulum.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ fn main() {
5353

5454
let offline_params = OfflineAlgParams::new()
5555
.with_batch_size(256)
56-
.with_memory_size(50000)
56+
.with_memory_size(100_000)
5757
.with_gamma(0.99)
5858
.with_n_steps(100_000)
5959
.with_warmup_steps(10000)
@@ -75,8 +75,11 @@ fn main() {
7575
true,
7676
None,
7777
Some(1e-3),
78-
Some(0.01),
79-
Box::new(BoxSpace::from(([-1.0, -1.0, -1.0].to_vec(), [1.0, 1.0, 1.0].to_vec()))),
78+
Some(0.005),
79+
Box::new(BoxSpace::from((
80+
vec![-1.0, -1.0, -1.0],
81+
vec![1.0, 1.0, 1.0],
82+
))),
8083
Box::new(BoxSpace::from(([-1.0].to_vec(), [1.0].to_vec()))),
8184
);
8285

src/common/distributions/action_distribution.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,18 @@ impl<B: Backend> DiagGaussianDistribution<B> {
6868
let dist: Normal<B, 2> = Normal::new(loc, std);
6969

7070
Self {
71-
means: LinearConfig::new(latent_dim, action_dim).init(device),
72-
log_std: LinearConfig::new(latent_dim, action_dim).init(device),
71+
means: LinearConfig::new(latent_dim, action_dim)
72+
.with_initializer(burn::nn::Initializer::Uniform {
73+
min: -3e-3,
74+
max: 3e-3,
75+
})
76+
.init(device),
77+
log_std: LinearConfig::new(latent_dim, action_dim)
78+
.with_initializer(burn::nn::Initializer::Uniform {
79+
min: -3e-3,
80+
max: 3e-3,
81+
})
82+
.init(device),
7383
dist,
7484
}
7585
}

src/common/utils/module_update.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ mod test {
134134
let to = Param::from_tensor(Tensor::from_floats([0.0], &Default::default()));
135135
let tau = 0.05;
136136

137-
let new_to: Param<Tensor<B, 2, Float>> = soft_update_tensor(&from, to, tau);
137+
let new_to: Param<Tensor<B, 1, Float>> = soft_update_tensor(&from, to, tau);
138138

139139
let new_to_f: f32 = new_to.val().into_scalar();
140140

src/common/utils/modules.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ impl<B: Backend> MLP<B> {
2121
layers.push(
2222
LinearConfig::new(sizes[i], sizes[i + 1])
2323
.with_initializer(burn::nn::Initializer::Uniform {
24-
min: 3e-3,
24+
min: -3e-3,
2525
max: 3e-3,
2626
})
2727
.init(device),

src/sac/agent.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -371,9 +371,9 @@ impl<B: AutodiffBackend> Agent<B, Vec<f32>, Vec<f32>> for SACAgent<B> {
371371

372372
// train entropy coeficient if required to do so
373373
let t_ent0 = std::time::Instant::now();
374-
let (ent_coef, ent_coef_loss) = self
375-
.ent_coef
376-
.train_step(log_prob.clone().flatten(0, 1), self.ent_lr, train_device);
374+
let (ent_coef, ent_coef_loss) =
375+
self.ent_coef
376+
.train_step(log_prob.clone().flatten(0, 1), self.ent_lr, train_device);
377377
self.profiler
378378
.record("ent_coef", t_ent0.elapsed().as_secs_f64());
379379

@@ -427,8 +427,6 @@ impl<B: AutodiffBackend> Agent<B, Vec<f32>, Vec<f32>> for SACAgent<B> {
427427
self.last_update = global_step;
428428
}
429429

430-
// panic!();
431-
432430
(None, log_dict)
433431
}
434432

0 commit comments

Comments
 (0)