Skip to content

Commit a4d3fb8

Browse files
张扬发张扬发
authored andcommitted
add wrap_module in lib.rs
1 parent 0b3a045 commit a4d3fb8

File tree

3 files changed

+37
-30
lines changed

3 files changed

+37
-30
lines changed

src/lib.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
use pyo3::prelude::*;
44
use pyo3::exceptions::PyRuntimeError;
55
use pyo3::types::PyModule;
6+
use pyo3::wrap_pymodule;
7+
68

79
use candle_core::{DType, Device, Tensor};
810
use candle_nn::{
@@ -11,11 +13,10 @@ use candle_nn::{
1113

1214
pub mod base;
1315
pub mod mlp;
14-
use mlp::MLP;
1516

1617

1718
#[pymodule]
1819
fn rust_mlp(m: &Bound<'_, PyModule>) -> PyResult<()> {
19-
m.add_class::<MLP>()?;
20+
m.add_wrapped(wrap_pymodule!(mlp::mlp))?;
2021
Ok(())
2122
}

src/mlp.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@ use candle_nn::{
77
};
88
use crate::base::Model;
99
#[pyclass]
10-
pub struct MLP {
10+
pub struct PyMLP {
1111
model: SimpleNN,
1212
optimizer: candle_nn::AdamW,
1313
device: Device,
1414
}
1515

1616
#[pymethods]
17-
impl MLP {
17+
impl PyMLP {
1818
#[new]
1919
fn new(input_dim: usize, lr: f64) -> PyResult<Self> {
2020
let device = Device::Cpu;
@@ -39,7 +39,7 @@ impl MLP {
3939
Ok(())
4040
}
4141
}
42-
impl Model for MLP {
42+
impl Model for PyMLP {
4343
fn train(&mut self, x: Vec<Vec<f32>>, y: Vec<f32>, epochs: usize) -> PyResult<()> {
4444
let n = x.len();
4545
let d = x[0].len();
@@ -101,4 +101,10 @@ impl Module for SimpleNN {
101101
let x = self.fc2.forward(&x)?;
102102
Ok(x)
103103
}
104+
}
105+
106+
#[pymodule]
107+
pub fn mlp(m: &Bound<'_, PyModule>)-> PyResult<()>{
108+
m.add_class::<PyMLP>()?;
109+
Ok(())
104110
}

test.ipynb

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -46,45 +46,45 @@
4646
},
4747
{
4848
"cell_type": "code",
49-
"execution_count": 5,
49+
"execution_count": 7,
5050
"id": "db84d87a",
5151
"metadata": {},
5252
"outputs": [],
5353
"source": [
5454
"input_dim = len(X_train[0])\n",
55-
"model = rust_mlp.MLP(input_dim,lr = 0.01)"
55+
"model = rust_mlp.mlp.PyMLP(input_dim,lr = 0.01)"
5656
]
5757
},
5858
{
5959
"cell_type": "code",
60-
"execution_count": 6,
60+
"execution_count": 8,
6161
"id": "975ccf7f",
6262
"metadata": {},
6363
"outputs": [
6464
{
6565
"name": "stdout",
6666
"output_type": "stream",
6767
"text": [
68-
"Loss: 11.602837\n",
69-
"Loss: 9.104499\n",
70-
"Loss: 7.028557\n",
71-
"Loss: 5.354487\n",
72-
"Loss: 4.0523024\n",
73-
"Loss: 3.0819073\n",
74-
"Loss: 2.3972578\n",
75-
"Loss: 1.9500879\n",
76-
"Loss: 1.6904843\n",
77-
"Loss: 1.5687597\n",
78-
"Loss: 1.5355564\n",
79-
"Loss: 1.5495899\n",
80-
"Loss: 1.5760833\n",
81-
"Loss: 1.5907441\n",
82-
"Loss: 1.5783068\n",
83-
"Loss: 1.5354254\n",
84-
"Loss: 1.4655299\n",
85-
"Loss: 1.3767947\n",
86-
"Loss: 1.2791892\n",
87-
"Loss: 1.1817851\n"
68+
"Loss: 6.421524\n",
69+
"Loss: 4.9037886\n",
70+
"Loss: 3.821808\n",
71+
"Loss: 3.0536802\n",
72+
"Loss: 2.4780674\n",
73+
"Loss: 2.0242903\n",
74+
"Loss: 1.6653357\n",
75+
"Loss: 1.3944621\n",
76+
"Loss: 1.2099363\n",
77+
"Loss: 1.1070489\n",
78+
"Loss: 1.074543\n",
79+
"Loss: 1.0935023\n",
80+
"Loss: 1.1390615\n",
81+
"Loss: 1.185772\n",
82+
"Loss: 1.2136548\n",
83+
"Loss: 1.2125317\n",
84+
"Loss: 1.1814777\n",
85+
"Loss: 1.1264257\n",
86+
"Loss: 1.0574386\n",
87+
"Loss: 0.9845169\n"
8888
]
8989
}
9090
],
@@ -94,15 +94,15 @@
9494
},
9595
{
9696
"cell_type": "code",
97-
"execution_count": 7,
97+
"execution_count": 9,
9898
"id": "fd3e1b4c",
9999
"metadata": {},
100100
"outputs": [
101101
{
102102
"name": "stdout",
103103
"output_type": "stream",
104104
"text": [
105-
"Evaluation Loss: 1.2425249\n"
105+
"Evaluation Loss: 0.9266705\n"
106106
]
107107
}
108108
],

0 commit comments

Comments
 (0)