66//! let mut scope = Scope::new_root_scope();
77//! // add operations to define the graph
88//! // ...
9- //! // let "w" and "b" the name of the variables that we wish to save
9+ //! // let w and b the variables that we wish to save
1010//! let mut checkpoint_maker = CheckpointMaker::new(scope.new_sub_scope("checkpoint"),
11- //! vec![String::from("w" ), String::from("b" )].into_boxed_slice(),
11+ //! vec![w.clone( ), b.clone( )].into_boxed_slice(),
1212//! );
1313//! let session = Session::new(&SessionOptions::new(), &scope.graph())?;
1414//! // run some training
1919//! let new_session = Session::new(&SessionOptions::new(), &scope.graph())?;
2020//! checkpoint_maker.save(&new_session, "data/checkpoint")?;
2121use crate :: option_insert_result:: OptionInsertWithResult ;
22- use crate :: { ops, Operation , Scope , Session , SessionRunArgs , Status , Tensor } ;
22+ use crate :: { ops, Operation , Scope , Session , SessionRunArgs , Status , Tensor , Variable } ;
2323
2424#[ derive( Debug ) ]
2525struct SaveRestoreOps {
@@ -35,25 +35,25 @@ struct SaveRestoreOps {
3535#[ derive( Debug ) ]
3636pub struct CheckpointMaker {
3737 scope : Scope ,
38- variables : Box < [ String ] > ,
38+ variables : Box < [ Variable ] > ,
3939 save_restore_ops : Option < SaveRestoreOps > ,
4040}
4141
4242impl CheckpointMaker {
4343 /// Creates a new CheckpointMaker for a Scope, with a list of variables to save/restore.
4444 /// The scope is used to modify the graph to add the save and restore ops.
4545 ///
46- /// In order to provide a scope for the CheckpointMaker one can use scope.new_sub_scope("")
47- /// as Scope does not support the Clone trait at present
48- pub fn new ( scope : Scope , variables : Box < [ String ] > ) -> CheckpointMaker {
46+ /// In order to provide a scope for the CheckpointMaker one can use scope.new_sub_scope("checkpoint ")
47+ /// in order to create the nodes with scoped names
48+ pub fn new ( scope : Scope , variables : Box < [ Variable ] > ) -> CheckpointMaker {
4949 CheckpointMaker {
5050 scope,
5151 variables,
5252 save_restore_ops : None ,
5353 }
5454 }
5555
56- fn make_all_variable_ops ( & mut self ) -> Result < Vec < Operation > , Status > {
56+ /* fn make_all_variable_ops(&mut self) -> Result<Vec<Operation>, Status> {
5757 let graph = self.scope.graph();
5858 Ok(self
5959 .variables
@@ -62,7 +62,7 @@ impl CheckpointMaker {
6262 Ok(graph.operation_by_name_required(v.as_str())?.clone())
6363 })
6464 .collect::<Result<Vec<_>, Status>>()?)
65- }
65+ }*/
6666
6767 /// Add save and restore ops to the graph
6868 fn build_save_ops ( & mut self ) -> Result < SaveRestoreOps , Status > {
@@ -77,16 +77,17 @@ impl CheckpointMaker {
7777 ( prefix_save_op, op)
7878 } else {
7979 let all_variable_ops =
80- all_variable_ops_opt. get_or_insert_with_result ( || self . make_all_variable_ops ( ) ) ?;
80+ all_variable_ops_opt. get_or_insert_with (
81+ || self . variables . iter ( ) . map ( |v| v. output . operation . clone ( ) ) . collect :: < Vec < _ > > ( ) ) ;
8182 let prefix_save = ops:: Placeholder :: new ( )
8283 . dtype ( crate :: DataType :: String )
8384 . build ( & mut self . scope . with_op_name ( "prefix_save" ) ) ?;
8485 let tensor_names = ops:: constant (
85- & self
86+ self
8687 . variables
8788 . iter ( )
88- . map ( |v| ( * v ) . to_string ( ) )
89- . collect :: < Vec < _ > > ( ) [ .. ] ,
89+ . map ( |v| String :: from ( v . name ( ) ) )
90+ . collect :: < Vec < _ > > ( ) . as_slice ( ) ,
9091 & mut self . scope ,
9192 ) ?;
9293 let shape_and_slices = ops:: constant (
@@ -126,14 +127,15 @@ impl CheckpointMaker {
126127 ( the_prefix_restore, op)
127128 } else {
128129 let all_variable_ops =
129- all_variable_ops_opt. get_or_insert_with_result ( || self . make_all_variable_ops ( ) ) ?;
130+ all_variable_ops_opt. get_or_insert_with (
131+ || self . variables . iter ( ) . map ( |v| v. output . operation . clone ( ) ) . collect :: < Vec < _ > > ( ) ) ;
130132 let prefix_restore = ops:: Placeholder :: new ( )
131133 . dtype ( crate :: DataType :: String )
132134 . build ( & mut self . scope . with_op_name ( "prefix_restore" ) ) ?;
133135 let all_var_names = self
134136 . variables
135137 . iter ( )
136- . map ( |v| v. to_string ( ) )
138+ . map ( |v| v. name . clone ( ) )
137139 . collect :: < Vec < _ > > ( ) ;
138140 let tensor_names = ops:: constant ( & all_var_names[ ..] , & mut self . scope ) ?;
139141 let shape_and_slices = ops:: constant (
@@ -158,10 +160,7 @@ impl CheckpointMaker {
158160 drop ( g) ;
159161 let mut restore_var_ops = Vec :: < Operation > :: new ( ) ;
160162 for ( i, var) in self . variables . iter ( ) . enumerate ( ) {
161- let var_op = self
162- . scope
163- . graph ( )
164- . operation_by_name_required ( var. as_str ( ) ) ?;
163+ let var_op = var. output . operation . clone ( ) ;
165164 restore_var_ops. push ( ops:: assign (
166165 var_op,
167166 crate :: Output {
@@ -357,16 +356,9 @@ mod tests {
357356 & [ 11.0 , 12.0 , 13.6 , 17.1 , 18.4 , 19.5 ] ,
358357 ] ;
359358 assign_variables ( & first_session, & first_scope_data, & assign_data, & new_values) ?;
360- let variable_names = first_scope_data
361- . variables
362- . as_ref ( )
363- . iter ( )
364- . map ( |v| String :: from ( v. name ( ) ) )
365- . collect :: < Vec < _ > > ( )
366- . into_boxed_slice ( ) ;
367359 let mut checkpoint = CheckpointMaker :: new (
368- first_scope_data. scope . new_sub_scope ( "" ) ,
369- variable_names . clone ( ) ,
360+ first_scope_data. scope . new_sub_scope ( "checkpoint " ) ,
361+ Box :: from ( first_scope_data . variables . clone ( ) ) ,
370362 ) ;
371363 let temp_dir = tempdir:: TempDir :: new ( "test-tensorflow" ) ?;
372364 let checkpoint_path = temp_dir. path ( ) . join ( "checkpoint-vars" ) ;
@@ -380,7 +372,7 @@ mod tests {
380372 variables : second_variables,
381373 } = create_scope ( ) ?;
382374 let second_session = Session :: new ( & SessionOptions :: new ( ) , & second_scope. graph ( ) ) ?;
383- let mut second_checkpoint = CheckpointMaker :: new ( second_scope, variable_names ) ;
375+ let mut second_checkpoint = CheckpointMaker :: new ( second_scope, Box :: new ( second_variables . clone ( ) ) ) ;
384376 second_checkpoint. restore ( & second_session, checkpoint_path_str. as_str ( ) ) ?;
385377 check_variables ( & second_session, & second_variables, & new_values) ?;
386378 Ok ( ( ) )
0 commit comments