Skip to content

Commit 8e3ca6d

Browse files
authored
Move optimizer components to burn-optim (#3773)
* Move optim, grad_clipping and lr_scheduler to burn-optim * Add publish workflow * Fix docs link
1 parent c339df5 commit 8e3ca6d

File tree

49 files changed

+451
-240
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+451
-240
lines changed

.github/workflows/publish.yml

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,6 @@ jobs:
256256
- publish-burn-derive
257257
- publish-burn-tensor
258258
- publish-burn-vision
259-
- publish-burn-collective
260259
# dev dependencies
261260
- publish-burn-autodiff
262261
- publish-burn-wgpu
@@ -286,6 +285,23 @@ jobs:
286285
secrets:
287286
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
288287

288+
publish-burn-optim:
289+
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v4
290+
needs:
291+
- publish-burn-core
292+
- publish-burn-collective
293+
# dev dependencies
294+
- publish-burn-autodiff
295+
- publish-burn-wgpu
296+
- publish-burn-tch
297+
- publish-burn-ndarray
298+
- publish-burn-candle
299+
- publish-burn-remote
300+
with:
301+
crate: burn-optim
302+
secrets:
303+
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
304+
289305
publish-burn-train:
290306
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v4
291307
needs:

Cargo.lock

Lines changed: 26 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/burn-core/Cargo.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ std = [
5050
]
5151
vision = ["burn-vision", "burn-dataset?/vision", "burn-common/network"]
5252
audio = ["burn-dataset?/audio"]
53-
collective = ["burn-collective"]
5453

5554
# Custom deserializer for Record that is helpful for importing data, such as PyTorch pt files.
5655
record-item-custom-serde = ["thiserror"]
@@ -91,7 +90,6 @@ burn-dataset = { path = "../burn-dataset", version = "0.19.0", optional = true,
9190
burn-derive = { path = "../burn-derive", version = "0.19.0" }
9291
burn-tensor = { path = "../burn-tensor", version = "0.19.0", default-features = false }
9392
burn-vision = { path = "../burn-vision", version = "0.19.0", optional = true, default-features = false }
94-
burn-collective = { path = "../burn-collective", version = "0.19.0", optional = true, default-features = false }
9593

9694
data-encoding = { workspace = true }
9795
uuid = { workspace = true }

crates/burn-core/src/lib.rs

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,6 @@ pub mod config;
1818
#[cfg(feature = "std")]
1919
pub mod data;
2020

21-
/// Optimizer module.
22-
pub mod optim;
23-
24-
/// Learning rate scheduler module.
25-
#[cfg(feature = "std")]
26-
pub mod lr_scheduler;
27-
28-
/// Gradient clipping module.
29-
pub mod grad_clipping;
30-
3121
/// Module for the neural network module.
3222
pub mod module;
3323

@@ -87,7 +77,6 @@ mod test_utils {
8777
use crate::module::Param;
8878
use burn_tensor::Tensor;
8979
use burn_tensor::backend::Backend;
90-
use burn_tensor::module::linear;
9180

9281
/// Simple linear module.
9382
#[derive(Module, Debug)]
@@ -110,23 +99,9 @@ mod test_utils {
11099
bias: Some(Param::from_tensor(bias)),
111100
}
112101
}
113-
114-
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
115-
linear(
116-
input,
117-
self.weight.val(),
118-
self.bias.as_ref().map(|b| b.val()),
119-
)
120-
}
121102
}
122103
}
123104

