1717from __future__ import division
1818from __future__ import print_function
1919
20+ import os
21+
2022from absl .testing import parameterized
2123import neural_structured_learning .configs as configs
2224from neural_structured_learning .keras import graph_regularization
23-
2425import numpy as np
2526import tensorflow as tf
2627
@@ -88,10 +89,12 @@ def build_linear_functional_model(input_shape, weights, num_output=1):
8889def build_linear_subclass_model (input_shape , weights , num_output = 1 ):
8990 del input_shape
9091
91- class LinearModel (tf .keras .Model ):
92+ class CustomLinearModel (tf .keras .Model ):
9293
93- def __init__ (self ):
94- super (LinearModel , self ).__init__ ()
94+ def __init__ (self , weights , num_output , name = None ):
95+ super (CustomLinearModel , self ).__init__ (name = name )
96+ self .init_weights = weights
97+ self .num_output = num_output
9598 self .dense = tf .keras .layers .Dense (
9699 num_output ,
97100 use_bias = False ,
@@ -101,7 +104,14 @@ def __init__(self):
101104 def call (self , inputs ):
102105 return self .dense (inputs [FEATURE_NAME ])
103106
104- return LinearModel ()
107+ def get_config (self ):
108+ return {
109+ 'name' : self .name ,
110+ 'weights' : self .init_weights ,
111+ 'num_output' : self .num_output
112+ }
113+
114+ return CustomLinearModel (weights , num_output )
105115
106116
107117def make_dataset (example_proto , input_shape , training , max_neighbors ):
@@ -481,6 +491,47 @@ def test_graph_reg_model_evaluate(self, model_fn):
481491 weight = w ,
482492 distributed_strategy = None )
483493
494+ def _test_graph_reg_model_save (self , model_fn ):
495+ """Template for testing model saving and loading."""
496+ w = np .array ([[4.0 ], [- 3.0 ]])
497+ base_model = model_fn ((2 ,), w )
498+ graph_reg_config = configs .make_graph_reg_config (
499+ max_neighbors = 1 , multiplier = 1 )
500+ graph_reg_model = graph_regularization .GraphRegularization (
501+ base_model , graph_reg_config )
502+ graph_reg_model .compile (
503+ optimizer = tf .keras .optimizers .SGD (LEARNING_RATE ),
504+ loss = 'MSE' ,
505+ metrics = ['accuracy' ])
506+
507+ # Run the model before saving it. This is necessary for subclassed models.
508+ inputs = {FEATURE_NAME : tf .constant ([[5.0 , 3.0 ]])}
509+ graph_reg_model .predict (inputs , steps = 1 , batch_size = 1 )
510+ saved_model_dir = os .path .join (self .get_temp_dir (), 'saved_model' )
511+ graph_reg_model .save (saved_model_dir )
512+
513+ loaded_model = tf .keras .models .load_model (saved_model_dir )
514+ self .assertEqual (
515+ len (loaded_model .trainable_weights ),
516+ len (graph_reg_model .trainable_weights ))
517+ for w_loaded , w_graph_reg in zip (loaded_model .trainable_weights ,
518+ graph_reg_model .trainable_weights ):
519+ self .assertAllClose (
520+ tf .keras .backend .get_value (w_loaded ),
521+ tf .keras .backend .get_value (w_graph_reg ))
522+
523+ @parameterized .named_parameters ([
524+ ('_sequential' , build_linear_sequential_model ),
525+ ('_functional' , build_linear_functional_model ),
526+ ])
527+ def test_graph_reg_model_save (self , model_fn ):
528+ self ._test_graph_reg_model_save (model_fn )
529+
530+ # Saving subclassed models are only supported in TF v2.
531+ @test_util .run_v2_only
532+ def test_graph_reg_model_save_subclass (self ):
533+ self ._test_graph_reg_model_save (build_linear_subclass_model )
534+
484535
485536if __name__ == '__main__' :
486537 tf .test .main ()
0 commit comments