11use std:: path:: PathBuf ;
22
3- use burn:: {
4- backend:: { libtorch:: LibTorchDevice , Autodiff , LibTorch } ,
5- grad_clipping:: GradientClippingConfig ,
6- optim:: AdamConfig ,
7- } ;
3+ use burn:: { backend:: Autodiff , grad_clipping:: GradientClippingConfig , optim:: AdamConfig } ;
84use sb3_burn:: {
95 common:: {
106 algorithm:: { OfflineAlgParams , OfflineTrainer } ,
@@ -17,16 +13,35 @@ use sb3_burn::{
1713 env:: { base:: Env , classic_control:: cartpole:: CartpoleEnv } ,
1814} ;
1915
16+ #[ cfg( feature = "sb3-tch" ) ]
17+ use burn:: backend:: { libtorch:: LibTorchDevice , LibTorch } ;
18+ #[ cfg( not( feature = "sb3-tch" ) ) ]
19+ use burn:: backend:: { wgpu:: WgpuDevice , Wgpu } ;
20+
21+ #[ cfg( not( feature = "sb3-tch" ) ) ]
22+ type B = Autodiff < Wgpu > ;
23+ #[ cfg( feature = "sb3-tch" ) ]
24+ type B = Autodiff < LibTorch > ;
25+
2026extern crate sb3_burn;
2127
2228fn main ( ) {
2329 // Using parameters from:
2430 // https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/dqn.yml
2531
26- type TrainDevice = Autodiff < LibTorch > ;
27- let train_device = LibTorchDevice :: Cuda ( 0 ) ;
32+ #[ cfg( feature = "sb3-tch" ) ]
33+ let train_device = if tch:: utils:: has_cuda ( ) {
34+ println ! ( "Using LibTorch (GPU)" ) ;
35+ LibTorchDevice :: Cuda ( 0 )
36+ } else {
37+ println ! ( "Using LibTorch (CPU)" ) ;
38+ LibTorchDevice :: Cpu
39+ } ;
40+
41+ #[ cfg( not( feature = "sb3-tch" ) ) ]
42+ let train_device = WgpuDevice :: default ( ) ;
2843
29- sb3_seed :: < TrainDevice > ( 1234 , & train_device) ;
44+ sb3_seed :: < B > ( 1234 , & train_device) ;
3045
3146 let config_optimizer =
3247 AdamConfig :: new ( ) . with_grad_clipping ( Some ( GradientClippingConfig :: Norm ( 10.0 ) ) ) ;
@@ -46,7 +61,7 @@ fn main() {
4661 . with_train_every ( 256 ) ;
4762
4863 let env = CartpoleEnv :: new ( 500 ) ;
49- let q: LinearAdvDQNNet < TrainDevice > = LinearAdvDQNNet :: init (
64+ let q: LinearAdvDQNNet < B > = LinearAdvDQNNet :: init (
5065 & train_device,
5166 env. observation_space ( ) . shape ( ) . len ( ) ,
5267 env. action_space ( ) . shape ( ) ,
@@ -86,7 +101,7 @@ fn main() {
86101 buffer,
87102 Box :: new ( logger) ,
88103 None ,
89- EvalConfig :: new ( ) . with_n_eval_episodes ( 100 ) ,
104+ EvalConfig :: new ( ) . with_n_eval_episodes ( 5 ) ,
90105 & train_device,
91106 ) ;
92107
0 commit comments