124-
/// Type alias for the learning rate.
125-
///
126-
/// LearningRate also implements [learning rate scheduler](crate::lr_scheduler::LrScheduler) so it
127-
/// can be used for constant learning rate.
128-
pub type LearningRate = f64; // We could potentially change the type.
129-
130105
pub mod prelude {
131106
//! Structs and macros used by most projects. Add `use
132107
//! burn::prelude::*` to your code to quickly get started with

crates/burn-nn/src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#![cfg_attr(not(feature = "std"), no_std)]
22
#![warn(missing_docs)]
33
#![cfg_attr(docsrs, feature(doc_auto_cfg))]
4-
// #![recursion_limit = "135"]
54

65
//! Burn neural network module.
76

crates/burn-optim/Cargo.toml

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
[package]
2+
authors = ["nathanielsimard <[email protected]>"]
3+
categories = ["science", "no-std", "embedded", "wasm"]
4+
description = "Optimizer building blocks for the Burn deep learning framework"
5+
documentation = "https://docs.rs/burn-optim"
6+
edition.workspace = true
7+
keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"]
8+
license.workspace = true
9+
name = "burn-optim"
10+
readme.workspace = true
11+
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-optim"
12+
version.workspace = true
13+
14+
[lints]
15+
workspace = true
16+
17+
[features]
18+
default = [
19+
"std",
20+
"burn-core/default",
21+
]
22+
doc = [
23+
"std",
24+
# Doc features
25+
"burn-core/doc",
26+
]
27+
std = [
28+
"burn-core/std",
29+
"num-traits/std",
30+
"serde/std",
31+
"log",
32+
]
33+
34+
collective = ["burn-collective"]
35+
36+
test-cuda = [
37+
"burn-cuda/default",
38+
] # To use cuda during testing, default uses ndarray.
39+
test-rocm = [
40+
"burn-rocm/default",
41+
] # To use hip during testing, default uses ndarray.
42+
test-tch = [
43+
"burn-tch/default",
44+
] # To use tch during testing, default uses ndarray.
45+
test-wgpu = [
46+
"burn-wgpu/default",
47+
] # To use wgpu during testing, default uses ndarray.
48+
test-vulkan = [
49+
"test-wgpu",
50+
"burn-wgpu/vulkan",
51+
] # To use wgpu-spirv during testing, default uses ndarray.
52+
test-metal = [
53+
"test-wgpu",
54+
"burn-wgpu/metal",
55+
] # To use wgpu-spirv during testing, default uses ndarray.
56+
57+
# Memory checks are disabled by default
58+
test-memory-checks = ["burn-fusion/memory-checks"]
59+
60+
[dependencies]
61+
62+
# ** Please make sure all dependencies support no_std when std is disabled **
63+
burn-core = { path = "../burn-core", version = "0.19.0", default-features = false }
64+
burn-collective = { path = "../burn-collective", version = "0.19.0", optional = true, default-features = false }
65+
66+
num-traits = { workspace = true }
67+
derive-new = { workspace = true }
68+
log = { workspace = true, optional = true }
69+
serde = { workspace = true, features = ["derive"] }
70+
71+
# The same implementation of HashMap in std but with no_std support (only alloc crate is needed)
72+
hashbrown = { workspace = true, features = ["serde"] } # no_std compatible
73+
74+
# FOR TESTING
75+
burn-cuda = { path = "../burn-cuda", version = "0.19.0", optional = true, default-features = false }
76+
burn-rocm = { path = "../burn-rocm", version = "0.19.0", optional = true, default-features = false }
77+
burn-remote = { path = "../burn-remote", version = "0.19.0", default-features = false, optional = true }
78+
burn-router = { path = "../burn-router", version = "0.19.0", default-features = false, optional = true }
79+
burn-tch = { path = "../burn-tch", version = "0.19.0", optional = true }
80+
burn-wgpu = { path = "../burn-wgpu", version = "0.19.0", optional = true, default-features = false }
81+
burn-fusion = { path = "../burn-fusion", version = "0.19.0", optional = true }
82+
83+
[dev-dependencies]
84+
burn-nn = { path = "../burn-nn", version = "0.19.0" }
85+
burn-ndarray = { path = "../burn-ndarray", version = "0.19.0" }
86+
burn-autodiff = { path = "../burn-autodiff", version = "0.19.0" }
87+
rstest = { workspace = true }
88+
89+
[package.metadata.docs.rs]
90+
features = ["doc"]
91+
rustdoc-args = ["--cfg", "docsrs"]

crates/burn-optim/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Burn Optimizers
2+
3+
Core building blocks for Burn optimizers.

crates/burn-core/src/grad_clipping/base.rs renamed to crates/burn-optim/src/grad_clipping/base.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
use crate as burn;
1+
use burn_core as burn;
22

3-
use crate::{config::Config, tensor::Tensor};
4-
use burn_tensor::backend::Backend;
3+
use burn::tensor::backend::Backend;
4+
use burn::{config::Config, tensor::Tensor};
55

66
/// Gradient Clipping provides a way to mitigate exploding gradients
77
#[derive(Config, Debug)]
@@ -91,7 +91,7 @@ impl GradientClipping {
9191
mod tests {
9292
use super::*;
9393
use crate::TestBackend;
94-
use crate::tensor::Tensor;
94+
use burn::tensor::Tensor;
9595

9696
#[test]
9797
fn test_clip_by_value() {

crates/burn-optim/src/lib.rs

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#![cfg_attr(not(feature = "std"), no_std)]
2+
#![warn(missing_docs)]
3+
#![cfg_attr(docsrs, feature(doc_auto_cfg))]
4+
5+
//! Burn optimizers.
6+
7+
#[macro_use]
8+
extern crate derive_new;
9+
10+
extern crate alloc;
11+
12+
/// Optimizer module.
13+
pub mod optim;
14+
pub use optim::*;
15+
16+
/// Gradient clipping module.
17+
pub mod grad_clipping;
18+
19+
/// Learning rate scheduler module.
20+
#[cfg(feature = "std")]
21+
pub mod lr_scheduler;
22+
23+
/// Type alias for the learning rate.
24+
///
25+
/// LearningRate also implements [learning rate scheduler](crate::lr_scheduler::LrScheduler) so it
26+
/// can be used for constant learning rate.
27+
pub type LearningRate = f64; // We could potentially change the type.
28+
29+
/// Backend for test cases
30+
#[cfg(all(
31+
test,
32+
not(feature = "test-tch"),
33+
not(feature = "test-wgpu"),
34+
not(feature = "test-cuda"),
35+
not(feature = "test-rocm")
36+
))]
37+
pub type TestBackend = burn_ndarray::NdArray<f32>;
38+
39+
#[cfg(all(test, feature = "test-tch"))]
40+
/// Backend for test cases
41+
pub type TestBackend = burn_tch::LibTorch<f32>;
42+
43+
#[cfg(all(test, feature = "test-wgpu"))]
44+
/// Backend for test cases
45+
pub type TestBackend = burn_wgpu::Wgpu;
46+
47+
#[cfg(all(test, feature = "test-cuda"))]
48+
/// Backend for test cases
49+
pub type TestBackend = burn_cuda::Cuda;
50+
51+
#[cfg(all(test, feature = "test-rocm"))]
52+
/// Backend for test cases
53+
pub type TestBackend = burn_rocm::Rocm;
54+
55+
/// Backend for autodiff test cases
56+
#[cfg(test)]
57+
pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;
58+
59+
#[cfg(all(test, feature = "test-memory-checks"))]
60+
mod tests {
61+
burn_fusion::memory_checks!();
62+
}

0 commit comments

Comments
 (0)