2121
2222from typing import Any
2323
24- from fairness_indicators import fairness_indicators_metrics # pylint: disable=unused-import
25- from tensorflow import keras
2624import tensorflow .compat .v1 as tf
2725import tensorflow_model_analysis as tfma
26+ from tensorflow import keras
2827
28+ from fairness_indicators import fairness_indicators_metrics # noqa: F401
2929
30- TEXT_FEATURE = ' comment_text'
31- LABEL = ' toxicity'
32- SLICE = ' slice'
30+ TEXT_FEATURE = " comment_text"
31+ LABEL = " toxicity"
32+ SLICE = " slice"
3333FEATURE_MAP = {
3434 LABEL : tf .io .FixedLenFeature ([], tf .float32 ),
3535 TEXT_FEATURE : tf .io .FixedLenFeature ([], tf .string ),
3838
3939
4040class ExampleParser (keras .layers .Layer ):
41- """A Keras layer that parses the tf.Example."""
41+ """A Keras layer that parses the tf.Example."""
42+
43+ def __init__ (self , input_feature_key ):
44+ self ._input_feature_key = input_feature_key
45+ self .input_spec = keras .layers .InputSpec (shape = (1 ,), dtype = tf .string )
46+ super ().__init__ ()
4247
43- def __init__ (self , input_feature_key ):
44- self ._input_feature_key = input_feature_key
45- self .input_spec = keras .layers .InputSpec (shape = (1 ,), dtype = tf .string )
46- super ().__init__ ()
48+ def compute_output_shape (self , input_shape : Any ):
49+ return [1 , 1 ]
4750
48- def compute_output_shape (self , input_shape : Any ):
49- return [1 , 1 ]
51+ def call (self , serialized_examples ):
52+ def get_feature (serialized_example ):
53+ parsed_example = tf .io .parse_single_example (
54+ serialized_example , features = FEATURE_MAP
55+ )
56+ return parsed_example [self ._input_feature_key ]
5057
51- def call (self , serialized_examples ):
52- def get_feature (serialized_example ):
53- parsed_example = tf .io .parse_single_example (
54- serialized_example , features = FEATURE_MAP
55- )
56- return parsed_example [self ._input_feature_key ]
57- serialized_examples = tf .cast (serialized_examples , tf .string )
58- return tf .map_fn (get_feature , serialized_examples )
58+ serialized_examples = tf .cast (serialized_examples , tf .string )
59+ return tf .map_fn (get_feature , serialized_examples )
5960
6061
6162class Reshaper (keras .layers .Layer ):
62- """A Keras layer that reshapes the input."""
63+ """A Keras layer that reshapes the input."""
6364
64- def call (self , inputs ):
65- return tf .reshape (inputs , (1 , 32 ))
65+ def call (self , inputs ):
66+ return tf .reshape (inputs , (1 , 32 ))
6667
6768
6869class Caster (keras .layers .Layer ):
69- """A Keras layer that reshapes the input."""
70+ """A Keras layer that reshapes the input."""
7071
71- def call (self , inputs ):
72- return tf .cast (inputs , tf .float32 )
72+ def call (self , inputs ):
73+ return tf .cast (inputs , tf .float32 )
7374
7475
7576def get_example_model (input_feature_key : str ):
76- """Returns a Keras model for testing."""
77- parser = ExampleParser (input_feature_key )
78- text_vectorization = keras .layers .TextVectorization (
79- max_tokens = 32 ,
80- output_mode = ' int' ,
81- output_sequence_length = 32 ,
82- )
83- text_vectorization .adapt (
84- [ ' nontoxic' , ' toxic comment' , ' test comment' , ' abc' , ' abcdef' , ' random' ]
85- )
86- dense1 = keras .layers .Dense (
87- 32 ,
88- activation = None ,
89- use_bias = True ,
90- kernel_initializer = ' glorot_uniform' ,
91- bias_initializer = ' zeros' ,
92- )
93- dense2 = keras .layers .Dense (
94- 1 ,
95- activation = None ,
96- use_bias = False ,
97- kernel_initializer = ' glorot_uniform' ,
98- bias_initializer = ' zeros' ,
99- )
100-
101- inputs = tf .keras .Input (shape = (), dtype = tf .string )
102- parsed_example = parser (inputs )
103- text_vector = text_vectorization (parsed_example )
104- text_vector = Reshaper ()(text_vector )
105- text_vector = Caster ()(text_vector )
106- output1 = dense1 (text_vector )
107- output2 = dense2 (output1 )
108- return tf .keras .Model (inputs = inputs , outputs = output2 )
77+ """Returns a Keras model for testing."""
78+ parser = ExampleParser (input_feature_key )
79+ text_vectorization = keras .layers .TextVectorization (
80+ max_tokens = 32 ,
81+ output_mode = " int" ,
82+ output_sequence_length = 32 ,
83+ )
84+ text_vectorization .adapt (
85+ [ " nontoxic" , " toxic comment" , " test comment" , " abc" , " abcdef" , " random" ]
86+ )
87+ dense1 = keras .layers .Dense (
88+ 32 ,
89+ activation = None ,
90+ use_bias = True ,
91+ kernel_initializer = " glorot_uniform" ,
92+ bias_initializer = " zeros" ,
93+ )
94+ dense2 = keras .layers .Dense (
95+ 1 ,
96+ activation = None ,
97+ use_bias = False ,
98+ kernel_initializer = " glorot_uniform" ,
99+ bias_initializer = " zeros" ,
100+ )
101+
102+ inputs = tf .keras .Input (shape = (), dtype = tf .string )
103+ parsed_example = parser (inputs )
104+ text_vector = text_vectorization (parsed_example )
105+ text_vector = Reshaper ()(text_vector )
106+ text_vector = Caster ()(text_vector )
107+ output1 = dense1 (text_vector )
108+ output2 = dense2 (output1 )
109+ return tf .keras .Model (inputs = inputs , outputs = output2 )
109110
110111
111112def evaluate_model (
@@ -114,23 +115,23 @@ def evaluate_model(
114115 tfma_eval_result_path ,
115116 eval_config ,
116117):
117- """Evaluate Model using Tensorflow Model Analysis.
118-
119- Args:
120- classifier_model_path: Trained classifier model to be evaluted.
121- validate_tf_file_path: File containing validation TFRecordDataset .
122- tfma_eval_result_path: Path to export tfma-related eval path .
123- eval_config: tfma eval_config .
124- """
125-
126- eval_shared_model = tfma .default_eval_shared_model (
127- eval_saved_model_path = classifier_model_path , eval_config = eval_config
128- )
129-
130- # Run the fairness evaluation.
131- tfma .run_model_analysis (
132- eval_shared_model = eval_shared_model ,
133- data_location = validate_tf_file_path ,
134- output_path = tfma_eval_result_path ,
135- eval_config = eval_config ,
136- )
118+ """Evaluate Model using Tensorflow Model Analysis.
119+
120+ Args:
121+ ----
122+ classifier_model_path: Trained classifier model to be evaluted .
123+ validate_tf_file_path: File containing validation TFRecordDataset .
124+ tfma_eval_result_path: Path to export tfma-related eval path .
125+ eval_config: tfma eval_config.
126+ """
127+ eval_shared_model = tfma .default_eval_shared_model (
128+ eval_saved_model_path = classifier_model_path , eval_config = eval_config
129+ )
130+
131+ # Run the fairness evaluation.
132+ tfma .run_model_analysis (
133+ eval_shared_model = eval_shared_model ,
134+ data_location = validate_tf_file_path ,
135+ output_path = tfma_eval_result_path ,
136+ eval_config = eval_config ,
137+ )
0 commit comments