Skip to content

Commit 7c1d08a

Browse files
张扬发张扬发
authored andcommitted
add class
1 parent a3894aa commit 7c1d08a

File tree

4 files changed

+119
-166
lines changed

4 files changed

+119
-166
lines changed

README.md

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,29 +13,40 @@ pip install Path/to/*.whl
1313
- 在python环境中使用
1414
```Python
1515
import rust_mlp
16-
rust_mlp.let_me_try()
16+
import json
17+
with open('fetch_california_housing.json', encoding='utf-8') as f:
18+
data = json.load(f)
19+
X_train = data['X_train']
20+
X_test = data['X_test']
21+
y_train = data['y_train']
22+
y_test = data['y_test']
23+
24+
input_dim = len(X_train[0])
25+
model = rust_mlp.MLP(input_dim,lr = 0.01)
26+
model.train(X_train,y_train, epochs = 20)
27+
model.evaluate(X_test,y_test)
1728
```
1829
预期结果如下:
1930
```
20-
Epoch 10 | Train Loss: 1.628042
21-
Epoch 20 | Train Loss: 1.206229
22-
Epoch 30 | Train Loss: 0.876774
23-
Epoch 40 | Train Loss: 0.737402
24-
Epoch 50 | Train Loss: 0.644500
25-
Epoch 60 | Train Loss: 0.576381
26-
Epoch 70 | Train Loss: 0.524079
27-
Epoch 80 | Train Loss: 0.482977
28-
Epoch 90 | Train Loss: 0.451547
29-
Epoch 100 | Train Loss: 0.431264
30-
Epoch 110 | Train Loss: 0.417604
31-
Epoch 120 | Train Loss: 0.407959
32-
Epoch 130 | Train Loss: 0.400804
33-
Epoch 140 | Train Loss: 0.395103
34-
Epoch 150 | Train Loss: 0.390289
35-
Epoch 160 | Train Loss: 0.386080
36-
Epoch 170 | Train Loss: 0.382233
37-
Epoch 180 | Train Loss: 0.378714
38-
Epoch 190 | Train Loss: 0.375436
39-
Epoch 200 | Train Loss: 0.372300
40-
Test Loss: 0.385307
31+
Loss: 4.4525466
32+
Loss: 3.3261764
33+
Loss: 2.539338
34+
Loss: 2.008424
35+
Loss: 1.6563666
36+
Loss: 1.4215927
37+
Loss: 1.2626717
38+
Loss: 1.1582772
39+
Loss: 1.0978976
40+
Loss: 1.0723116
41+
Loss: 1.0687258
42+
Loss: 1.073826
43+
Loss: 1.0768875
44+
Loss: 1.0705063
45+
Loss: 1.0518045
46+
Loss: 1.0223864
47+
Loss: 0.9862392
48+
Loss: 0.9472292
49+
Loss: 0.9076066
50+
Loss: 0.86793756
51+
Evaluation Loss: 0.82283765
4152
```

src/lib.rs

Lines changed: 1 addition & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -8,128 +8,10 @@ use candle_core::{DType, Device, Tensor};
88
use candle_nn::{
99
linear, Linear, Module, VarBuilder, VarMap, Optimizer,
1010
};
11-
mod mytest;
12-
use mytest::Number;
1311
mod mlp;
1412
use mlp::MLP;
1513
#[pymodule]
1614
fn rust_mlp(m: &Bound<'_, PyModule>) -> PyResult<()> {
17-
m.add_function(wrap_pyfunction!(sum_as_string, m)?)?;
18-
m.add_function(wrap_pyfunction!(let_me_try, m)?)?;
19-
m.add_class::<Number>()?;
2015
m.add_class::<MLP>()?;
2116
Ok(())
22-
}
23-
24-
#[pyfunction]
25-
fn sum_as_string(a: usize, b: usize) -> PyResult<String> {
26-
Ok((a + b).to_string())
27-
}
28-
29-
#[pyfunction]
30-
fn let_me_try() -> PyResult<()> {
31-
run_training()
32-
.map_err(|e| PyRuntimeError::new_err(e.to_string()))
33-
}
34-
35-
fn run_training() -> anyhow::Result<()> {
36-
let device = Device::Cpu;
37-
38-
let file = std::fs::File::open("fetch_california_housing.json")?;
39-
let reader = std::io::BufReader::new(file);
40-
let data: Data = serde_json::from_reader(reader)?;
41-
42-
let train_d1 = data.X_train.len();
43-
let train_d2 = data.X_train[0].len();
44-
let test_d1 = data.X_test.len();
45-
let test_d2 = data.X_test[0].len();
46-
47-
let x_train_vec = data.X_train.into_iter().flatten().collect::<Vec<_>>();
48-
let x_test_vec = data.X_test.into_iter().flatten().collect::<Vec<_>>();
49-
50-
let y_train_vec = data.y_train;
51-
let y_test_vec = data.y_test;
52-
53-
let x_train = Tensor::from_vec(x_train_vec, (train_d1, train_d2), &device)?;
54-
let y_train = Tensor::from_vec(y_train_vec, (train_d1, 1), &device)?;
55-
let x_test = Tensor::from_vec(x_test_vec, (test_d1, test_d2), &device)?;
56-
let y_test = Tensor::from_vec(y_test_vec, (test_d1, 1), &device)?;
57-
58-
let varmap = VarMap::new();
59-
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
60-
61-
let model = SimpleNN::new(train_d2, vb)?;
62-
63-
let mut optimizer =
64-
candle_nn::AdamW::new_lr(varmap.all_vars(), 1e-2)?;
65-
66-
train_model(&model, &x_train, &y_train, &mut optimizer, 200)?;
67-
evaluate_model(&model, &x_test, &y_test)?;
68-
69-
Ok(())
70-
}
71-
72-
#[derive(Debug)]
73-
struct SimpleNN {
74-
fc1: Linear,
75-
fc2: Linear,
76-
}
77-
78-
impl SimpleNN {
79-
fn new(in_dim: usize, vb: VarBuilder) -> candle_core::Result<Self> {
80-
let fc1 = linear(in_dim, 64, vb.pp("fc1"))?;
81-
let fc2 = linear(64, 1, vb.pp("fc2"))?;
82-
Ok(Self { fc1, fc2 })
83-
}
84-
}
85-
86-
impl Module for SimpleNN {
87-
fn forward(&self, xs: &Tensor) -> candle_core::Result<Tensor> {
88-
let x = self.fc1.forward(xs)?;
89-
let x = x.relu()?;
90-
let x = self.fc2.forward(&x)?;
91-
Ok(x)
92-
}
93-
}
94-
95-
fn train_model(
96-
model: &SimpleNN,
97-
x_train: &Tensor,
98-
y_train: &Tensor,
99-
optimizer: &mut candle_nn::AdamW,
100-
epochs: usize,
101-
) -> anyhow::Result<()> {
102-
for epoch in 0..epochs {
103-
let output = model.forward(x_train)?;
104-
let loss = candle_nn::loss::mse(&output, y_train)?;
105-
optimizer.backward_step(&loss)?;
106-
107-
if (epoch + 1) % 10 == 0 {
108-
println!(
109-
"Epoch {} | Train Loss: {:.6}",
110-
epoch + 1,
111-
loss.to_scalar::<f32>()?
112-
);
113-
}
114-
}
115-
Ok(())
116-
}
117-
118-
fn evaluate_model(
119-
model: &SimpleNN,
120-
x_test: &Tensor,
121-
y_test: &Tensor,
122-
) -> anyhow::Result<()> {
123-
let output = model.forward(x_test)?;
124-
let loss = candle_nn::loss::mse(&output, y_test)?;
125-
println!("Test Loss: {:.6}", loss.to_scalar::<f32>()?);
126-
Ok(())
127-
}
128-
129-
#[derive(Debug, serde::Deserialize)]
130-
struct Data {
131-
X_train: Vec<Vec<f32>>,
132-
X_test: Vec<Vec<f32>>,
133-
y_train: Vec<f32>,
134-
y_test: Vec<f32>,
135-
}
17+
}

src/mlp.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,15 @@ impl MLP {
4040
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
4141
let loss = candle_nn::loss::mse(&output, &y_tensor)
4242
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
43+
println!("Loss: {}", loss.to_scalar::<f32>()
44+
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?);
4345
self.optimizer.backward_step(&loss)
4446
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
4547
}
4648
Ok(())
4749
}
4850

49-
fn evaluate(&self, x: Vec<Vec<f32>>, y: Vec<f32>) -> PyResult<f32> {
51+
fn evaluate(&self, x: Vec<Vec<f32>>, y: Vec<f32>) -> PyResult<()> {
5052
let n = x.len();
5153
let d = x[0].len();
5254
let x_flat = x.into_iter().flatten().collect::<Vec<_>>();
@@ -58,8 +60,9 @@ impl MLP {
5860
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
5961
let loss = candle_nn::loss::mse(&output, &y_tensor)
6062
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
61-
Ok(loss.to_scalar::<f32>()
62-
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?)
63+
println!("Evaluation Loss: {}", loss.to_scalar::<f32>()
64+
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?);
65+
Ok(())
6366
}
6467
}
6568

test.ipynb

Lines changed: 79 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,47 +10,104 @@
1010
"import rust_mlp"
1111
]
1212
},
13+
{
14+
"cell_type": "code",
15+
"execution_count": 2,
16+
"id": "4ca52908",
17+
"metadata": {},
18+
"outputs": [],
19+
"source": [
20+
"import json"
21+
]
22+
},
1323
{
1424
"cell_type": "code",
1525
"execution_count": 3,
16-
"id": "bedc6943",
26+
"id": "34d74ef8",
27+
"metadata": {},
28+
"outputs": [],
29+
"source": [
30+
"with open('fetch_california_housing.json', encoding='utf-8') as f:\n",
31+
" data = json.load(f) "
32+
]
33+
},
34+
{
35+
"cell_type": "code",
36+
"execution_count": 4,
37+
"id": "d0bb55a4",
38+
"metadata": {},
39+
"outputs": [],
40+
"source": [
41+
"X_train = data['X_train']\n",
42+
"X_test = data['X_test']\n",
43+
"y_train = data['y_train']\n",
44+
"y_test = data['y_test']"
45+
]
46+
},
47+
{
48+
"cell_type": "code",
49+
"execution_count": 5,
50+
"id": "db84d87a",
51+
"metadata": {},
52+
"outputs": [],
53+
"source": [
54+
"input_dim = len(X_train[0])\n",
55+
"model = rust_mlp.MLP(input_dim,lr = 0.01)"
56+
]
57+
},
58+
{
59+
"cell_type": "code",
60+
"execution_count": 6,
61+
"id": "975ccf7f",
1762
"metadata": {},
1863
"outputs": [
1964
{
20-
"data": {
21-
"text/plain": [
22-
"<Number at 0x106bce4f0>"
23-
]
24-
},
25-
"execution_count": 3,
26-
"metadata": {},
27-
"output_type": "execute_result"
65+
"name": "stdout",
66+
"output_type": "stream",
67+
"text": [
68+
"Loss: 4.4525466\n",
69+
"Loss: 3.3261764\n",
70+
"Loss: 2.539338\n",
71+
"Loss: 2.008424\n",
72+
"Loss: 1.6563666\n",
73+
"Loss: 1.4215927\n",
74+
"Loss: 1.2626717\n",
75+
"Loss: 1.1582772\n",
76+
"Loss: 1.0978976\n",
77+
"Loss: 1.0723116\n",
78+
"Loss: 1.0687258\n",
79+
"Loss: 1.073826\n",
80+
"Loss: 1.0768875\n",
81+
"Loss: 1.0705063\n",
82+
"Loss: 1.0518045\n",
83+
"Loss: 1.0223864\n",
84+
"Loss: 0.9862392\n",
85+
"Loss: 0.9472292\n",
86+
"Loss: 0.9076066\n",
87+
"Loss: 0.86793756\n"
88+
]
2889
}
2990
],
3091
"source": [
31-
"x = rust_mlp.Number(42)\n",
32-
"x"
92+
"model.train(X_train,y_train, epochs = 20)"
3393
]
3494
},
3595
{
3696
"cell_type": "code",
37-
"execution_count": 5,
38-
"id": "db84d87a",
97+
"execution_count": 7,
98+
"id": "fd3e1b4c",
3999
"metadata": {},
40100
"outputs": [
41101
{
42-
"data": {
43-
"text/plain": [
44-
"<MLP at 0x106bdd740>"
45-
]
46-
},
47-
"execution_count": 5,
48-
"metadata": {},
49-
"output_type": "execute_result"
102+
"name": "stdout",
103+
"output_type": "stream",
104+
"text": [
105+
"Evaluation Loss: 0.82283765\n"
106+
]
50107
}
51108
],
52109
"source": [
53-
"rust_mlp.MLP(64,0.01)"
110+
"model.evaluate(X_test,y_test)"
54111
]
55112
}
56113
],

0 commit comments

Comments
 (0)