11use std:: path:: PathBuf ;
22
3- use burn:: {
4- backend:: { wgpu:: WgpuDevice , Autodiff } ,
5- grad_clipping:: GradientClippingConfig ,
6- optim:: AdamConfig ,
7- } ;
3+ use burn:: { backend:: Autodiff , grad_clipping:: GradientClippingConfig , optim:: AdamConfig } ;
84use sb3_burn:: {
95 common:: {
106 algorithm:: { OfflineAlgParams , OfflineTrainer } ,
@@ -17,14 +13,14 @@ use sb3_burn::{
1713 env:: { base:: Env , classic_control:: cartpole:: CartpoleEnv } ,
1814} ;
1915
20- #[ cfg( not ( feature = "tch" ) ) ]
21- use burn:: backend:: Wgpu ;
22- #[ cfg( feature = "tch" ) ]
23- use burn:: backend:: { LibTorch , LibTorchDevice } ;
16+ #[ cfg( feature = "sb3- tch" ) ]
17+ use burn:: backend:: { libtorch :: LibTorchDevice , LibTorch } ;
18+ #[ cfg( not ( feature = "sb3- tch" ) ) ]
19+ use burn:: backend:: { wgpu :: WgpuDevice , Wgpu } ;
2420
25- #[ cfg( not( feature = "tch" ) ) ]
21+ #[ cfg( not( feature = "sb3- tch" ) ) ]
2622type B = Autodiff < Wgpu > ;
27- #[ cfg( feature = "tch" ) ]
23+ #[ cfg( feature = "sb3- tch" ) ]
2824type B = Autodiff < LibTorch > ;
2925
3026extern crate sb3_burn;
@@ -33,14 +29,16 @@ fn main() {
3329 // Using parameters from:
3430 // https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/dqn.yml
3531
36- #[ cfg( feature = "tch" ) ]
37- let train_device = if tch:: utils:: has_cuda ( ) ( ) {
32+ #[ cfg( feature = "sb3-tch" ) ]
33+ let train_device = if tch:: utils:: has_cuda ( ) {
34+ println ! ( "Using LibTorch (GPU)" ) ;
3835 LibTorchDevice :: Cuda ( 0 )
3936 } else {
37+ println ! ( "Using LibTorch (CPU)" ) ;
4038 LibTorchDevice :: Cpu
4139 } ;
4240
43- #[ cfg( not( feature = "tch" ) ) ]
41+ #[ cfg( not( feature = "sb3- tch" ) ) ]
4442 let train_device = WgpuDevice :: default ( ) ;
4543
4644 sb3_seed :: < B > ( 1234 , & train_device) ;
@@ -103,7 +101,7 @@ fn main() {
103101 buffer,
104102 Box :: new ( logger) ,
105103 None ,
106- EvalConfig :: new ( ) . with_n_eval_episodes ( 100 ) ,
104+ EvalConfig :: new ( ) . with_n_eval_episodes ( 5 ) ,
107105 & train_device,
108106 ) ;
109107
0 commit comments