File tree Expand file tree Collapse file tree 3 files changed +37
-30
lines changed
Expand file tree Collapse file tree 3 files changed +37
-30
lines changed Original file line number Diff line number Diff line change 33use pyo3:: prelude:: * ;
44use pyo3:: exceptions:: PyRuntimeError ;
55use pyo3:: types:: PyModule ;
6+ use pyo3:: wrap_pymodule;
7+
68
79use candle_core:: { DType , Device , Tensor } ;
810use candle_nn:: {
@@ -11,11 +13,10 @@ use candle_nn::{
1113
1214pub mod base;
1315pub mod mlp;
14- use mlp:: MLP ;
1516
1617
1718#[ pymodule]
1819fn rust_mlp ( m : & Bound < ' _ , PyModule > ) -> PyResult < ( ) > {
19- m. add_class :: < MLP > ( ) ?;
20+ m. add_wrapped ( wrap_pymodule ! ( mlp :: mlp ) ) ?;
2021 Ok ( ( ) )
2122}
Original file line number Diff line number Diff line change @@ -7,14 +7,14 @@ use candle_nn::{
77} ;
88use crate :: base:: Model ;
99#[ pyclass]
10- pub struct MLP {
10+ pub struct PyMLP {
1111 model : SimpleNN ,
1212 optimizer : candle_nn:: AdamW ,
1313 device : Device ,
1414}
1515
1616#[ pymethods]
17- impl MLP {
17+ impl PyMLP {
1818 #[ new]
1919 fn new ( input_dim : usize , lr : f64 ) -> PyResult < Self > {
2020 let device = Device :: Cpu ;
@@ -39,7 +39,7 @@ impl MLP {
3939 Ok ( ( ) )
4040 }
4141}
42- impl Model for MLP {
42+ impl Model for PyMLP {
4343 fn train ( & mut self , x : Vec < Vec < f32 > > , y : Vec < f32 > , epochs : usize ) -> PyResult < ( ) > {
4444 let n = x. len ( ) ;
4545 let d = x[ 0 ] . len ( ) ;
@@ -101,4 +101,10 @@ impl Module for SimpleNN {
101101 let x = self . fc2 . forward ( & x) ?;
102102 Ok ( x)
103103 }
104+ }
105+
106+ #[ pymodule]
107+ pub fn mlp ( m : & Bound < ' _ , PyModule > ) -> PyResult < ( ) > {
108+ m. add_class :: < PyMLP > ( ) ?;
109+ Ok ( ( ) )
104110}
Original file line number Diff line number Diff line change 4646 },
4747 {
4848 "cell_type" : " code" ,
49- "execution_count" : 5 ,
49+ "execution_count" : 7 ,
5050 "id" : " db84d87a" ,
5151 "metadata" : {},
5252 "outputs" : [],
5353 "source" : [
5454 " input_dim = len(X_train[0])\n " ,
55- " model = rust_mlp.MLP (input_dim,lr = 0.01)"
55+ " model = rust_mlp.mlp.PyMLP (input_dim,lr = 0.01)"
5656 ]
5757 },
5858 {
5959 "cell_type" : " code" ,
60- "execution_count" : 6 ,
60+ "execution_count" : 8 ,
6161 "id" : " 975ccf7f" ,
6262 "metadata" : {},
6363 "outputs" : [
6464 {
6565 "name" : " stdout" ,
6666 "output_type" : " stream" ,
6767 "text" : [
68- " Loss: 11.602837 \n " ,
69- " Loss: 9.104499 \n " ,
70- " Loss: 7.028557 \n " ,
71- " Loss: 5.354487 \n " ,
72- " Loss: 4.0523024 \n " ,
73- " Loss: 3.0819073 \n " ,
74- " Loss: 2.3972578 \n " ,
75- " Loss: 1.9500879 \n " ,
76- " Loss: 1.6904843 \n " ,
77- " Loss: 1.5687597 \n " ,
78- " Loss: 1.5355564 \n " ,
79- " Loss: 1.5495899 \n " ,
80- " Loss: 1.5760833 \n " ,
81- " Loss: 1.5907441 \n " ,
82- " Loss: 1.5783068 \n " ,
83- " Loss: 1.5354254 \n " ,
84- " Loss: 1.4655299 \n " ,
85- " Loss: 1.3767947 \n " ,
86- " Loss: 1.2791892 \n " ,
87- " Loss: 1.1817851 \n "
68+ " Loss: 6.421524 \n " ,
69+ " Loss: 4.9037886 \n " ,
70+ " Loss: 3.821808 \n " ,
71+ " Loss: 3.0536802 \n " ,
72+ " Loss: 2.4780674 \n " ,
73+ " Loss: 2.0242903 \n " ,
74+ " Loss: 1.6653357 \n " ,
75+ " Loss: 1.3944621 \n " ,
76+ " Loss: 1.2099363 \n " ,
77+ " Loss: 1.1070489 \n " ,
78+ " Loss: 1.074543 \n " ,
79+ " Loss: 1.0935023 \n " ,
80+ " Loss: 1.1390615 \n " ,
81+ " Loss: 1.185772 \n " ,
82+ " Loss: 1.2136548 \n " ,
83+ " Loss: 1.2125317 \n " ,
84+ " Loss: 1.1814777 \n " ,
85+ " Loss: 1.1264257 \n " ,
86+ " Loss: 1.0574386 \n " ,
87+ " Loss: 0.9845169 \n "
8888 ]
8989 }
9090 ],
9494 },
9595 {
9696 "cell_type" : " code" ,
97- "execution_count" : 7 ,
97+ "execution_count" : 9 ,
9898 "id" : " fd3e1b4c" ,
9999 "metadata" : {},
100100 "outputs" : [
101101 {
102102 "name" : " stdout" ,
103103 "output_type" : " stream" ,
104104 "text" : [
105- " Evaluation Loss: 1.2425249 \n "
105+ " Evaluation Loss: 0.9266705 \n "
106106 ]
107107 }
108108 ],
You can’t perform that action at this time.
0 commit comments