Skip to content

Commit 7d7444d

Browse files
committed
cleaning up examples for sb3-tch feature
1 parent 294d944 commit 7d7444d

15 files changed

+113
-113
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@ version = "0.1.0"
44
edition = "2021"
55

66
[features]
7-
default = ["ndarray", "wgpu"]
8-
tch = ["burn/tch"]
7+
default = ["ndarray", "wgpu", "sb3-tch"]
8+
sb3-tch = ["burn/tch", "tch"]
99
ndarray = ["burn/ndarray"]
1010
wgpu = ["burn/wgpu"]
1111

1212
[dependencies]
1313
assert_approx_eq = "1.1.0"
1414
burn = { version = "0.19.0", features = ["autodiff", "train"]}
15+
tch = {version="0.22.0", optional=true}
1516
csv = "1.3.0"
1617
dyn-clone = "1.0.17"
1718
indicatif = "0.17.8"

examples/dqn_cartpole.rs

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
use std::path::PathBuf;
22

3-
use burn::{
4-
backend::{wgpu::WgpuDevice, Autodiff},
5-
grad_clipping::GradientClippingConfig,
6-
optim::AdamConfig,
7-
};
3+
use burn::{backend::Autodiff, grad_clipping::GradientClippingConfig, optim::AdamConfig};
84
use sb3_burn::{
95
common::{
106
algorithm::{OfflineAlgParams, OfflineTrainer},
@@ -17,14 +13,14 @@ use sb3_burn::{
1713
env::{base::Env, classic_control::cartpole::CartpoleEnv},
1814
};
1915

20-
#[cfg(not(feature = "tch"))]
21-
use burn::backend::Wgpu;
22-
#[cfg(feature = "tch")]
23-
use burn::backend::{LibTorch, LibTorchDevice};
16+
#[cfg(feature = "sb3-tch")]
17+
use burn::backend::{libtorch::LibTorchDevice, LibTorch};
18+
#[cfg(not(feature = "sb3-tch"))]
19+
use burn::backend::{wgpu::WgpuDevice, Wgpu};
2420

25-
#[cfg(not(feature = "tch"))]
21+
#[cfg(not(feature = "sb3-tch"))]
2622
type B = Autodiff<Wgpu>;
27-
#[cfg(feature = "tch")]
23+
#[cfg(feature = "sb3-tch")]
2824
type B = Autodiff<LibTorch>;
2925

3026
extern crate sb3_burn;
@@ -33,14 +29,16 @@ fn main() {
3329
// Using parameters from:
3430
// https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/dqn.yml
3531

36-
#[cfg(feature = "tch")]
37-
let train_device = if tch::utils::has_cuda()() {
32+
#[cfg(feature = "sb3-tch")]
33+
let train_device = if tch::utils::has_cuda() {
34+
println!("Using LibTorch (GPU)");
3835
LibTorchDevice::Cuda(0)
3936
} else {
37+
println!("Using LibTorch (CPU)");
4038
LibTorchDevice::Cpu
4139
};
4240

43-
#[cfg(not(feature = "tch"))]
41+
#[cfg(not(feature = "sb3-tch"))]
4442
let train_device = WgpuDevice::default();
4543

4644
sb3_seed::<B>(1234, &train_device);
@@ -103,7 +101,7 @@ fn main() {
103101
buffer,
104102
Box::new(logger),
105103
None,
106-
EvalConfig::new().with_n_eval_episodes(100),
104+
EvalConfig::new().with_n_eval_episodes(5),
107105
&train_device,
108106
);
109107

examples/dqn_gridworld.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ use sb3_burn::{
1313
env::{base::Env, gridworld::GridWorldEnv},
1414
};
1515

16-
#[cfg(not(feature = "tch"))]
16+
#[cfg(feature = "sb3-tch")]
17+
use burn::backend::{libtorch::LibTorchDevice, LibTorch};
18+
#[cfg(not(feature = "sb3-tch"))]
1719
use burn::backend::{wgpu::WgpuDevice, Wgpu};
18-
#[cfg(feature = "tch")]
19-
use burn::backend::{LibTorch, LibTorchDevice};
2020

21-
#[cfg(not(feature = "tch"))]
21+
#[cfg(not(feature = "sb3-tch"))]
2222
type B = Autodiff<Wgpu>;
23-
#[cfg(feature = "tch")]
23+
#[cfg(feature = "sb3-tch")]
2424
type B = Autodiff<LibTorch>;
2525

2626
extern crate sb3_burn;
@@ -29,14 +29,14 @@ fn main() {
2929
// Using parameters from:
3030
// https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/dqn.yml
3131

32-
#[cfg(feature = "tch")]
33-
let train_device = if tch::utils::has_cuda()() {
32+
#[cfg(feature = "sb3-tch")]
33+
let train_device = if tch::utils::has_cuda() {
3434
LibTorchDevice::Cuda(0)
3535
} else {
3636
LibTorchDevice::Cpu
3737
};
3838

39-
#[cfg(not(feature = "tch"))]
39+
#[cfg(not(feature = "sb3-tch"))]
4040
let train_device = WgpuDevice::default();
4141

4242
sb3_seed::<B>(1234, &train_device);

examples/dqn_mountaincar.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ use sb3_burn::{
1313
env::{base::Env, classic_control::mountain_car::MountainCarEnv},
1414
};
1515

16-
#[cfg(not(feature = "tch"))]
16+
#[cfg(feature = "sb3-tch")]
17+
use burn::backend::{libtorch::LibTorchDevice, LibTorch};
18+
#[cfg(not(feature = "sb3-tch"))]
1719
use burn::backend::{wgpu::WgpuDevice, Wgpu};
18-
#[cfg(feature = "tch")]
19-
use burn::backend::{LibTorch, LibTorchDevice};
2020

21-
#[cfg(not(feature = "tch"))]
21+
#[cfg(not(feature = "sb3-tch"))]
2222
type B = Autodiff<Wgpu>;
23-
#[cfg(feature = "tch")]
23+
#[cfg(feature = "sb3-tch")]
2424
type B = Autodiff<LibTorch>;
2525

2626
extern crate sb3_burn;
@@ -29,14 +29,14 @@ fn main() {
2929
// Using parameters from:
3030
// https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/dqn.yml
3131

32-
#[cfg(feature = "tch")]
33-
let train_device = if tch::utils::has_cuda()() {
32+
#[cfg(feature = "sb3-tch")]
33+
let train_device = if tch::utils::has_cuda() {
3434
LibTorchDevice::Cuda(0)
3535
} else {
3636
LibTorchDevice::Cpu
3737
};
3838

39-
#[cfg(not(feature = "tch"))]
39+
#[cfg(not(feature = "sb3-tch"))]
4040
let train_device = WgpuDevice::default();
4141

4242
sb3_seed::<B>(1234, &train_device);

examples/dqn_probe1.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ use sb3_burn::{
1313
env::{base::Env, probe::ProbeEnvValueTest},
1414
};
1515

16-
#[cfg(not(feature = "tch"))]
16+
#[cfg(feature = "sb3-tch")]
17+
use burn::backend::{libtorch::LibTorchDevice, LibTorch};
18+
#[cfg(not(feature = "sb3-tch"))]
1719
use burn::backend::{wgpu::WgpuDevice, Wgpu};
18-
#[cfg(feature = "tch")]
19-
use burn::backend::{LibTorch, LibTorchDevice};
2020

21-
#[cfg(not(feature = "tch"))]
21+
#[cfg(not(feature = "sb3-tch"))]
2222
type B = Autodiff<Wgpu>;
23-
#[cfg(feature = "tch")]
23+
#[cfg(feature = "sb3-tch")]
2424
type B = Autodiff<LibTorch>;
2525

2626
extern crate sb3_burn;
@@ -29,14 +29,14 @@ fn main() {
2929
// Using parameters from:
3030
// https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/dqn.yml
3131

32-
#[cfg(feature = "tch")]
33-
let train_device = if tch::utils::has_cuda()() {
32+
#[cfg(feature = "sb3-tch")]
33+
let train_device = if tch::utils::has_cuda() {
3434
LibTorchDevice::Cuda(0)
3535
} else {
3636
LibTorchDevice::Cpu
3737
};
3838

39-
#[cfg(not(feature = "tch"))]
39+
#[cfg(not(feature = "sb3-tch"))]
4040
let train_device = WgpuDevice::default();
4141

4242
sb3_seed::<B>(1234, &train_device);

examples/dqn_probe2.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ use sb3_burn::{
1313
env::{base::Env, probe::ProbeEnvBackpropTest},
1414
};
1515

16-
#[cfg(not(feature = "tch"))]
16+
#[cfg(feature = "sb3-tch")]
17+
use burn::backend::{libtorch::LibTorchDevice, LibTorch};
18+
#[cfg(not(feature = "sb3-tch"))]
1719
use burn::backend::{wgpu::WgpuDevice, Wgpu};
18-
#[cfg(feature = "tch")]
19-
use burn::backend::{LibTorch, LibTorchDevice};
2020

21-
#[cfg(not(feature = "tch"))]
21+
#[cfg(not(feature = "sb3-tch"))]
2222
type B = Autodiff<Wgpu>;
23-
#[cfg(feature = "tch")]
23+
#[cfg(feature = "sb3-tch")]
2424
type B = Autodiff<LibTorch>;
2525

2626
extern crate sb3_burn;
@@ -29,14 +29,14 @@ fn main() {
2929
// Using parameters from:
3030
// https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/dqn.yml
3131

32-
#[cfg(feature = "tch")]
33-
let train_device = if tch::utils::has_cuda()() {
32+
#[cfg(feature = "sb3-tch")]
33+
let train_device = if tch::utils::has_cuda() {
3434
LibTorchDevice::Cuda(0)
3535
} else {
3636
LibTorchDevice::Cpu
3737
};
3838

39-
#[cfg(not(feature = "tch"))]
39+
#[cfg(not(feature = "sb3-tch"))]
4040
let train_device = WgpuDevice::default();
4141

4242
sb3_seed::<B>(1234, &train_device);

examples/dqn_probe3.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ use sb3_burn::{
1313
env::{base::Env, probe::ProbeEnvDiscountingTest},
1414
};
1515

16-
#[cfg(not(feature = "tch"))]
16+
#[cfg(feature = "sb3-tch")]
17+
use burn::backend::{libtorch::LibTorchDevice, LibTorch};
18+
#[cfg(not(feature = "sb3-tch"))]
1719
use burn::backend::{wgpu::WgpuDevice, Wgpu};
18-
#[cfg(feature = "tch")]
19-
use burn::backend::{LibTorch, LibTorchDevice};
2020

21-
#[cfg(not(feature = "tch"))]
21+
#[cfg(not(feature = "sb3-tch"))]
2222
type B = Autodiff<Wgpu>;
23-
#[cfg(feature = "tch")]
23+
#[cfg(feature = "sb3-tch")]
2424
type B = Autodiff<LibTorch>;
2525

2626
extern crate sb3_burn;
@@ -29,14 +29,14 @@ fn main() {
2929
// Using parameters from:
3030
// https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/dqn.yml
3131

32-
#[cfg(feature = "tch")]
33-
let train_device = if tch::utils::has_cuda()() {
32+
#[cfg(feature = "sb3-tch")]
33+
let train_device = if tch::utils::has_cuda() {
3434
LibTorchDevice::Cuda(0)
3535
} else {
3636
LibTorchDevice::Cpu
3737
};
3838

39-
#[cfg(not(feature = "tch"))]
39+
#[cfg(not(feature = "sb3-tch"))]
4040
let train_device = WgpuDevice::default();
4141

4242
sb3_seed::<B>(1234, &train_device);

examples/dqn_probe4.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ use sb3_burn::{
1313
env::{base::Env, probe::ProbeEnvActionTest},
1414
};
1515

16-
#[cfg(not(feature = "tch"))]
16+
#[cfg(feature = "sb3-tch")]
17+
use burn::backend::{libtorch::LibTorchDevice, LibTorch};
18+
#[cfg(not(feature = "sb3-tch"))]
1719
use burn::backend::{wgpu::WgpuDevice, Wgpu};
18-
#[cfg(feature = "tch")]
19-
use burn::backend::{LibTorch, LibTorchDevice};
2020

21-
#[cfg(not(feature = "tch"))]
21+
#[cfg(not(feature = "sb3-tch"))]
2222
type B = Autodiff<Wgpu>;
23-
#[cfg(feature = "tch")]
23+
#[cfg(feature = "sb3-tch")]
2424
type B = Autodiff<LibTorch>;
2525

2626
extern crate sb3_burn;
@@ -29,14 +29,14 @@ fn main() {
2929
// Using parameters from:
3030
// https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/dqn.yml
3131

32-
#[cfg(feature = "tch")]
33-
let train_device = if tch::utils::has_cuda()() {
32+
#[cfg(feature = "sb3-tch")]
33+
let train_device = if tch::utils::has_cuda() {
3434
LibTorchDevice::Cuda(0)
3535
} else {
3636
LibTorchDevice::Cpu
3737
};
3838

39-
#[cfg(not(feature = "tch"))]
39+
#[cfg(not(feature = "sb3-tch"))]
4040
let train_device = WgpuDevice::default();
4141

4242
sb3_seed::<B>(1234, &train_device);

examples/dqn_probe5.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ use sb3_burn::{
1313
env::{base::Env, probe::ProbeEnvStateActionTest},
1414
};
1515

16-
#[cfg(not(feature = "tch"))]
16+
#[cfg(feature = "sb3-tch")]
17+
use burn::backend::{libtorch::LibTorchDevice, LibTorch};
18+
#[cfg(not(feature = "sb3-tch"))]
1719
use burn::backend::{wgpu::WgpuDevice, Wgpu};
18-
#[cfg(feature = "tch")]
19-
use burn::backend::{LibTorch, LibTorchDevice};
2020

21-
#[cfg(not(feature = "tch"))]
21+
#[cfg(not(feature = "sb3-tch"))]
2222
type B = Autodiff<Wgpu>;
23-
#[cfg(feature = "tch")]
23+
#[cfg(feature = "sb3-tch")]
2424
type B = Autodiff<LibTorch>;
2525

2626
extern crate sb3_burn;
@@ -29,14 +29,14 @@ fn main() {
2929
// Using parameters from:
3030
// https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/dqn.yml
3131

32-
#[cfg(feature = "tch")]
33-
let train_device = if tch::utils::has_cuda()() {
32+
#[cfg(feature = "sb3-tch")]
33+
let train_device = if tch::utils::has_cuda() {
3434
LibTorchDevice::Cuda(0)
3535
} else {
3636
LibTorchDevice::Cpu
3737
};
3838

39-
#[cfg(not(feature = "tch"))]
39+
#[cfg(not(feature = "sb3-tch"))]
4040
let train_device = WgpuDevice::default();
4141

4242
sb3_seed::<B>(1234, &train_device);

0 commit comments

Comments
 (0)