4646 }
4747}
4848
49- /// Continuous actions are usually considered to be independent,
50- /// so we can sum components of the ``log_prob`` or the entropy.
51- ///
52- /// # Shapes
53- /// t: (batch, n_actions) or (batch)
54- /// return: (batch) for (batch, n_actions) input, or (1) for (batch) input
55- // fn sum_independent_dims<B: Backend>(t: Tensor<B, 1>) -> Tensor<B, 1>{
56- // t.sum()
57- // }
58-
59- // fn sum_independent_dims_batched<B: Backend>(t: Tensor<B, 1>) -> Tensor<B, 1>{
60- // t.sum_dim(1).squeeze(1)
61- // }
62-
6349#[ derive( Debug , Module ) ]
6450pub struct DiagGaussianDistribution < B : Backend > {
6551 means : Linear < B > ,
@@ -87,10 +73,11 @@ impl<B: Backend> DiagGaussianDistribution<B> {
8773
8874impl < B : Backend > ActionDistribution < B > for DiagGaussianDistribution < B > {
8975 fn log_prob ( & self , sample : Tensor < B , 2 > ) -> Tensor < B , 2 > {
76+ // (B, N)
9077 let log_prob = self . dist . log_prob ( sample) ;
9178
92- // TODO: add sum_independent_dims when multi-dim actions are supported
93- log_prob
79+ // (B, 1)
80+ log_prob. sum_dim ( 1 )
9481 }
9582
9683 fn entropy ( & self ) -> Tensor < B , 2 > {
@@ -107,13 +94,12 @@ impl<B: Backend> ActionDistribution<B> for DiagGaussianDistribution<B> {
10794
10895 fn actions_from_obs ( & mut self , obs : Tensor < B , 2 > , deterministic : bool ) -> Tensor < B , 2 > {
10996 let loc = self . means . forward ( obs. clone ( ) ) ;
97+ let scale: Tensor < B , 2 > = self . log_std . forward ( obs) . clamp ( -20 , 2 ) . exp ( ) ;
98+ self . dist = Normal :: new ( loc. clone ( ) , scale) ;
11099
111100 if deterministic {
112- loc
101+ self . dist . mean ( )
113102 } else {
114- let scale: Tensor < B , 2 > = self . log_std . forward ( obs) . clamp ( -20 , 2 ) . exp ( ) ;
115- self . dist = Normal :: new ( loc. clone ( ) , scale) ;
116-
117103 self . dist . rsample ( )
118104 }
119105 }
@@ -141,22 +127,29 @@ impl<B: Backend> SquashedDiagGaussianDistribution<B> {
141127 }
142128
143129 fn log_prob_correction ( & self , ln_u : Tensor < B , 2 > , a : Tensor < B , 2 > ) -> Tensor < B , 2 > {
144- ln_u - ( ( 1.0 - a. powi_scalar ( 2.0 ) + self . epsilon ) as Tensor < B , 2 > )
130+ // ln_u: (B, 1)
131+ // a: (B, N)
132+
133+ // (B, 1)
134+ let correction = ( ( 1.0 - a. powi_scalar ( 2.0 ) + self . epsilon ) as Tensor < B , 2 > )
145135 . log ( )
146- . sum_dim ( 1 )
136+ . sum_dim ( 1 ) ;
137+
138+ // (B, 1)
139+ ln_u - correction
147140 }
148141}
149142
150143impl < B : Backend > ActionDistribution < B > for SquashedDiagGaussianDistribution < B > {
151144 fn log_prob ( & self , a : Tensor < B , 2 > ) -> Tensor < B , 2 > {
145+ // (B, N)
152146 let u = tanh_bijector_inverse ( a. clone ( ) ) ;
153- let ln_u = self . diag_gaus_dist . log_prob ( u) ;
154147
155- // Squash correction (from original SAC implementation)
156- // this comes from the fact that tanh is bijective and differentiable
157- let ln_a = self . log_prob_correction ( ln_u, a) ;
148+ // (B, 1)
149+ let ln_u = self . diag_gaus_dist . log_prob ( u) ;
158150
159- ln_a
151+ // (B, 1)
152+ self . log_prob_correction ( ln_u, a)
160153 }
161154
162155 fn entropy ( & self ) -> Tensor < B , 2 > {
0 commit comments