1919results can be visualized using tools like TensorBoard.
2020"""
2121
22+ from typing import Any
23+
2224from fairness_indicators import fairness_indicators_metrics # pylint: disable=unused-import
2325from tensorflow import keras
2426import tensorflow .compat .v1 as tf
@@ -40,41 +42,58 @@ class ExampleParser(keras.layers.Layer):
4042
4143 def __init__ (self , input_feature_key ):
4244 self ._input_feature_key = input_feature_key
45+ self .input_spec = keras .layers .InputSpec (shape = (1 ,), dtype = tf .string )
4346 super ().__init__ ()
4447
48+ def compute_output_shape (self , input_shape : Any ):
49+ return [1 , 1 ]
50+
4551 def call (self , serialized_examples ):
4652 def get_feature (serialized_example ):
4753 parsed_example = tf .io .parse_single_example (
4854 serialized_example , features = FEATURE_MAP
4955 )
5056 return parsed_example [self ._input_feature_key ]
51-
57+ serialized_examples = tf . cast ( serialized_examples , tf . string )
5258 return tf .map_fn (get_feature , serialized_examples )
5359
5460
55- class ExampleModel (keras .Model ):
56- """A Example Keras NLP model ."""
61+ class Reshaper (keras .layers . Layer ):
62+ """A Keras layer that reshapes the input ."""
5763
58- def __init__ (self , input_feature_key ):
59- super ().__init__ ()
60- self .parser = ExampleParser (input_feature_key )
61- self .text_vectorization = keras .layers .TextVectorization (
62- max_tokens = 32 ,
63- output_mode = 'int' ,
64- output_sequence_length = 32 ,
65- )
66- self .text_vectorization .adapt (
67- ['nontoxic' , 'toxic comment' , 'test comment' , 'abc' , 'abcdef' , 'random' ]
68- )
69- self .dense1 = keras .layers .Dense (32 , activation = 'relu' )
70- self .dense2 = keras .layers .Dense (1 )
71-
72- def call (self , inputs , training = True , mask = None ):
73- parsed_example = self .parser (inputs )
74- text_vector = self .text_vectorization (parsed_example )
75- output1 = self .dense1 (tf .cast (text_vector , tf .float32 ))
76- output2 = self .dense2 (output1 )
77- return output2
64+ def call (self , inputs ):
65+ return tf .reshape (inputs , (1 , 32 ))
66+
67+
68+ class Caster (keras .layers .Layer ):
69+ """A Keras layer that reshapes the input."""
70+
71+ def call (self , inputs ):
72+ return tf .cast (inputs , tf .float32 )
73+
74+
75+ def get_example_model (input_feature_key : str ):
76+ """Returns a Keras model for testing."""
77+ parser = ExampleParser (input_feature_key )
78+ text_vectorization = keras .layers .TextVectorization (
79+ max_tokens = 32 ,
80+ output_mode = 'int' ,
81+ output_sequence_length = 32 ,
82+ )
83+ text_vectorization .adapt (
84+ ['nontoxic' , 'toxic comment' , 'test comment' , 'abc' , 'abcdef' , 'random' ]
85+ )
86+ dense1 = keras .layers .Dense (32 , activation = 'relu' )
87+ dense2 = keras .layers .Dense (1 )
88+
89+ inputs = tf .keras .Input (shape = (), dtype = tf .string )
90+ parsed_example = parser (inputs )
91+ text_vector = text_vectorization (parsed_example )
92+ text_vector = Reshaper ()(text_vector )
93+ text_vector = Caster ()(text_vector )
94+ output1 = dense1 (text_vector )
95+ output2 = dense2 (output1 )
96+ return tf .keras .Model (inputs = inputs , outputs = output2 )
7897
7998
8099def evaluate_model (
0 commit comments