1- //! This module supports saving and restoring variables using Tensorflow checkpoints in SaveV2 format
2-
1+ //! This module supports saving and restoring variables using Tensorflow checkpoints in SaveV2 format.
2+ //! First, the user creates a [`CheckpointMaker`] attached to an [`Scope`] indicating the list of variables to be saved/restored.
3+ //! The CheckpointMaker lazily modifies the graph creating the nodes needed for saving/restoring.
4+ //! When one wants to save/restore from or into a session, one calls the save/restore methods
5+ //! # Example
6+ //! let mut scope = Scope::new_root_scope();
7+ //! // add operations to define the graph
8+ //! // ...
9+ //! // let "w" and "b" the name of the variables that we wish to save
10+ //! let mut checkpoint_maker = CheckpointMaker::new(scope.new_sub_scope("checkpoint"),
11+ //! vec![String::from("w"), String::from("b")].into_boxed_slice(),
12+ //! );
13+ //! let session = Session::new(&SessionOptions::new(), &scope.graph())?;
14+ //! // run some training
15+ //! // ...
16+ //! // to save the training
17+ //! checkpoint_maker.save(&session, "data/checkpoint")?;
18+ //! // then we restore in a different session to continue there
19+ //! let new_session = Session::new(&SessionOptions::new(), &scope.graph())?;
20+ //! checkpoint_maker.save(&new_session, "data/checkpoint")?;
321use crate :: option_insert_result:: OptionInsertWithResult ;
422use crate :: { ops, Operation , Scope , Session , SessionRunArgs , Status , Tensor } ;
523
624#[ derive( Debug ) ]
725struct SaveRestoreOps {
8- pub prefix_save : Operation ,
9- pub prefix_restore : Operation ,
10- pub save_op : Operation ,
11- pub restore_op : Operation ,
26+ prefix_save : Operation ,
27+ prefix_restore : Operation ,
28+ save_op : Operation ,
29+ restore_op : Operation ,
1230}
1331
14- /// Checkpointing and restoring struct
32+ /// Checkpointing and restoring support for Tensorflow.
33+ /// This struct is manages a scope, adds lazily the Tensorflow ops
34+ /// to perform the save/restore operations
1535#[ derive( Debug ) ]
1636pub struct CheckpointMaker {
1737 scope : Scope ,
@@ -33,19 +53,20 @@ impl CheckpointMaker {
3353 }
3454 }
3555
56+ fn make_all_variable_ops ( & mut self ) -> Result < Vec < Operation > , Status > {
57+ let graph = self . scope . graph ( ) ;
58+ Ok ( self
59+ . variables
60+ . iter ( )
61+ . map ( |v : & String | -> Result < Operation , Status > {
62+ Ok ( graph. operation_by_name_required ( v. as_str ( ) ) ?. clone ( ) )
63+ } )
64+ . collect :: < Result < Vec < _ > , Status > > ( ) ?)
65+ }
66+
3667 /// Add save and restore ops to the graph
3768 fn build_save_ops ( & mut self ) -> Result < SaveRestoreOps , Status > {
3869 let mut all_variable_ops_opt: Option < Vec < Operation > > = None ;
39- fn make_all_variable_ops ( myself : & mut CheckpointMaker ) -> Result < Vec < Operation > , Status > {
40- let graph = myself. scope . graph ( ) ;
41- Ok ( myself
42- . variables
43- . iter ( )
44- . map ( |v : & String | -> Result < Operation , Status > {
45- Ok ( graph. operation_by_name_required ( v. as_str ( ) ) ?. clone ( ) )
46- } )
47- . collect :: < Result < Vec < _ > , Status > > ( ) ?)
48- }
4970
5071 let existing_save_op = self . scope . graph ( ) . operation_by_name ( "save" ) ?;
5172 let ( prefix_save, save_op) = if let Some ( op) = existing_save_op {
@@ -56,7 +77,7 @@ impl CheckpointMaker {
5677 ( prefix_save_op, op)
5778 } else {
5879 let all_variable_ops =
59- all_variable_ops_opt. get_or_insert_with_result ( || make_all_variable_ops ( self ) ) ?;
80+ all_variable_ops_opt. get_or_insert_with_result ( || self . make_all_variable_ops ( ) ) ?;
6081 let prefix_save = ops:: Placeholder :: new ( )
6182 . dtype ( crate :: DataType :: String )
6283 . build ( & mut self . scope . with_op_name ( "prefix_save" ) ) ?;
@@ -105,39 +126,37 @@ impl CheckpointMaker {
105126 ( the_prefix_restore, op)
106127 } else {
107128 let all_variable_ops =
108- all_variable_ops_opt. get_or_insert_with_result ( || make_all_variable_ops ( self ) ) ?;
129+ all_variable_ops_opt. get_or_insert_with_result ( || self . make_all_variable_ops ( ) ) ?;
109130 let prefix_restore = ops:: Placeholder :: new ( )
110131 . dtype ( crate :: DataType :: String )
111132 . build ( & mut self . scope . with_op_name ( "prefix_restore" ) ) ?;
112- let restore_op = {
113- let all_var_names = self
133+ let all_var_names = self
134+ . variables
135+ . iter ( )
136+ . map ( |v| v. to_string ( ) )
137+ . collect :: < Vec < _ > > ( ) ;
138+ let tensor_names = ops:: constant ( & all_var_names[ ..] , & mut self . scope ) ?;
139+ let shape_and_slices = ops:: constant (
140+ & self
114141 . variables
115142 . iter ( )
116- . map ( |v| v. to_string ( ) )
117- . collect :: < Vec < _ > > ( ) ;
118- let tensor_names = ops:: constant ( & all_var_names[ ..] , & mut self . scope ) ?;
119- let shape_and_slices = ops:: constant (
120- & self
121- . variables
122- . iter ( )
123- . map ( |_| "" . to_string ( ) )
124- . collect :: < Vec < _ > > ( ) [ ..] ,
125- & mut self . scope ,
126- ) ?;
127- let mut g = self . scope . graph_mut ( ) ;
128- let mut nd = g. new_operation ( "RestoreV2" , "restore" ) ?;
129- nd. add_input ( prefix_restore. clone ( ) ) ;
130- nd. add_input ( tensor_names) ;
131- nd. add_input ( shape_and_slices) ;
132- let dtypes = all_variable_ops
133- . iter ( )
134- . map ( |v| v. get_attr_type ( "dtype" ) )
135- . collect :: < Result < Vec < _ > , Status > > ( ) ?;
136- nd. set_attr_type_list ( "dtypes" , & dtypes[ ..] ) ?;
137- nd. finish ( ) ?
138- } ;
139- {
140- let mut restore_var_ops = Vec :: < Operation > :: new ( ) ;
143+ . map ( |_| "" . to_string ( ) )
144+ . collect :: < Vec < _ > > ( ) [ ..] ,
145+ & mut self . scope ,
146+ ) ?;
147+ let mut g = self . scope . graph_mut ( ) ;
148+ let mut nd = g. new_operation ( "RestoreV2" , "restore" ) ?;
149+ nd. add_input ( prefix_restore. clone ( ) ) ;
150+ nd. add_input ( tensor_names) ;
151+ nd. add_input ( shape_and_slices) ;
152+ let dtypes = all_variable_ops
153+ . iter ( )
154+ . map ( |v| v. get_attr_type ( "dtype" ) )
155+ . collect :: < Result < Vec < _ > , Status > > ( ) ?;
156+ nd. set_attr_type_list ( "dtypes" , & dtypes[ ..] ) ?;
157+ let restore_op = nd. finish ( ) ?;
158+ drop ( g) ;
159+ let mut restore_var_ops = Vec :: < Operation > :: new ( ) ;
141160 for ( i, var) in self . variables . iter ( ) . enumerate ( ) {
142161 let var_op = self
143162 . scope
@@ -157,7 +176,6 @@ impl CheckpointMaker {
157176 no_op = no_op. add_control_input ( op) ;
158177 }
159178 ( prefix_restore, no_op. build ( & mut self . scope ) ?)
160- }
161179 } ;
162180 Ok ( SaveRestoreOps {
163181 prefix_save,
@@ -236,8 +254,8 @@ mod tests {
236254 }
237255
238256 struct MyScopeData {
239- pub scope : Scope ,
240- pub variables : [ Variable ; 3 ] ,
257+ scope : Scope ,
258+ variables : [ Variable ; 3 ] ,
241259 }
242260
243261 // Initialize a scope and place same variables in it
0 commit comments