@@ -8,128 +8,10 @@ use candle_core::{DType, Device, Tensor};
88use candle_nn:: {
99 linear, Linear , Module , VarBuilder , VarMap , Optimizer ,
1010} ;
11- mod mytest;
12- use mytest:: Number ;
1311mod mlp;
1412use mlp:: MLP ;
1513#[ pymodule]
1614fn rust_mlp ( m : & Bound < ' _ , PyModule > ) -> PyResult < ( ) > {
17- m. add_function ( wrap_pyfunction ! ( sum_as_string, m) ?) ?;
18- m. add_function ( wrap_pyfunction ! ( let_me_try, m) ?) ?;
19- m. add_class :: < Number > ( ) ?;
2015 m. add_class :: < MLP > ( ) ?;
2116 Ok ( ( ) )
22- }
23-
24- #[ pyfunction]
25- fn sum_as_string ( a : usize , b : usize ) -> PyResult < String > {
26- Ok ( ( a + b) . to_string ( ) )
27- }
28-
29- #[ pyfunction]
30- fn let_me_try ( ) -> PyResult < ( ) > {
31- run_training ( )
32- . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) )
33- }
34-
35- fn run_training ( ) -> anyhow:: Result < ( ) > {
36- let device = Device :: Cpu ;
37-
38- let file = std:: fs:: File :: open ( "fetch_california_housing.json" ) ?;
39- let reader = std:: io:: BufReader :: new ( file) ;
40- let data: Data = serde_json:: from_reader ( reader) ?;
41-
42- let train_d1 = data. X_train . len ( ) ;
43- let train_d2 = data. X_train [ 0 ] . len ( ) ;
44- let test_d1 = data. X_test . len ( ) ;
45- let test_d2 = data. X_test [ 0 ] . len ( ) ;
46-
47- let x_train_vec = data. X_train . into_iter ( ) . flatten ( ) . collect :: < Vec < _ > > ( ) ;
48- let x_test_vec = data. X_test . into_iter ( ) . flatten ( ) . collect :: < Vec < _ > > ( ) ;
49-
50- let y_train_vec = data. y_train ;
51- let y_test_vec = data. y_test ;
52-
53- let x_train = Tensor :: from_vec ( x_train_vec, ( train_d1, train_d2) , & device) ?;
54- let y_train = Tensor :: from_vec ( y_train_vec, ( train_d1, 1 ) , & device) ?;
55- let x_test = Tensor :: from_vec ( x_test_vec, ( test_d1, test_d2) , & device) ?;
56- let y_test = Tensor :: from_vec ( y_test_vec, ( test_d1, 1 ) , & device) ?;
57-
58- let varmap = VarMap :: new ( ) ;
59- let vb = VarBuilder :: from_varmap ( & varmap, DType :: F32 , & device) ;
60-
61- let model = SimpleNN :: new ( train_d2, vb) ?;
62-
63- let mut optimizer =
64- candle_nn:: AdamW :: new_lr ( varmap. all_vars ( ) , 1e-2 ) ?;
65-
66- train_model ( & model, & x_train, & y_train, & mut optimizer, 200 ) ?;
67- evaluate_model ( & model, & x_test, & y_test) ?;
68-
69- Ok ( ( ) )
70- }
71-
72- #[ derive( Debug ) ]
73- struct SimpleNN {
74- fc1 : Linear ,
75- fc2 : Linear ,
76- }
77-
78- impl SimpleNN {
79- fn new ( in_dim : usize , vb : VarBuilder ) -> candle_core:: Result < Self > {
80- let fc1 = linear ( in_dim, 64 , vb. pp ( "fc1" ) ) ?;
81- let fc2 = linear ( 64 , 1 , vb. pp ( "fc2" ) ) ?;
82- Ok ( Self { fc1, fc2 } )
83- }
84- }
85-
86- impl Module for SimpleNN {
87- fn forward ( & self , xs : & Tensor ) -> candle_core:: Result < Tensor > {
88- let x = self . fc1 . forward ( xs) ?;
89- let x = x. relu ( ) ?;
90- let x = self . fc2 . forward ( & x) ?;
91- Ok ( x)
92- }
93- }
94-
95- fn train_model (
96- model : & SimpleNN ,
97- x_train : & Tensor ,
98- y_train : & Tensor ,
99- optimizer : & mut candle_nn:: AdamW ,
100- epochs : usize ,
101- ) -> anyhow:: Result < ( ) > {
102- for epoch in 0 ..epochs {
103- let output = model. forward ( x_train) ?;
104- let loss = candle_nn:: loss:: mse ( & output, y_train) ?;
105- optimizer. backward_step ( & loss) ?;
106-
107- if ( epoch + 1 ) % 10 == 0 {
108- println ! (
109- "Epoch {} | Train Loss: {:.6}" ,
110- epoch + 1 ,
111- loss. to_scalar:: <f32 >( ) ?
112- ) ;
113- }
114- }
115- Ok ( ( ) )
116- }
117-
118- fn evaluate_model (
119- model : & SimpleNN ,
120- x_test : & Tensor ,
121- y_test : & Tensor ,
122- ) -> anyhow:: Result < ( ) > {
123- let output = model. forward ( x_test) ?;
124- let loss = candle_nn:: loss:: mse ( & output, y_test) ?;
125- println ! ( "Test Loss: {:.6}" , loss. to_scalar:: <f32 >( ) ?) ;
126- Ok ( ( ) )
127- }
128-
129- #[ derive( Debug , serde:: Deserialize ) ]
130- struct Data {
131- X_train : Vec < Vec < f32 > > ,
132- X_test : Vec < Vec < f32 > > ,
133- y_train : Vec < f32 > ,
134- y_test : Vec < f32 > ,
135- }
17+ }
0 commit comments