Skip to content

Commit a3894aa

Browse files
张扬发张扬发
authored andcommitted
add class
1 parent c594ad3 commit a3894aa

File tree

4 files changed

+129
-5
lines changed

4 files changed

+129
-5
lines changed

src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@ use candle_nn::{
1010
};
1111
mod mytest;
1212
use mytest::Number;
13-
13+
mod mlp;
14+
use mlp::MLP;
1415
#[pymodule]
1516
fn rust_mlp(m: &Bound<'_, PyModule>) -> PyResult<()> {
1617
m.add_function(wrap_pyfunction!(sum_as_string, m)?)?;
1718
m.add_function(wrap_pyfunction!(let_me_try, m)?)?;
1819
m.add_class::<Number>()?;
20+
m.add_class::<MLP>()?;
1921
Ok(())
2022
}
2123

src/mlp.rs

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
use pyo3::prelude::*;
2+
use pyo3::exceptions::PyRuntimeError;
3+
use pyo3::types::PyModule;
4+
use candle_core::{DType, Device, Tensor};
5+
use candle_nn::{
6+
linear, Linear, Module, VarBuilder, VarMap, Optimizer,
7+
};
8+
9+
#[pyclass]
10+
pub struct MLP {
11+
model: SimpleNN,
12+
optimizer: candle_nn::AdamW,
13+
device: Device,
14+
}
15+
16+
#[pymethods]
17+
impl MLP {
18+
#[new]
19+
fn new(input_dim: usize, lr: f64) -> PyResult<Self> {
20+
let device = Device::Cpu;
21+
let varmap = VarMap::new();
22+
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
23+
let model = SimpleNN::new(input_dim, vb)
24+
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
25+
let optimizer = candle_nn::AdamW::new_lr(varmap.all_vars(), lr)
26+
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
27+
Ok(Self { model, optimizer, device })
28+
}
29+
30+
fn train(&mut self, x: Vec<Vec<f32>>, y: Vec<f32>, epochs: usize) -> PyResult<()> {
31+
let n = x.len();
32+
let d = x[0].len();
33+
let x_flat = x.into_iter().flatten().collect::<Vec<_>>();
34+
let x_tensor = Tensor::from_vec(x_flat, (n, d), &self.device)
35+
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
36+
let y_tensor = Tensor::from_vec(y, (n, 1), &self.device)
37+
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
38+
for _ in 0..epochs {
39+
let output = self.model.forward(&x_tensor)
40+
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
41+
let loss = candle_nn::loss::mse(&output, &y_tensor)
42+
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
43+
self.optimizer.backward_step(&loss)
44+
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
45+
}
46+
Ok(())
47+
}
48+
49+
fn evaluate(&self, x: Vec<Vec<f32>>, y: Vec<f32>) -> PyResult<f32> {
50+
let n = x.len();
51+
let d = x[0].len();
52+
let x_flat = x.into_iter().flatten().collect::<Vec<_>>();
53+
let x_tensor = Tensor::from_vec(x_flat, (n, d), &self.device)
54+
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
55+
let y_tensor = Tensor::from_vec(y, (n, 1), &self.device)
56+
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
57+
let output = self.model.forward(&x_tensor)
58+
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
59+
let loss = candle_nn::loss::mse(&output, &y_tensor)
60+
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
61+
Ok(loss.to_scalar::<f32>()
62+
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?)
63+
}
64+
}
65+
66+
#[derive(Debug)]
67+
struct SimpleNN {
68+
fc1: Linear,
69+
fc2: Linear,
70+
}
71+
72+
impl SimpleNN {
73+
fn new(in_dim: usize, vb: VarBuilder) -> candle_core::Result<Self> {
74+
let fc1 = linear(in_dim, 64, vb.pp("fc1"))?;
75+
let fc2 = linear(64, 1, vb.pp("fc2"))?;
76+
Ok(Self { fc1, fc2 })
77+
}
78+
}
79+
80+
impl Module for SimpleNN {
81+
fn forward(&self, xs: &Tensor) -> candle_core::Result<Tensor> {
82+
let x = self.fc1.forward(xs)?;
83+
let x = x.relu()?;
84+
let x = self.fc2.forward(&x)?;
85+
Ok(x)
86+
}
87+
}

src/mytest.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use pyo3::prelude::*;
2+
use pyo3::types::PyModule;
23

34
#[pyclass]
45
pub struct Number(i32);
@@ -9,4 +10,5 @@ impl Number {
910
fn new(value: i32) -> Self {
1011
Number(value)
1112
}
12-
}
13+
}
14+

test.ipynb

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,45 @@
1212
},
1313
{
1414
"cell_type": "code",
15-
"execution_count": 2,
15+
"execution_count": 3,
1616
"id": "bedc6943",
1717
"metadata": {},
18-
"outputs": [],
18+
"outputs": [
19+
{
20+
"data": {
21+
"text/plain": [
22+
"<Number at 0x106bce4f0>"
23+
]
24+
},
25+
"execution_count": 3,
26+
"metadata": {},
27+
"output_type": "execute_result"
28+
}
29+
],
30+
"source": [
31+
"x = rust_mlp.Number(42)\n",
32+
"x"
33+
]
34+
},
35+
{
36+
"cell_type": "code",
37+
"execution_count": 5,
38+
"id": "db84d87a",
39+
"metadata": {},
40+
"outputs": [
41+
{
42+
"data": {
43+
"text/plain": [
44+
"<MLP at 0x106bdd740>"
45+
]
46+
},
47+
"execution_count": 5,
48+
"metadata": {},
49+
"output_type": "execute_result"
50+
}
51+
],
1952
"source": [
20-
"x = rust_mlp.Number(42)"
53+
"rust_mlp.MLP(64,0.01)"
2154
]
2255
}
2356
],

0 commit comments

Comments
 (0)