2525from tensorboard_plugin_fairness_indicators import plugin
2626from tensorboard_plugin_fairness_indicators import summary_v2
2727import six
28- import tensorflow .compat .v1 as tf
29- import tensorflow .compat .v2 as tf2
28+ import tensorflow as tf2
29+ from tensorflow .keras import layers
30+ from tensorflow .keras import models
3031import tensorflow_model_analysis as tfma
31- from tensorflow_model_analysis .eval_saved_model .example_trainers import linear_classifier
3232from werkzeug import test as werkzeug_test
3333from werkzeug import wrappers
3434
3535from tensorboard .backend import application
3636from tensorboard .backend .event_processing import plugin_event_multiplexer as event_multiplexer
3737from tensorboard .plugins import base_plugin
3838
39- tf .enable_eager_execution ()
39+ Sequential = models .Sequential
40+ Dense = layers .Dense
41+
4042tf = tf2
4143
4244
45+ # Define keras based linear classifier.
46+ def create_linear_classifier (model_dir ):
47+
48+ inputs = tf .keras .Input (shape = (2 ,))
49+ outputs = layers .Dense (1 , activation = "sigmoid" )(inputs )
50+ model = tf .keras .Model (inputs = inputs , outputs = outputs )
51+
52+ model .compile (
53+ optimizer = "adam" , loss = "binary_crossentropy" , metrics = ["accuracy" ]
54+ )
55+
56+ tf .saved_model .save (model , model_dir )
57+ return model
58+
59+
4360class PluginTest (tf .test .TestCase ):
4461 """Tests for Fairness Indicators plugin server."""
4562
@@ -74,19 +91,19 @@ def tearDown(self):
7491 super (PluginTest , self ).tearDown ()
7592 shutil .rmtree (self ._log_dir , ignore_errors = True )
7693
77- def _exportEvalSavedModel (self , classifier ):
94+ def _export_eval_saved_model (self ):
95+ """Export the evaluation saved model."""
7896 temp_eval_export_dir = os .path .join (self .get_temp_dir (), "eval_export_dir" )
79- _ , eval_export_dir = classifier (None , temp_eval_export_dir )
80- return eval_export_dir
97+ return create_linear_classifier (temp_eval_export_dir )
8198
82- def _writeTFExamplesToTFRecords (self , examples ):
99+ def _write_tf_examples_to_tfrecords (self , examples ):
83100 data_location = os .path .join (self .get_temp_dir (), "input_data.rio" )
84101 with tf .io .TFRecordWriter (data_location ) as writer :
85102 for example in examples :
86103 writer .write (example .SerializeToString ())
87104 return data_location
88105
89- def _makeExample (self , age , language , label ):
106+ def _make_tf_example (self , age , language , label ):
90107 example = tf .train .Example ()
91108 example .features .feature ["age" ].float_list .value [:] = [age ]
92109 example .features .feature ["language" ].bytes_list .value [:] = [
@@ -112,14 +129,14 @@ def testRoutes(self):
112129 "foo" : "" .encode ("utf-8" )
113130 }},
114131 )
115- def testIsActive (self , get_random_stub ):
132+ def testIsActive (self ):
116133 self .assertTrue (self ._plugin .is_active ())
117134
118135 @mock .patch .object (
119136 event_multiplexer .EventMultiplexer ,
120137 "PluginRunToTagToContent" ,
121138 return_value = {})
122- def testIsInactive (self , get_random_stub ):
139+ def testIsInactive (self ):
123140 self .assertFalse (self ._plugin .is_active ())
124141
125142 def testIndexJsRoute (self ):
@@ -134,16 +151,15 @@ def testVulcanizedTemplateRoute(self):
134151 self .assertEqual (200 , response .status_code )
135152
136153 def testGetEvalResultsRoute (self ):
137- model_location = self ._exportEvalSavedModel (
138- linear_classifier .simple_linear_classifier )
154+ model_location = self ._export_eval_saved_model () # Call the method
139155 examples = [
140- self ._makeExample (age = 3.0 , language = "english" , label = 1.0 ),
141- self ._makeExample (age = 3.0 , language = "chinese" , label = 0.0 ),
142- self ._makeExample (age = 4.0 , language = "english" , label = 1.0 ),
143- self ._makeExample (age = 5.0 , language = "chinese" , label = 1.0 ),
144- self ._makeExample (age = 5.0 , language = "hindi" , label = 1.0 )
156+ self ._make_tf_example (age = 3.0 , language = "english" , label = 1.0 ),
157+ self ._make_tf_example (age = 3.0 , language = "chinese" , label = 0.0 ),
158+ self ._make_tf_example (age = 4.0 , language = "english" , label = 1.0 ),
159+ self ._make_tf_example (age = 5.0 , language = "chinese" , label = 1.0 ),
160+ self ._make_tf_example (age = 5.0 , language = "hindi" , label = 1.0 ),
145161 ]
146- data_location = self ._writeTFExamplesToTFRecords (examples )
162+ data_location = self ._write_tf_examples_to_tfrecords (examples )
147163 _ = tfma .run_model_analysis (
148164 eval_shared_model = tfma .default_eval_shared_model (
149165 eval_saved_model_path = model_location , example_weight_key = "age" ),
@@ -155,32 +171,36 @@ def testGetEvalResultsRoute(self):
155171 self .assertEqual (200 , response .status_code )
156172
157173 def testGetEvalResultsFromURLRoute (self ):
158- model_location = self ._exportEvalSavedModel (
159- linear_classifier .simple_linear_classifier )
174+ model_location = self ._export_eval_saved_model () # Call the method
160175 examples = [
161- self ._makeExample (age = 3.0 , language = "english" , label = 1.0 ),
162- self ._makeExample (age = 3.0 , language = "chinese" , label = 0.0 ),
163- self ._makeExample (age = 4.0 , language = "english" , label = 1.0 ),
164- self ._makeExample (age = 5.0 , language = "chinese" , label = 1.0 ),
165- self ._makeExample (age = 5.0 , language = "hindi" , label = 1.0 )
176+ self ._make_tf_example (age = 3.0 , language = "english" , label = 1.0 ),
177+ self ._make_tf_example (age = 3.0 , language = "chinese" , label = 0.0 ),
178+ self ._make_tf_example (age = 4.0 , language = "english" , label = 1.0 ),
179+ self ._make_tf_example (age = 5.0 , language = "chinese" , label = 1.0 ),
180+ self ._make_tf_example (age = 5.0 , language = "hindi" , label = 1.0 ),
166181 ]
167- data_location = self ._writeTFExamplesToTFRecords (examples )
182+ data_location = self ._write_tf_examples_to_tfrecords (examples )
168183 _ = tfma .run_model_analysis (
169184 eval_shared_model = tfma .default_eval_shared_model (
170185 eval_saved_model_path = model_location , example_weight_key = "age" ),
171186 data_location = data_location ,
172187 output_path = self ._eval_result_output_dir )
173188
174189 response = self ._server .get (
175- "/data/plugin/fairness_indicators/" +
176- "get_evaluation_result_from_remote_path?evaluation_output_path=" +
177- os .path .join (self ._eval_result_output_dir , tfma .METRICS_KEY ))
190+ "/data/plugin/fairness_indicators/"
191+ + "get_evaluation_result_from_remote_path?evaluation_output_path="
192+ + self ._eval_result_output_dir
193+ )
178194 self .assertEqual (200 , response .status_code )
179195
180- def testGetOutputFileFormat (self ):
181- self .assertEqual ("" , self ._plugin ._get_output_file_format ("abc_path" ))
182- self .assertEqual ("tfrecord" ,
183- self ._plugin ._get_output_file_format ("abc_path.tfrecord" ))
196+ def test_get_output_file_format (self ):
197+ evaluation_output_path = os .path .join (
198+ self ._eval_result_output_dir , "eval_result.tfrecord"
199+ )
200+ self .assertEqual (
201+ self ._plugin ._get_output_file_format (evaluation_output_path ),
202+ "tfrecord" ,
203+ )
184204
185205
186206if __name__ == "__main__" :
0 commit comments