Skip to content

Commit 1af7966

Browse files
committed
updating examples
1 parent fbd39ec commit 1af7966

File tree

13 files changed

+285
-117
lines changed

13 files changed

+285
-117
lines changed

examples/dqn_cartpole.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ use burn::backend::Wgpu;
2323
use burn::backend::{LibTorch, LibTorchDevice};
2424

2525
#[cfg(not(feature = "tch"))]
26-
type TrainDevice = Autodiff<Wgpu>;
26+
type B = Autodiff<Wgpu>;
2727
#[cfg(feature = "tch")]
28-
type TrainDevice = Autodiff<LibTorch>;
28+
type B = Autodiff<LibTorch>;
2929

3030
extern crate sb3_burn;
3131

@@ -34,7 +34,7 @@ fn main() {
3434
// https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/dqn.yml
3535

3636
#[cfg(feature = "tch")]
37-
let train_device = if has_cuda() {
37+
let train_device = if tch::utils::has_cuda()() {
3838
LibTorchDevice::Cuda(0)
3939
} else {
4040
LibTorchDevice::Cpu
@@ -43,7 +43,7 @@ fn main() {
4343
#[cfg(not(feature = "tch"))]
4444
let train_device = WgpuDevice::default();
4545

46-
sb3_seed::<TrainDevice>(1234, &train_device);
46+
sb3_seed::<B>(1234, &train_device);
4747

4848
let config_optimizer =
4949
AdamConfig::new().with_grad_clipping(Some(GradientClippingConfig::Norm(10.0)));
@@ -63,7 +63,7 @@ fn main() {
6363
.with_train_every(256);
6464

6565
let env = CartpoleEnv::new(500);
66-
let q: LinearAdvDQNNet<TrainDevice> = LinearAdvDQNNet::init(
66+
let q: LinearAdvDQNNet<B> = LinearAdvDQNNet::init(
6767
&train_device,
6868
env.observation_space().shape().len(),
6969
env.action_space().shape(),

examples/dqn_gridworld.rs

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
use std::path::PathBuf;
22

3-
use burn::{
4-
backend::{libtorch::LibTorchDevice, Autodiff, LibTorch},
5-
grad_clipping::GradientClippingConfig,
6-
optim::AdamConfig,
7-
};
3+
use burn::{backend::Autodiff, grad_clipping::GradientClippingConfig, optim::AdamConfig};
84
use sb3_burn::{
95
common::{
106
algorithm::{OfflineAlgParams, OfflineTrainer},
@@ -17,17 +13,33 @@ use sb3_burn::{
1713
env::{base::Env, gridworld::GridWorldEnv},
1814
};
1915

16+
#[cfg(not(feature = "tch"))]
17+
use burn::backend::{wgpu::WgpuDevice, Wgpu};
18+
#[cfg(feature = "tch")]
19+
use burn::backend::{LibTorch, LibTorchDevice};
20+
21+
#[cfg(not(feature = "tch"))]
22+
type B = Autodiff<Wgpu>;
23+
#[cfg(feature = "tch")]
24+
type B = Autodiff<LibTorch>;
25+
2026
extern crate sb3_burn;
2127

2228
fn main() {
2329
// Using parameters from:
2430
// https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/dqn.yml
2531

26-
type TrainingBacked = Autodiff<LibTorch>;
32+
#[cfg(feature = "tch")]
33+
let train_device = if tch::utils::has_cuda()() {
34+
LibTorchDevice::Cuda(0)
35+
} else {
36+
LibTorchDevice::Cpu
37+
};
2738

28-
let train_device = LibTorchDevice::Cuda(0);
39+
#[cfg(not(feature = "tch"))]
40+
let train_device = WgpuDevice::default();
2941

30-
sb3_seed::<TrainingBacked>(1234, &train_device);
42+
sb3_seed::<B>(1234, &train_device);
3143

3244
let config_optimizer =
3345
AdamConfig::new().with_grad_clipping(Some(GradientClippingConfig::Norm(10.0)));
@@ -40,7 +52,7 @@ fn main() {
4052
.with_lr(1e-3);
4153

4254
let env = GridWorldEnv::default();
43-
let q = LinearAdvDQNNet::<TrainingBacked>::init(
55+
let q = LinearAdvDQNNet::<B>::init(
4456
&train_device,
4557
env.observation_space().shape().len(),
4658
env.action_space().shape(),

examples/dqn_mountaincar.rs

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
use std::path::PathBuf;
22

3-
use burn::{
4-
backend::{libtorch::LibTorchDevice, Autodiff, LibTorch},
5-
grad_clipping::GradientClippingConfig,
6-
optim::AdamConfig,
7-
};
3+
use burn::{backend::Autodiff, grad_clipping::GradientClippingConfig, optim::AdamConfig};
84
use sb3_burn::{
95
common::{
106
algorithm::{OfflineAlgParams, OfflineTrainer},
@@ -17,17 +13,33 @@ use sb3_burn::{
1713
env::{base::Env, classic_control::mountain_car::MountainCarEnv},
1814
};
1915

16+
#[cfg(not(feature = "tch"))]
17+
use burn::backend::{wgpu::WgpuDevice, Wgpu};
18+
#[cfg(feature = "tch")]
19+
use burn::backend::{LibTorch, LibTorchDevice};
20+
21+
#[cfg(not(feature = "tch"))]
22+
type B = Autodiff<Wgpu>;
23+
#[cfg(feature = "tch")]
24+
type B = Autodiff<LibTorch>;
25+
2026
extern crate sb3_burn;
2127

2228
fn main() {
2329
// Using parameters from:
2430
// https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/dqn.yml
2531

26-
type TrainingBacked = Autodiff<LibTorch>;
32+
#[cfg(feature = "tch")]
33+
let train_device = if tch::utils::has_cuda()() {
34+
LibTorchDevice::Cuda(0)
35+
} else {
36+
LibTorchDevice::Cpu
37+
};
2738

28-
let train_device = LibTorchDevice::Cuda(0);
39+
#[cfg(not(feature = "tch"))]
40+
let train_device = WgpuDevice::default();
2941

30-
sb3_seed::<TrainingBacked>(1234, &train_device);
42+
sb3_seed::<B>(1234, &train_device);
3143

3244
let config_optimizer =
3345
AdamConfig::new().with_grad_clipping(Some(GradientClippingConfig::Norm(10.0)));
@@ -46,7 +58,7 @@ fn main() {
4658
.with_train_every(16);
4759

4860
let env = MountainCarEnv::default();
49-
let q = LinearAdvDQNNet::<TrainingBacked>::init(
61+
let q = LinearAdvDQNNet::<B>::init(
5062
&train_device,
5163
env.observation_space().shape().len(),
5264
env.action_space().shape(),

examples/dqn_probe1.rs

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
use std::path::PathBuf;
22

3-
use burn::{
4-
backend::{libtorch::LibTorchDevice, Autodiff, LibTorch},
5-
grad_clipping::GradientClippingConfig,
6-
optim::AdamConfig,
7-
};
3+
use burn::{backend::Autodiff, grad_clipping::GradientClippingConfig, optim::AdamConfig};
84
use sb3_burn::{
95
common::{
106
algorithm::{OfflineAlgParams, OfflineTrainer},
@@ -17,16 +13,33 @@ use sb3_burn::{
1713
env::{base::Env, probe::ProbeEnvValueTest},
1814
};
1915

16+
#[cfg(not(feature = "tch"))]
17+
use burn::backend::{wgpu::WgpuDevice, Wgpu};
18+
#[cfg(feature = "tch")]
19+
use burn::backend::{LibTorch, LibTorchDevice};
20+
21+
#[cfg(not(feature = "tch"))]
22+
type B = Autodiff<Wgpu>;
23+
#[cfg(feature = "tch")]
24+
type B = Autodiff<LibTorch>;
25+
2026
extern crate sb3_burn;
2127

2228
fn main() {
2329
// Using parameters from:
2430
// https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/dqn.yml
2531

26-
type TrainBackend = Autodiff<LibTorch>;
27-
let train_device = LibTorchDevice::default();
32+
#[cfg(feature = "tch")]
33+
let train_device = if tch::utils::has_cuda()() {
34+
LibTorchDevice::Cuda(0)
35+
} else {
36+
LibTorchDevice::Cpu
37+
};
38+
39+
#[cfg(not(feature = "tch"))]
40+
let train_device = WgpuDevice::default();
2841

29-
sb3_seed::<TrainBackend>(1234, &train_device);
42+
sb3_seed::<B>(1234, &train_device);
3043

3144
let config_optimizer =
3245
AdamConfig::new().with_grad_clipping(Some(GradientClippingConfig::Norm(10.0)));
@@ -42,7 +55,7 @@ fn main() {
4255
.with_evaluate_during_training(false);
4356

4457
let env = ProbeEnvValueTest::default();
45-
let q: LinearDQNNet<TrainBackend> = LinearDQNNet::init(
58+
let q: LinearDQNNet<B> = LinearDQNNet::init(
4659
&train_device,
4760
env.observation_space().shape().len(),
4861
env.action_space().shape(),

examples/dqn_probe2.rs

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
use std::path::PathBuf;
22

3-
use burn::{
4-
backend::{libtorch::LibTorchDevice, Autodiff, LibTorch},
5-
grad_clipping::GradientClippingConfig,
6-
optim::AdamConfig,
7-
};
3+
use burn::{backend::Autodiff, grad_clipping::GradientClippingConfig, optim::AdamConfig};
84
use sb3_burn::{
95
common::{
106
algorithm::{OfflineAlgParams, OfflineTrainer},
@@ -17,14 +13,33 @@ use sb3_burn::{
1713
env::{base::Env, probe::ProbeEnvBackpropTest},
1814
};
1915

16+
#[cfg(not(feature = "tch"))]
17+
use burn::backend::{wgpu::WgpuDevice, Wgpu};
18+
#[cfg(feature = "tch")]
19+
use burn::backend::{LibTorch, LibTorchDevice};
20+
21+
#[cfg(not(feature = "tch"))]
22+
type B = Autodiff<Wgpu>;
23+
#[cfg(feature = "tch")]
24+
type B = Autodiff<LibTorch>;
25+
2026
extern crate sb3_burn;
2127

2228
fn main() {
2329
// Using parameters from:
2430
// https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/dqn.yml
25-
type TrainBackend = Autodiff<LibTorch>;
26-
let train_device = LibTorchDevice::default();
27-
sb3_seed::<TrainBackend>(1234, &train_device);
31+
32+
#[cfg(feature = "tch")]
33+
let train_device = if tch::utils::has_cuda()() {
34+
LibTorchDevice::Cuda(0)
35+
} else {
36+
LibTorchDevice::Cpu
37+
};
38+
39+
#[cfg(not(feature = "tch"))]
40+
let train_device = WgpuDevice::default();
41+
42+
sb3_seed::<B>(1234, &train_device);
2843

2944
let config_optimizer =
3045
AdamConfig::new().with_grad_clipping(Some(GradientClippingConfig::Norm(10.0)));
@@ -40,7 +55,7 @@ fn main() {
4055
.with_evaluate_during_training(false);
4156

4257
let env = ProbeEnvBackpropTest::default();
43-
let q: LinearDQNNet<TrainBackend> = LinearDQNNet::init(
58+
let q: LinearDQNNet<B> = LinearDQNNet::init(
4459
&train_device,
4560
env.observation_space().shape(),
4661
env.action_space().shape(),

examples/dqn_probe3.rs

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
use std::path::PathBuf;
22

3-
use burn::{
4-
backend::{libtorch::LibTorchDevice, Autodiff, LibTorch},
5-
grad_clipping::GradientClippingConfig,
6-
optim::AdamConfig,
7-
};
3+
use burn::{backend::Autodiff, grad_clipping::GradientClippingConfig, optim::AdamConfig};
84
use sb3_burn::{
95
common::{
106
algorithm::{OfflineAlgParams, OfflineTrainer},
@@ -17,16 +13,33 @@ use sb3_burn::{
1713
env::{base::Env, probe::ProbeEnvDiscountingTest},
1814
};
1915

16+
#[cfg(not(feature = "tch"))]
17+
use burn::backend::{wgpu::WgpuDevice, Wgpu};
18+
#[cfg(feature = "tch")]
19+
use burn::backend::{LibTorch, LibTorchDevice};
20+
21+
#[cfg(not(feature = "tch"))]
22+
type B = Autodiff<Wgpu>;
23+
#[cfg(feature = "tch")]
24+
type B = Autodiff<LibTorch>;
25+
2026
extern crate sb3_burn;
2127

2228
fn main() {
2329
// Using parameters from:
2430
// https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/dqn.yml
2531

26-
type TrainingBacked = Autodiff<LibTorch>;
32+
#[cfg(feature = "tch")]
33+
let train_device = if tch::utils::has_cuda()() {
34+
LibTorchDevice::Cuda(0)
35+
} else {
36+
LibTorchDevice::Cpu
37+
};
38+
39+
#[cfg(not(feature = "tch"))]
40+
let train_device = WgpuDevice::default();
2741

28-
let train_device = LibTorchDevice::default();
29-
sb3_seed::<TrainingBacked>(1234, &train_device);
42+
sb3_seed::<B>(1234, &train_device);
3043

3144
let config_optimizer =
3245
AdamConfig::new().with_grad_clipping(Some(GradientClippingConfig::Norm(10.0)));
@@ -42,7 +55,7 @@ fn main() {
4255
.with_evaluate_during_training(false);
4356

4457
let env = ProbeEnvDiscountingTest::default();
45-
let q = LinearAdvDQNNet::<TrainingBacked>::init(
58+
let q = LinearAdvDQNNet::<B>::init(
4659
&train_device,
4760
env.observation_space().shape(),
4861
env.action_space().shape(),

examples/dqn_probe4.rs

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
use std::path::PathBuf;
22

3-
use burn::{
4-
backend::{libtorch::LibTorchDevice, Autodiff, LibTorch},
5-
grad_clipping::GradientClippingConfig,
6-
optim::AdamConfig,
7-
};
3+
use burn::{backend::Autodiff, grad_clipping::GradientClippingConfig, optim::AdamConfig};
84
use sb3_burn::{
95
common::{
106
algorithm::{OfflineAlgParams, OfflineTrainer},
@@ -17,17 +13,33 @@ use sb3_burn::{
1713
env::{base::Env, probe::ProbeEnvActionTest},
1814
};
1915

16+
#[cfg(not(feature = "tch"))]
17+
use burn::backend::{wgpu::WgpuDevice, Wgpu};
18+
#[cfg(feature = "tch")]
19+
use burn::backend::{LibTorch, LibTorchDevice};
20+
21+
#[cfg(not(feature = "tch"))]
22+
type B = Autodiff<Wgpu>;
23+
#[cfg(feature = "tch")]
24+
type B = Autodiff<LibTorch>;
25+
2026
extern crate sb3_burn;
2127

2228
fn main() {
2329
// Using parameters from:
2430
// https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/dqn.yml
2531

26-
type TrainingBacked = Autodiff<LibTorch>;
32+
#[cfg(feature = "tch")]
33+
let train_device = if tch::utils::has_cuda()() {
34+
LibTorchDevice::Cuda(0)
35+
} else {
36+
LibTorchDevice::Cpu
37+
};
2738

28-
let train_device = LibTorchDevice::default();
39+
#[cfg(not(feature = "tch"))]
40+
let train_device = WgpuDevice::default();
2941

30-
sb3_seed::<TrainingBacked>(1234, &train_device);
42+
sb3_seed::<B>(1234, &train_device);
3143

3244
let config_optimizer =
3345
AdamConfig::new().with_grad_clipping(Some(GradientClippingConfig::Norm(10.0)));
@@ -43,7 +55,7 @@ fn main() {
4355
.with_evaluate_during_training(false);
4456

4557
let env = ProbeEnvActionTest::default();
46-
let q = LinearAdvDQNNet::<TrainingBacked>::init(
58+
let q = LinearAdvDQNNet::<B>::init(
4759
&train_device,
4860
env.observation_space().shape(),
4961
env.action_space().shape(),

0 commit comments

Comments
 (0)