Skip to content

Commit 91bbf17

Browse files
committed
tau check, module init fn
1 parent a985767 commit 91bbf17

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

src/common/utils/module_update.rs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,12 @@ pub fn update_conv2d<B: Backend>(from: &Conv2d<B>, to: Conv2d<B>, tau: Option<f3
6161
mod test {
6262
use burn::{
6363
backend::NdArray,
64-
module::Module,
64+
module::{Module, Param},
6565
nn::{Linear, LinearConfig},
66-
tensor::backend::Backend,
66+
tensor::{backend::Backend, Float, Tensor},
6767
};
6868

69-
use crate::common::agent::Policy;
69+
use crate::common::{agent::Policy, utils::module_update::soft_update_tensor};
7070

7171
use super::update_linear;
7272

@@ -126,4 +126,20 @@ mod test {
126126

127127
a.update(&b, None);
128128
}
129+
130+
#[test]
131+
fn test_soft_update_value() {
132+
type B = NdArray;
133+
let from = Param::from_tensor(Tensor::from_floats([1.0], &Default::default()));
134+
let to = Param::from_tensor(Tensor::from_floats([0.0], &Default::default()));
135+
let tau = 0.05;
136+
137+
let new_to: Param<Tensor<B, 2, Float>> = soft_update_tensor(&from, to, tau);
138+
139+
let new_to_f: f32 = new_to.val().into_scalar();
140+
141+
let diff = tau - new_to_f;
142+
143+
assert!(diff < 1e-6);
144+
}
129145
}

src/common/utils/modules.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,14 @@ impl<B: Backend> MLP<B> {
1818
let mut layers = Vec::new();
1919

2020
for i in 0..sizes.len() - 1 {
21-
layers.push(LinearConfig::new(sizes[i], sizes[i + 1]).init(device))
21+
layers.push(
22+
LinearConfig::new(sizes[i], sizes[i + 1])
23+
.with_initializer(burn::nn::Initializer::Uniform {
24+
min: 3e-3,
25+
max: 3e-3,
26+
})
27+
.init(device),
28+
)
2229
}
2330

2431
Self { layers }

0 commit comments

Comments
 (0)