11use std:: path:: PathBuf ;
22
3- use burn:: {
4- backend:: { libtorch:: LibTorchDevice , Autodiff , LibTorch } ,
5- grad_clipping:: GradientClippingConfig ,
6- optim:: AdamConfig ,
7- } ;
3+ use burn:: { backend:: Autodiff , grad_clipping:: GradientClippingConfig , optim:: AdamConfig } ;
84use sb3_burn:: {
95 common:: {
106 algorithm:: { OfflineAlgParams , OfflineTrainer } ,
@@ -17,16 +13,33 @@ use sb3_burn::{
1713 env:: { base:: Env , probe:: ProbeEnvValueTest } ,
1814} ;
1915
16+ #[ cfg( not( feature = "tch" ) ) ]
17+ use burn:: backend:: { wgpu:: WgpuDevice , Wgpu } ;
18+ #[ cfg( feature = "tch" ) ]
19+ use burn:: backend:: { LibTorch , LibTorchDevice } ;
20+
21+ #[ cfg( not( feature = "tch" ) ) ]
22+ type B = Autodiff < Wgpu > ;
23+ #[ cfg( feature = "tch" ) ]
24+ type B = Autodiff < LibTorch > ;
25+
2026extern crate sb3_burn;
2127
2228fn main ( ) {
2329 // Using parameters from:
2430 // https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/dqn.yml
2531
26- type TrainBackend = Autodiff < LibTorch > ;
27- let train_device = LibTorchDevice :: default ( ) ;
32+ #[ cfg( feature = "tch" ) ]
33+ let train_device = if tch:: utils:: has_cuda ( ) ( ) {
34+ LibTorchDevice :: Cuda ( 0 )
35+ } else {
36+ LibTorchDevice :: Cpu
37+ } ;
38+
39+ #[ cfg( not( feature = "tch" ) ) ]
40+ let train_device = WgpuDevice :: default ( ) ;
2841
29- sb3_seed :: < TrainBackend > ( 1234 , & train_device) ;
42+ sb3_seed :: < B > ( 1234 , & train_device) ;
3043
3144 let config_optimizer =
3245 AdamConfig :: new ( ) . with_grad_clipping ( Some ( GradientClippingConfig :: Norm ( 10.0 ) ) ) ;
@@ -42,7 +55,7 @@ fn main() {
4255 . with_evaluate_during_training ( false ) ;
4356
4457 let env = ProbeEnvValueTest :: default ( ) ;
45- let q: LinearDQNNet < TrainBackend > = LinearDQNNet :: init (
58+ let q: LinearDQNNet < B > = LinearDQNNet :: init (
4659 & train_device,
4760 env. observation_space ( ) . shape ( ) . len ( ) ,
4861 env. action_space ( ) . shape ( ) ,
0 commit comments