1- //! This module supports saving and restoring variables using Tensorflow checkpoints in SaveV2 format.
2- //! First, the user creates a [`CheckpointMaker`] attached to an [`Scope`] indicating the list of variables to be saved/restored.
3- //! The CheckpointMaker lazily modifies the graph creating the nodes needed for saving/restoring.
4- //! When one wants to save/restore from or into a session, one calls the save/restore methods
5- //! # Example
6- //! let mut scope = Scope::new_root_scope();
7- //! // add operations to define the graph
8- //! // ...
9- //! // let w and b the variables that we wish to save
10- //! let mut checkpoint_maker = CheckpointMaker::new(scope.new_sub_scope("checkpoint"),
11- //! vec![w.clone(), b.clone()].into_boxed_slice(),
12- //! );
13- //! let session = Session::new(&SessionOptions::new(), &scope.graph())?;
14- //! // run some training
15- //! // ...
16- //! // to save the training
17- //! checkpoint_maker.save(&session, "data/checkpoint")?;
18- //! // then we restore in a different session to continue there
19- //! let new_session = Session::new(&SessionOptions::new(), &scope.graph())?;
20- //! checkpoint_maker.save(&new_session, "data/checkpoint")?;
211use crate :: option_insert_result:: OptionInsertWithResult ;
222use crate :: { ops, Operation , Scope , Session , SessionRunArgs , Status , Tensor , Variable } ;
233
@@ -29,9 +9,29 @@ struct SaveRestoreOps {
299 restore_op : Operation ,
3010}
3111
32- /// Checkpointing and restoring support for Tensorflow.
33- /// This struct is manages a scope, adds lazily the Tensorflow ops
34- /// to perform the save/restore operations
12+ /// This struct supports saving and restoring variables using Tensorflow checkpoints in SaveV2 format.
13+ /// First, the user creates a [`CheckpointMaker`] attached to an [`Scope`] indicating the list of variables to be saved/restored.
14+ /// The CheckpointMaker lazily modifies the graph creating the nodes needed for saving/restoring.
15+ /// When one wants to save/restore from or into a session, one calls the save/restore methods
16+ /// # Example
17+ /// ```
18+ /// let mut scope = Scope::new_root_scope();
19+ /// // add operations to define the graph
20+ /// // ...
21+ /// // let w and b the variables that we wish to save
22+ /// let mut checkpoint_maker = CheckpointMaker::new(scope.new_sub_scope("checkpoint"),
23+ /// vec![w.clone(), b.clone()].into_boxed_slice(),
24+ /// );
25+ /// let session = Session::new(&SessionOptions::new(), &scope.graph())?;
26+ /// // run some training
27+ /// // ...
28+ /// // to save the training
29+ /// checkpoint_maker.save(&session, "data/checkpoint")?;
30+ /// // then we restore in a different session to continue there
31+ /// let new_session = Session::new(&SessionOptions::new(), &scope.graph())?;
32+ /// checkpoint_maker.save(&new_session, "data/checkpoint")?;
33+ /// ```
34+ ///
3535#[ derive( Debug ) ]
3636pub struct CheckpointMaker {
3737 scope : Scope ,
@@ -44,7 +44,7 @@ impl CheckpointMaker {
4444 /// The scope is used to modify the graph to add the save and restore ops.
4545 ///
4646 /// In order to provide a scope for the CheckpointMaker one can use scope.new_sub_scope("checkpoint")
47- /// in order to create the nodes with scoped names
47+ /// in order to create the nodes with scoped names.
4848 pub fn new ( scope : Scope , variables : Box < [ Variable ] > ) -> CheckpointMaker {
4949 CheckpointMaker {
5050 scope,
@@ -53,18 +53,7 @@ impl CheckpointMaker {
5353 }
5454 }
5555
56- /* fn make_all_variable_ops(&mut self) -> Result<Vec<Operation>, Status> {
57- let graph = self.scope.graph();
58- Ok(self
59- .variables
60- .iter()
61- .map(|v: &String| -> Result<Operation, Status> {
62- Ok(graph.operation_by_name_required(v.as_str())?.clone())
63- })
64- .collect::<Result<Vec<_>, Status>>()?)
65- }*/
66-
67- /// Add save and restore ops to the graph
56+ // Add save and restore ops to the graph.
6857 fn build_save_ops ( & mut self ) -> Result < SaveRestoreOps , Status > {
6958 let mut all_variable_ops_opt: Option < Vec < Operation > > = None ;
7059
@@ -76,18 +65,21 @@ impl CheckpointMaker {
7665 . operation_by_name_required ( "prefix_save" ) ?;
7766 ( prefix_save_op, op)
7867 } else {
79- let all_variable_ops =
80- all_variable_ops_opt. get_or_insert_with (
81- || self . variables . iter ( ) . map ( |v| v. output . operation . clone ( ) ) . collect :: < Vec < _ > > ( ) ) ;
68+ let all_variable_ops = all_variable_ops_opt. get_or_insert_with ( || {
69+ self . variables
70+ . iter ( )
71+ . map ( |v| v. output . operation . clone ( ) )
72+ . collect :: < Vec < _ > > ( )
73+ } ) ;
8274 let prefix_save = ops:: Placeholder :: new ( )
8375 . dtype ( crate :: DataType :: String )
8476 . build ( & mut self . scope . with_op_name ( "prefix_save" ) ) ?;
8577 let tensor_names = ops:: constant (
86- self
87- . variables
78+ self . variables
8879 . iter ( )
8980 . map ( |v| String :: from ( v. name ( ) ) )
90- . collect :: < Vec < _ > > ( ) . as_slice ( ) ,
81+ . collect :: < Vec < _ > > ( )
82+ . as_slice ( ) ,
9183 & mut self . scope ,
9284 ) ?;
9385 let shape_and_slices = ops:: constant (
@@ -126,9 +118,12 @@ impl CheckpointMaker {
126118 . operation_by_name_required ( "prefix_restore" ) ?;
127119 ( the_prefix_restore, op)
128120 } else {
129- let all_variable_ops =
130- all_variable_ops_opt. get_or_insert_with (
131- || self . variables . iter ( ) . map ( |v| v. output . operation . clone ( ) ) . collect :: < Vec < _ > > ( ) ) ;
121+ let all_variable_ops = all_variable_ops_opt. get_or_insert_with ( || {
122+ self . variables
123+ . iter ( )
124+ . map ( |v| v. output . operation . clone ( ) )
125+ . collect :: < Vec < _ > > ( )
126+ } ) ;
132127 let prefix_restore = ops:: Placeholder :: new ( )
133128 . dtype ( crate :: DataType :: String )
134129 . build ( & mut self . scope . with_op_name ( "prefix_restore" ) ) ?;
@@ -159,22 +154,22 @@ impl CheckpointMaker {
159154 let restore_op = nd. finish ( ) ?;
160155 drop ( g) ;
161156 let mut restore_var_ops = Vec :: < Operation > :: new ( ) ;
162- for ( i, var) in self . variables . iter ( ) . enumerate ( ) {
163- let var_op = var. output . operation . clone ( ) ;
164- restore_var_ops. push ( ops:: assign (
165- var_op,
166- crate :: Output {
167- operation : restore_op. clone ( ) ,
168- index : i as i32 ,
169- } ,
170- & mut self . scope . new_sub_scope ( format ! ( "restore{}" , i) . as_str ( ) ) ,
171- ) ?) ;
172- }
173- let mut no_op = ops:: NoOp :: new ( ) ;
174- for op in restore_var_ops {
175- no_op = no_op. add_control_input ( op) ;
176- }
177- ( prefix_restore, no_op. build ( & mut self . scope ) ?)
157+ for ( i, var) in self . variables . iter ( ) . enumerate ( ) {
158+ let var_op = var. output . operation . clone ( ) ;
159+ restore_var_ops. push ( ops:: assign (
160+ var_op,
161+ crate :: Output {
162+ operation : restore_op. clone ( ) ,
163+ index : i as i32 ,
164+ } ,
165+ & mut self . scope . new_sub_scope ( format ! ( "restore{}" , i) . as_str ( ) ) ,
166+ ) ?) ;
167+ }
168+ let mut no_op = ops:: NoOp :: new ( ) ;
169+ for op in restore_var_ops {
170+ no_op = no_op. add_control_input ( op) ;
171+ }
172+ ( prefix_restore, no_op. build ( & mut self . scope ) ?)
178173 } ;
179174 Ok ( SaveRestoreOps {
180175 prefix_save,
@@ -194,7 +189,7 @@ impl CheckpointMaker {
194189 Ok ( save_r_op)
195190 }
196191
197- /// Save the variables listed in this CheckpointMaker to the checkpoint with base filename backup_filename_base
192+ /// Save the variables listed in this CheckpointMaker to the checkpoint with base filename backup_filename_base.
198193 pub fn save ( & mut self , session : & Session , backup_filename_base : & str ) -> Result < ( ) , Status > {
199194 let save_restore_ops = self . get_save_operation ( ) ?;
200195 let prefix_arg = Tensor :: from ( backup_filename_base. to_string ( ) ) ;
@@ -206,7 +201,7 @@ impl CheckpointMaker {
206201 }
207202
208203 /// Restore into the session the variables listed in this CheckpointMaker from the checkpoint
209- /// in path_base
204+ /// in path_base.
210205 pub fn restore ( & mut self , session : & Session , path_base : & str ) -> Result < ( ) , Status > {
211206 let save_restore_ops = self . get_save_operation ( ) ?;
212207 let prefix_arg = Tensor :: from ( path_base. to_string ( ) ) ;
@@ -372,7 +367,8 @@ mod tests {
372367 variables : second_variables,
373368 } = create_scope ( ) ?;
374369 let second_session = Session :: new ( & SessionOptions :: new ( ) , & second_scope. graph ( ) ) ?;
375- let mut second_checkpoint = CheckpointMaker :: new ( second_scope, Box :: new ( second_variables. clone ( ) ) ) ;
370+ let mut second_checkpoint =
371+ CheckpointMaker :: new ( second_scope, Box :: new ( second_variables. clone ( ) ) ) ;
376372 second_checkpoint. restore ( & second_session, checkpoint_path_str. as_str ( ) ) ?;
377373 check_variables ( & second_session, & second_variables, & new_values) ?;
378374 Ok ( ( ) )
0 commit comments