11use burn:: {
22 module:: Module ,
3+ nn:: { Linear , LinearConfig } ,
34 prelude:: Backend ,
45 tensor:: { activation:: relu, Tensor } ,
56} ;
67
78use crate :: common:: {
89 agent:: Policy ,
9- distributions:: {
10- action_distribution:: { ActionDistribution , SquashedDiagGaussianDistribution } ,
11- distribution:: BaseDistribution ,
12- normal:: Normal ,
13- } ,
10+ distributions:: { distribution:: BaseDistribution , normal:: Normal } ,
1411 utils:: modules:: MLP ,
1512} ;
1613
@@ -27,8 +24,8 @@ impl<B: Backend> PiModel<B> {
2724 pub fn new ( obs_size : usize , n_actions : usize , device : & B :: Device ) -> Self {
2825 Self {
2926 mlp : MLP :: new ( & [ obs_size, 256 , 256 ] . to_vec ( ) , device) ,
30- scale_head : LinearConfig :: new ( 256 , n_actions) ,
31- loc_head : LinearConfig :: new ( 256 , n_actions) ,
27+ scale_head : LinearConfig :: new ( 256 , n_actions) . init ( device ) ,
28+ loc_head : LinearConfig :: new ( 256 , n_actions) . init ( device ) ,
3229 // dist: SquashedDiagGaussianDistribution::new(256, n_actions, device, 1e-6),
3330 n_actions,
3431 }
@@ -37,41 +34,42 @@ impl<B: Backend> PiModel<B> {
3734
3835impl < B : Backend > PiModel < B > {
3936 fn forward ( & self , obs : Tensor < B , 2 > ) -> ( Tensor < B , 2 > , Tensor < B , 2 > ) {
40- let latent = relu ( self . mlp . forward ( obs. clone ( ) . unsqueeze_dim ( 0 ) ) ) ;
37+ let latent = relu ( self . mlp . forward ( obs. clone ( ) ) ) ;
4138 let loc = self . loc_head . forward ( latent. clone ( ) ) ;
4239 let log_scale = self . loc_head . forward ( latent) ;
4340 let log_scale = log_scale. tanh ( ) ;
4441
45- let min_log_scale = -20 ;
46- let max_log_scale = 2 ;
42+ let min_log_scale = -20.0 ;
43+ let max_log_scale = 2.0 ;
4744
4845 let log_scale = min_log_scale + 0.5 * ( max_log_scale - min_log_scale) * ( log_scale + 1.0 ) ;
4946
5047 ( loc, log_scale)
5148 }
5249 pub fn act ( & mut self , obs : & Tensor < B , 1 > , deterministic : bool ) -> Tensor < B , 1 > {
53- let ( loc, log_scale) = self . forward ( obs. unsqueeze_dim ( 0 ) ) ;
50+ let ( loc, log_scale) = self . forward ( obs. clone ( ) . unsqueeze_dim ( 0 ) ) ;
5451
5552 if deterministic {
56- loc. tanh ( )
53+ loc. tanh ( ) . squeeze ( 0 )
5754 } else {
5855 let scale = log_scale. exp ( ) ;
5956 let dist = Normal :: new ( loc, scale) ;
6057 let x_t = dist. rsample ( ) ;
6158 let action = x_t. tanh ( ) ;
6259
63- action
60+ action. squeeze ( 0 )
6461 }
6562
6663 // self.dist.actions_from_obs(latent, deterministic).squeeze(0)
6764 }
6865
6966 pub fn act_log_prob ( & mut self , obs : Tensor < B , 2 > ) -> ( Tensor < B , 2 > , Tensor < B , 2 > ) {
7067 let ( loc, log_scale) = self . forward ( obs. unsqueeze_dim ( 0 ) ) ;
68+ let scale = log_scale. exp ( ) ;
7169 let dist = Normal :: new ( loc, scale) ;
7270 let x_t = dist. rsample ( ) ;
73- let action = x_t. tanh ( ) ;
74- let log_prob = dist. log_prob ( action ) ;
71+ let action = x_t. clone ( ) . tanh ( ) ;
72+ let log_prob = dist. log_prob ( x_t ) ;
7573
7674 ( action, log_prob)
7775
0 commit comments