1+ use pyo3:: prelude:: * ;
2+ use pyo3:: exceptions:: PyRuntimeError ;
3+ use pyo3:: types:: PyModule ;
4+ use candle_core:: { DType , Device , Tensor } ;
5+ use candle_nn:: {
6+ linear, Linear , Module , VarBuilder , VarMap , Optimizer ,
7+ } ;
8+
9+ #[ pyclass]
10+ pub struct MLP {
11+ model : SimpleNN ,
12+ optimizer : candle_nn:: AdamW ,
13+ device : Device ,
14+ }
15+
16+ #[ pymethods]
17+ impl MLP {
18+ #[ new]
19+ fn new ( input_dim : usize , lr : f64 ) -> PyResult < Self > {
20+ let device = Device :: Cpu ;
21+ let varmap = VarMap :: new ( ) ;
22+ let vb = VarBuilder :: from_varmap ( & varmap, DType :: F32 , & device) ;
23+ let model = SimpleNN :: new ( input_dim, vb)
24+ . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
25+ let optimizer = candle_nn:: AdamW :: new_lr ( varmap. all_vars ( ) , lr)
26+ . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
27+ Ok ( Self { model, optimizer, device } )
28+ }
29+
30+ fn train ( & mut self , x : Vec < Vec < f32 > > , y : Vec < f32 > , epochs : usize ) -> PyResult < ( ) > {
31+ let n = x. len ( ) ;
32+ let d = x[ 0 ] . len ( ) ;
33+ let x_flat = x. into_iter ( ) . flatten ( ) . collect :: < Vec < _ > > ( ) ;
34+ let x_tensor = Tensor :: from_vec ( x_flat, ( n, d) , & self . device )
35+ . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
36+ let y_tensor = Tensor :: from_vec ( y, ( n, 1 ) , & self . device )
37+ . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
38+ for _ in 0 ..epochs {
39+ let output = self . model . forward ( & x_tensor)
40+ . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
41+ let loss = candle_nn:: loss:: mse ( & output, & y_tensor)
42+ . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
43+ self . optimizer . backward_step ( & loss)
44+ . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
45+ }
46+ Ok ( ( ) )
47+ }
48+
49+ fn evaluate ( & self , x : Vec < Vec < f32 > > , y : Vec < f32 > ) -> PyResult < f32 > {
50+ let n = x. len ( ) ;
51+ let d = x[ 0 ] . len ( ) ;
52+ let x_flat = x. into_iter ( ) . flatten ( ) . collect :: < Vec < _ > > ( ) ;
53+ let x_tensor = Tensor :: from_vec ( x_flat, ( n, d) , & self . device )
54+ . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
55+ let y_tensor = Tensor :: from_vec ( y, ( n, 1 ) , & self . device )
56+ . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
57+ let output = self . model . forward ( & x_tensor)
58+ . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
59+ let loss = candle_nn:: loss:: mse ( & output, & y_tensor)
60+ . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
61+ Ok ( loss. to_scalar :: < f32 > ( )
62+ . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?)
63+ }
64+ }
65+
66+ #[ derive( Debug ) ]
67+ struct SimpleNN {
68+ fc1 : Linear ,
69+ fc2 : Linear ,
70+ }
71+
72+ impl SimpleNN {
73+ fn new ( in_dim : usize , vb : VarBuilder ) -> candle_core:: Result < Self > {
74+ let fc1 = linear ( in_dim, 64 , vb. pp ( "fc1" ) ) ?;
75+ let fc2 = linear ( 64 , 1 , vb. pp ( "fc2" ) ) ?;
76+ Ok ( Self { fc1, fc2 } )
77+ }
78+ }
79+
80+ impl Module for SimpleNN {
81+ fn forward ( & self , xs : & Tensor ) -> candle_core:: Result < Tensor > {
82+ let x = self . fc1 . forward ( xs) ?;
83+ let x = x. relu ( ) ?;
84+ let x = self . fc2 . forward ( & x) ?;
85+ Ok ( x)
86+ }
87+ }
0 commit comments