2525from tensorboard_plugin_fairness_indicators import plugin
2626from tensorboard_plugin_fairness_indicators import summary_v2
2727import six
28- import tensorflow .compat .v1 as tf
29- import tensorflow .compat .v2 as tf2
28+ import tensorflow as tf2
29+ from tensorflow .keras import layers
30+ from tensorflow .keras import models
3031import tensorflow_model_analysis as tfma
31- from tensorflow_model_analysis .eval_saved_model .example_trainers import linear_classifier
3232from werkzeug import test as werkzeug_test
3333from werkzeug import wrappers
3434
3535from tensorboard .backend import application
3636from tensorboard .backend .event_processing import plugin_event_multiplexer as event_multiplexer
3737from tensorboard .plugins import base_plugin
3838
39- tf .enable_eager_execution ()
39+ Sequential = models .Sequential
40+ Dense = layers .Dense
41+
4042tf = tf2
4143
4244
45+ # Define keras based linear classifier.
46+ def create_linear_classifier (model_dir ):
47+
48+ model = Sequential ([Dense (1 , activation = "sigmoid" , input_shape = (2 ,))])
49+ model .compile (
50+ optimizer = "adam" , loss = "binary_crossentropy" , metrics = ["accuracy" ]
51+ )
52+ # Convert the Sequential model to a tf.Module before saving
53+ model = tf .keras .models .Model (inputs = model .inputs , outputs = model .outputs )
54+ tf .saved_model .save (model , model_dir )
55+ return model
56+
57+
4358class PluginTest (tf .test .TestCase ):
4459 """Tests for Fairness Indicators plugin server."""
4560
@@ -74,19 +89,19 @@ def tearDown(self):
7489 super (PluginTest , self ).tearDown ()
7590 shutil .rmtree (self ._log_dir , ignore_errors = True )
7691
77- def _exportEvalSavedModel (self , classifier ):
92+ def _export_eval_saved_model (self ):
93+ """Export the evaluation saved model."""
7894 temp_eval_export_dir = os .path .join (self .get_temp_dir (), "eval_export_dir" )
79- _ , eval_export_dir = classifier (None , temp_eval_export_dir )
80- return eval_export_dir
95+ return create_linear_classifier (temp_eval_export_dir )
8196
82- def _writeTFExamplesToTFRecords (self , examples ):
97+ def _write_tf_examples_to_tfrecords (self , examples ):
8398 data_location = os .path .join (self .get_temp_dir (), "input_data.rio" )
8499 with tf .io .TFRecordWriter (data_location ) as writer :
85100 for example in examples :
86101 writer .write (example .SerializeToString ())
87102 return data_location
88103
89- def _makeExample (self , age , language , label ):
104+ def _make_tf_example (self , age , language , label ):
90105 example = tf .train .Example ()
91106 example .features .feature ["age" ].float_list .value [:] = [age ]
92107 example .features .feature ["language" ].bytes_list .value [:] = [
@@ -112,14 +127,14 @@ def testRoutes(self):
112127 "foo" : "" .encode ("utf-8" )
113128 }},
114129 )
115- def testIsActive (self , get_random_stub ):
130+ def testIsActive (self ):
116131 self .assertTrue (self ._plugin .is_active ())
117132
118133 @mock .patch .object (
119134 event_multiplexer .EventMultiplexer ,
120135 "PluginRunToTagToContent" ,
121136 return_value = {})
122- def testIsInactive (self , get_random_stub ):
137+ def testIsInactive (self ):
123138 self .assertFalse (self ._plugin .is_active ())
124139
125140 def testIndexJsRoute (self ):
@@ -134,16 +149,15 @@ def testVulcanizedTemplateRoute(self):
134149 self .assertEqual (200 , response .status_code )
135150
136151 def testGetEvalResultsRoute (self ):
137- model_location = self ._exportEvalSavedModel (
138- linear_classifier .simple_linear_classifier )
152+ model_location = self ._export_eval_saved_model () # Call the method
139153 examples = [
140- self ._makeExample (age = 3.0 , language = "english" , label = 1.0 ),
141- self ._makeExample (age = 3.0 , language = "chinese" , label = 0.0 ),
142- self ._makeExample (age = 4.0 , language = "english" , label = 1.0 ),
143- self ._makeExample (age = 5.0 , language = "chinese" , label = 1.0 ),
144- self ._makeExample (age = 5.0 , language = "hindi" , label = 1.0 )
154+ self ._make_tf_example (age = 3.0 , language = "english" , label = 1.0 ),
155+ self ._make_tf_example (age = 3.0 , language = "chinese" , label = 0.0 ),
156+ self ._make_tf_example (age = 4.0 , language = "english" , label = 1.0 ),
157+ self ._make_tf_example (age = 5.0 , language = "chinese" , label = 1.0 ),
158+ self ._make_tf_example (age = 5.0 , language = "hindi" , label = 1.0 ),
145159 ]
146- data_location = self ._writeTFExamplesToTFRecords (examples )
160+ data_location = self ._write_tf_examples_to_tfrecords (examples )
147161 _ = tfma .run_model_analysis (
148162 eval_shared_model = tfma .default_eval_shared_model (
149163 eval_saved_model_path = model_location , example_weight_key = "age" ),
@@ -155,32 +169,36 @@ def testGetEvalResultsRoute(self):
155169 self .assertEqual (200 , response .status_code )
156170
157171 def testGetEvalResultsFromURLRoute (self ):
158- model_location = self ._exportEvalSavedModel (
159- linear_classifier .simple_linear_classifier )
172+ model_location = self ._export_eval_saved_model () # Call the method
160173 examples = [
161- self ._makeExample (age = 3.0 , language = "english" , label = 1.0 ),
162- self ._makeExample (age = 3.0 , language = "chinese" , label = 0.0 ),
163- self ._makeExample (age = 4.0 , language = "english" , label = 1.0 ),
164- self ._makeExample (age = 5.0 , language = "chinese" , label = 1.0 ),
165- self ._makeExample (age = 5.0 , language = "hindi" , label = 1.0 )
174+ self ._make_tf_example (age = 3.0 , language = "english" , label = 1.0 ),
175+ self ._make_tf_example (age = 3.0 , language = "chinese" , label = 0.0 ),
176+ self ._make_tf_example (age = 4.0 , language = "english" , label = 1.0 ),
177+ self ._make_tf_example (age = 5.0 , language = "chinese" , label = 1.0 ),
178+ self ._make_tf_example (age = 5.0 , language = "hindi" , label = 1.0 ),
166179 ]
167- data_location = self ._writeTFExamplesToTFRecords (examples )
180+ data_location = self ._write_tf_examples_to_tfrecords (examples )
168181 _ = tfma .run_model_analysis (
169182 eval_shared_model = tfma .default_eval_shared_model (
170183 eval_saved_model_path = model_location , example_weight_key = "age" ),
171184 data_location = data_location ,
172185 output_path = self ._eval_result_output_dir )
173186
174187 response = self ._server .get (
175- "/data/plugin/fairness_indicators/" +
176- "get_evaluation_result_from_remote_path?evaluation_output_path=" +
177- os .path .join (self ._eval_result_output_dir , tfma .METRICS_KEY ))
188+ "/data/plugin/fairness_indicators/"
189+ + "get_evaluation_result_from_remote_path?evaluation_output_path="
190+ + self ._eval_result_output_dir
191+ )
178192 self .assertEqual (200 , response .status_code )
179193
180- def testGetOutputFileFormat (self ):
181- self .assertEqual ("" , self ._plugin ._get_output_file_format ("abc_path" ))
182- self .assertEqual ("tfrecord" ,
183- self ._plugin ._get_output_file_format ("abc_path.tfrecord" ))
194+ def test_get_output_file_format (self ):
195+ evaluation_output_path = os .path .join (
196+ self ._eval_result_output_dir , "eval_result.tfrecord"
197+ )
198+ self .assertEqual (
199+ self ._plugin ._get_output_file_format (evaluation_output_path ),
200+ "tfrecord" ,
201+ )
184202
185203
186204if __name__ == "__main__" :
0 commit comments