Skip to content

Commit 5dc5ae7

Browse files
authored
Feature flags (#49)
* playing with features and compiler flags * installling deps in ci * installling deps in ci * installling deps in ci * updating examples * cleaning warnings * only action on main * cleaning up examples for sb3-tch feature
1 parent 9644697 commit 5dc5ae7

19 files changed

+335
-130
lines changed

.github/workflows/cov.yml

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
name: Coverage
22

3-
on: [pull_request, push]
3+
on:
4+
pull_request:
5+
branches:
6+
- main
7+
push:
8+
branches:
9+
- main
410

511
jobs:
612
coverage:
@@ -9,15 +15,17 @@ jobs:
915
CARGO_TERM_COLOR: always
1016
steps:
1117
- uses: actions/checkout@v4
18+
- name: Install system deps
19+
run: sudo apt install pkg-config libfreetype6-dev libfontconfig1-dev -y
1220
- name: Install Rust
1321
run: rustup update stable
1422
- name: Install cargo-llvm-cov
1523
uses: taiki-e/install-action@cargo-llvm-cov
1624
- name: Generate code coverage
17-
run: cargo llvm-cov --all-features --workspace --lcov --output-path lcov.info
25+
run: cargo llvm-cov --workspace --lcov --output-path lcov.info
1826
- name: Upload coverage to Codecov
1927
uses: codecov/codecov-action@v3
2028
with:
2129
token: ${{ secrets.CODECOV_TOKEN }}
2230
files: lcov.info
23-
fail_ci_if_error: true
31+
fail_ci_if_error: true

.github/workflows/rust.yml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
name: Continuous Integration
22

3-
on: [push, pull_request]
3+
on:
4+
pull_request:
5+
branches:
6+
- main
7+
push:
8+
branches:
9+
- main
410

511
jobs:
612
build_and_test:
713
runs-on: ubuntu-latest
814

915
steps:
1016
- uses: actions/checkout@v2
17+
- name: Install system deps
18+
run: sudo apt install pkg-config libfreetype6-dev libfontconfig1-dev -y
1119
- name: ⚡ Cache
1220
uses: actions/cache@v4
1321
with:

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,16 @@ name = "sb3-burn"
33
version = "0.1.0"
44
edition = "2021"
55

6+
[features]
7+
default = ["ndarray", "wgpu", "sb3-tch"]
8+
sb3-tch = ["burn/tch", "tch"]
9+
ndarray = ["burn/ndarray"]
10+
wgpu = ["burn/wgpu"]
11+
612
[dependencies]
713
assert_approx_eq = "1.1.0"
8-
burn = { version = "0.19.0", features = ["ndarray", "wgpu", "autodiff", "train", "tch"]}
14+
burn = { version = "0.19.0", features = ["autodiff", "train"]}
15+
tch = {version="0.22.0", optional=true}
916
csv = "1.3.0"
1017
dyn-clone = "1.0.17"
1118
indicatif = "0.17.8"

examples/dqn_cartpole.rs

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
use std::path::PathBuf;
22

3-
use burn::{
4-
backend::{libtorch::LibTorchDevice, Autodiff, LibTorch},
5-
grad_clipping::GradientClippingConfig,
6-
optim::AdamConfig,
7-
};
3+
use burn::{backend::Autodiff, grad_clipping::GradientClippingConfig, optim::AdamConfig};
84
use sb3_burn::{
95
common::{
106
algorithm::{OfflineAlgParams, OfflineTrainer},
@@ -17,16 +13,35 @@ use sb3_burn::{
1713
env::{base::Env, classic_control::cartpole::CartpoleEnv},
1814
};
1915

16+
#[cfg(feature = "sb3-tch")]
17+
use burn::backend::{libtorch::LibTorchDevice, LibTorch};
18+
#[cfg(not(feature = "sb3-tch"))]
19+
use burn::backend::{wgpu::WgpuDevice, Wgpu};
20+
21+
#[cfg(not(feature = "sb3-tch"))]
22+
type B = Autodiff<Wgpu>;
23+
#[cfg(feature = "sb3-tch")]
24+
type B = Autodiff<LibTorch>;
25+
2026
extern crate sb3_burn;
2127

2228
fn main() {
2329
// Using parameters from:
2430
// https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/dqn.yml
2531

26-
type TrainDevice = Autodiff<LibTorch>;
27-
let train_device = LibTorchDevice::Cuda(0);
32+
#[cfg(feature = "sb3-tch")]
33+
let train_device = if tch::utils::has_cuda() {
34+
println!("Using LibTorch (GPU)");
35+
LibTorchDevice::Cuda(0)
36+
} else {
37+
println!("Using LibTorch (CPU)");
38+
LibTorchDevice::Cpu
39+
};
40+
41+
#[cfg(not(feature = "sb3-tch"))]
42+
let train_device = WgpuDevice::default();
2843

29-
sb3_seed::<TrainDevice>(1234, &train_device);
44+
sb3_seed::<B>(1234, &train_device);
3045

3146
let config_optimizer =
3247
AdamConfig::new().with_grad_clipping(Some(GradientClippingConfig::Norm(10.0)));
@@ -46,7 +61,7 @@ fn main() {
4661
.with_train_every(256);
4762

4863
let env = CartpoleEnv::new(500);
49-
let q: LinearAdvDQNNet<TrainDevice> = LinearAdvDQNNet::init(
64+
let q: LinearAdvDQNNet<B> = LinearAdvDQNNet::init(
5065
&train_device,
5166
env.observation_space().shape().len(),
5267
env.action_space().shape(),
@@ -86,7 +101,7 @@ fn main() {
86101
buffer,
87102
Box::new(logger),
88103
None,
89-
EvalConfig::new().with_n_eval_episodes(100),
104+
EvalConfig::new().with_n_eval_episodes(5),
90105
&train_device,
91106
);
92107

examples/dqn_gridworld.rs

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
use std::path::PathBuf;
22

3-
use burn::{
4-
backend::{libtorch::LibTorchDevice, Autodiff, LibTorch},
5-
grad_clipping::GradientClippingConfig,
6-
optim::AdamConfig,
7-
};
3+
use burn::{backend::Autodiff, grad_clipping::GradientClippingConfig, optim::AdamConfig};
84
use sb3_burn::{
95
common::{
106
algorithm::{OfflineAlgParams, OfflineTrainer},
@@ -17,17 +13,33 @@ use sb3_burn::{
1713
env::{base::Env, gridworld::GridWorldEnv},
1814
};
1915

16+
#[cfg(feature = "sb3-tch")]
17+
use burn::backend::{libtorch::LibTorchDevice, LibTorch};
18+
#[cfg(not(feature = "sb3-tch"))]
19+
use burn::backend::{wgpu::WgpuDevice, Wgpu};
20+
21+
#[cfg(not(feature = "sb3-tch"))]
22+
type B = Autodiff<Wgpu>;
23+
#[cfg(feature = "sb3-tch")]
24+
type B = Autodiff<LibTorch>;
25+
2026
extern crate sb3_burn;
2127

2228
fn main() {
2329
// Using parameters from:
2430
// https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/dqn.yml
2531

26-
type TrainingBacked = Autodiff<LibTorch>;
32+
#[cfg(feature = "sb3-tch")]
33+
let train_device = if tch::utils::has_cuda() {
34+
LibTorchDevice::Cuda(0)
35+
} else {
36+
LibTorchDevice::Cpu
37+
};
2738

28-
let train_device = LibTorchDevice::Cuda(0);
39+
#[cfg(not(feature = "sb3-tch"))]
40+
let train_device = WgpuDevice::default();
2941

30-
sb3_seed::<TrainingBacked>(1234, &train_device);
42+
sb3_seed::<B>(1234, &train_device);
3143

3244
let config_optimizer =
3345
AdamConfig::new().with_grad_clipping(Some(GradientClippingConfig::Norm(10.0)));
@@ -40,7 +52,7 @@ fn main() {
4052
.with_lr(1e-3);
4153

4254
let env = GridWorldEnv::default();
43-
let q = LinearAdvDQNNet::<TrainingBacked>::init(
55+
let q = LinearAdvDQNNet::<B>::init(
4456
&train_device,
4557
env.observation_space().shape().len(),
4658
env.action_space().shape(),

examples/dqn_mountaincar.rs

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
use std::path::PathBuf;
22

3-
use burn::{
4-
backend::{libtorch::LibTorchDevice, Autodiff, LibTorch},
5-
grad_clipping::GradientClippingConfig,
6-
optim::AdamConfig,
7-
};
3+
use burn::{backend::Autodiff, grad_clipping::GradientClippingConfig, optim::AdamConfig};
84
use sb3_burn::{
95
common::{
106
algorithm::{OfflineAlgParams, OfflineTrainer},
@@ -17,17 +13,33 @@ use sb3_burn::{
1713
env::{base::Env, classic_control::mountain_car::MountainCarEnv},
1814
};
1915

16+
#[cfg(feature = "sb3-tch")]
17+
use burn::backend::{libtorch::LibTorchDevice, LibTorch};
18+
#[cfg(not(feature = "sb3-tch"))]
19+
use burn::backend::{wgpu::WgpuDevice, Wgpu};
20+
21+
#[cfg(not(feature = "sb3-tch"))]
22+
type B = Autodiff<Wgpu>;
23+
#[cfg(feature = "sb3-tch")]
24+
type B = Autodiff<LibTorch>;
25+
2026
extern crate sb3_burn;
2127

2228
fn main() {
2329
// Using parameters from:
2430
// https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/dqn.yml
2531

26-
type TrainingBacked = Autodiff<LibTorch>;
32+
#[cfg(feature = "sb3-tch")]
33+
let train_device = if tch::utils::has_cuda() {
34+
LibTorchDevice::Cuda(0)
35+
} else {
36+
LibTorchDevice::Cpu
37+
};
2738

28-
let train_device = LibTorchDevice::Cuda(0);
39+
#[cfg(not(feature = "sb3-tch"))]
40+
let train_device = WgpuDevice::default();
2941

30-
sb3_seed::<TrainingBacked>(1234, &train_device);
42+
sb3_seed::<B>(1234, &train_device);
3143

3244
let config_optimizer =
3345
AdamConfig::new().with_grad_clipping(Some(GradientClippingConfig::Norm(10.0)));
@@ -46,7 +58,7 @@ fn main() {
4658
.with_train_every(16);
4759

4860
let env = MountainCarEnv::default();
49-
let q = LinearAdvDQNNet::<TrainingBacked>::init(
61+
let q = LinearAdvDQNNet::<B>::init(
5062
&train_device,
5163
env.observation_space().shape().len(),
5264
env.action_space().shape(),

examples/dqn_probe1.rs

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
use std::path::PathBuf;
22

3-
use burn::{
4-
backend::{libtorch::LibTorchDevice, Autodiff, LibTorch},
5-
grad_clipping::GradientClippingConfig,
6-
optim::AdamConfig,
7-
};
3+
use burn::{backend::Autodiff, grad_clipping::GradientClippingConfig, optim::AdamConfig};
84
use sb3_burn::{
95
common::{
106
algorithm::{OfflineAlgParams, OfflineTrainer},
@@ -17,16 +13,33 @@ use sb3_burn::{
1713
env::{base::Env, probe::ProbeEnvValueTest},
1814
};
1915

16+
#[cfg(feature = "sb3-tch")]
17+
use burn::backend::{libtorch::LibTorchDevice, LibTorch};
18+
#[cfg(not(feature = "sb3-tch"))]
19+
use burn::backend::{wgpu::WgpuDevice, Wgpu};
20+
21+
#[cfg(not(feature = "sb3-tch"))]
22+
type B = Autodiff<Wgpu>;
23+
#[cfg(feature = "sb3-tch")]
24+
type B = Autodiff<LibTorch>;
25+
2026
extern crate sb3_burn;
2127

2228
fn main() {
2329
// Using parameters from:
2430
// https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/dqn.yml
2531

26-
type TrainBackend = Autodiff<LibTorch>;
27-
let train_device = LibTorchDevice::default();
32+
#[cfg(feature = "sb3-tch")]
33+
let train_device = if tch::utils::has_cuda() {
34+
LibTorchDevice::Cuda(0)
35+
} else {
36+
LibTorchDevice::Cpu
37+
};
38+
39+
#[cfg(not(feature = "sb3-tch"))]
40+
let train_device = WgpuDevice::default();
2841

29-
sb3_seed::<TrainBackend>(1234, &train_device);
42+
sb3_seed::<B>(1234, &train_device);
3043

3144
let config_optimizer =
3245
AdamConfig::new().with_grad_clipping(Some(GradientClippingConfig::Norm(10.0)));
@@ -42,7 +55,7 @@ fn main() {
4255
.with_evaluate_during_training(false);
4356

4457
let env = ProbeEnvValueTest::default();
45-
let q: LinearDQNNet<TrainBackend> = LinearDQNNet::init(
58+
let q: LinearDQNNet<B> = LinearDQNNet::init(
4659
&train_device,
4760
env.observation_space().shape().len(),
4861
env.action_space().shape(),

examples/dqn_probe2.rs

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
use std::path::PathBuf;
22

3-
use burn::{
4-
backend::{libtorch::LibTorchDevice, Autodiff, LibTorch},
5-
grad_clipping::GradientClippingConfig,
6-
optim::AdamConfig,
7-
};
3+
use burn::{backend::Autodiff, grad_clipping::GradientClippingConfig, optim::AdamConfig};
84
use sb3_burn::{
95
common::{
106
algorithm::{OfflineAlgParams, OfflineTrainer},
@@ -17,14 +13,33 @@ use sb3_burn::{
1713
env::{base::Env, probe::ProbeEnvBackpropTest},
1814
};
1915

16+
#[cfg(feature = "sb3-tch")]
17+
use burn::backend::{libtorch::LibTorchDevice, LibTorch};
18+
#[cfg(not(feature = "sb3-tch"))]
19+
use burn::backend::{wgpu::WgpuDevice, Wgpu};
20+
21+
#[cfg(not(feature = "sb3-tch"))]
22+
type B = Autodiff<Wgpu>;
23+
#[cfg(feature = "sb3-tch")]
24+
type B = Autodiff<LibTorch>;
25+
2026
extern crate sb3_burn;
2127

2228
fn main() {
2329
// Using parameters from:
2430
// https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/dqn.yml
25-
type TrainBackend = Autodiff<LibTorch>;
26-
let train_device = LibTorchDevice::default();
27-
sb3_seed::<TrainBackend>(1234, &train_device);
31+
32+
#[cfg(feature = "sb3-tch")]
33+
let train_device = if tch::utils::has_cuda() {
34+
LibTorchDevice::Cuda(0)
35+
} else {
36+
LibTorchDevice::Cpu
37+
};
38+
39+
#[cfg(not(feature = "sb3-tch"))]
40+
let train_device = WgpuDevice::default();
41+
42+
sb3_seed::<B>(1234, &train_device);
2843

2944
let config_optimizer =
3045
AdamConfig::new().with_grad_clipping(Some(GradientClippingConfig::Norm(10.0)));
@@ -40,7 +55,7 @@ fn main() {
4055
.with_evaluate_during_training(false);
4156

4257
let env = ProbeEnvBackpropTest::default();
43-
let q: LinearDQNNet<TrainBackend> = LinearDQNNet::init(
58+
let q: LinearDQNNet<B> = LinearDQNNet::init(
4459
&train_device,
4560
env.observation_space().shape(),
4661
env.action_space().shape(),

0 commit comments

Comments
 (0)