@@ -788,13 +788,14 @@ def preprocessing_fn(inputs):
788788 preprocessing_fn , expected_data ,
789789 expected_metadata )
790790
791- def testScaleUnitIntervalPerKey (self ):
791+ @tft_unit .parameters ((True ,), (False ,))
792+ def testScaleUnitIntervalPerKey (self , elementwise ):
792793
793794 def preprocessing_fn (inputs ):
794795 outputs = {}
795796 stacked_input = tf .stack ([inputs ['x' ], inputs ['y' ]], axis = 1 )
796797 result = tft .scale_to_0_1_per_key (
797- stacked_input , inputs ['key' ], elementwise = False )
798+ stacked_input , inputs ['key' ], elementwise )
798799 outputs ['x_scaled' ], outputs ['y_scaled' ] = tf .unstack (result , axis = 1 )
799800 return outputs
800801
@@ -828,25 +829,46 @@ def preprocessing_fn(inputs):
828829 'y' : tf .io .FixedLenFeature ([], tf .float32 ),
829830 'key' : tf .io .FixedLenFeature ([], tf .string )
830831 })
831- expected_data = [{
832- 'x_scaled' : 0.6 ,
833- 'y_scaled' : 0.8
834- }, {
835- 'x_scaled' : 0.0 ,
836- 'y_scaled' : 0.2
837- }, {
838- 'x_scaled' : 0.8 ,
839- 'y_scaled' : 1.0
840- }, {
841- 'x_scaled' : 0.2 ,
842- 'y_scaled' : 0.4
843- }, {
844- 'x_scaled' : 1.0 ,
845- 'y_scaled' : 0.0
846- }, {
847- 'x_scaled' : 0.6 ,
848- 'y_scaled' : 0.5
849- }]
832+ if elementwise :
833+ expected_data = [{
834+ 'x_scaled' : 0.75 ,
835+ 'y_scaled' : 0.75
836+ }, {
837+ 'x_scaled' : 0.0 ,
838+ 'y_scaled' : 0.0
839+ }, {
840+ 'x_scaled' : 1.0 ,
841+ 'y_scaled' : 1.0
842+ }, {
843+ 'x_scaled' : 0.25 ,
844+ 'y_scaled' : 0.25
845+ }, {
846+ 'x_scaled' : 1.0 ,
847+ 'y_scaled' : 0.0
848+ }, {
849+ 'x_scaled' : 0.0 ,
850+ 'y_scaled' : 1.0
851+ }]
852+ else :
853+ expected_data = [{
854+ 'x_scaled' : 0.6 ,
855+ 'y_scaled' : 0.8
856+ }, {
857+ 'x_scaled' : 0.0 ,
858+ 'y_scaled' : 0.2
859+ }, {
860+ 'x_scaled' : 0.8 ,
861+ 'y_scaled' : 1.0
862+ }, {
863+ 'x_scaled' : 0.2 ,
864+ 'y_scaled' : 0.4
865+ }, {
866+ 'x_scaled' : 1.0 ,
867+ 'y_scaled' : 0.0
868+ }, {
869+ 'x_scaled' : 0.6 ,
870+ 'y_scaled' : 0.5
871+ }]
850872 expected_metadata = tft .DatasetMetadata .from_feature_spec ({
851873 'x_scaled' : tf .io .FixedLenFeature ([], tf .float32 ),
852874 'y_scaled' : tf .io .FixedLenFeature ([], tf .float32 )
@@ -919,14 +941,24 @@ def preprocessing_fn(inputs):
919941 expected_metadata )
920942
921943 @tft_unit .named_parameters (
922- dict (testcase_name = '_empty_filename' ,
923- key_vocabulary_filename = '' ),
924- dict (testcase_name = '_nonempty_filename' ,
925- key_vocabulary_filename = 'per_key' ),
926- dict (testcase_name = '_none_filename' ,
927- key_vocabulary_filename = None )
928- )
929- def testScaleMinMaxPerKey (self , key_vocabulary_filename ):
944+ dict (
945+ testcase_name = '_empty_filename' ,
946+ elementwise = False ,
947+ key_vocabulary_filename = '' ),
948+ dict (
949+ testcase_name = '_nonempty_filename' ,
950+ elementwise = False ,
951+ key_vocabulary_filename = 'per_key' ),
952+ dict (
953+ testcase_name = '_none_filename' ,
954+ elementwise = False ,
955+ key_vocabulary_filename = None ),
956+ dict (
957+ testcase_name = '_elementwise_none_filename' ,
958+ elementwise = True ,
959+ key_vocabulary_filename = None ))
960+ def testScaleMinMaxPerKey (self , elementwise , key_vocabulary_filename ):
961+
930962 def preprocessing_fn (inputs ):
931963 outputs = {}
932964 stacked_input = tf .stack ([inputs ['x' ], inputs ['y' ]], axis = 1 )
@@ -935,7 +967,7 @@ def preprocessing_fn(inputs):
935967 inputs ['key' ],
936968 output_min = - 1 ,
937969 output_max = 1 ,
938- elementwise = False ,
970+ elementwise = elementwise ,
939971 key_vocabulary_filename = key_vocabulary_filename )
940972 outputs ['x_scaled' ], outputs ['y_scaled' ] = tf .unstack (result , axis = 1 )
941973 return outputs
@@ -970,37 +1002,61 @@ def preprocessing_fn(inputs):
9701002 'y' : tf .io .FixedLenFeature ([], tf .float32 ),
9711003 'key' : tf .io .FixedLenFeature ([], tf .string )
9721004 })
973-
974- expected_data = [{
975- 'x_scaled' : - 0.25 ,
976- 'y_scaled' : 0.75
977- }, {
978- 'x_scaled' : - 1.0 ,
979- 'y_scaled' : 0.0
980- }, {
981- 'x_scaled' : 0.0 ,
982- 'y_scaled' : 1.0
983- }, {
984- 'x_scaled' : - 0.75 ,
985- 'y_scaled' : 0.25
986- }, {
987- 'x_scaled' : - 1.0 ,
988- 'y_scaled' : 0.0
989- }, {
990- 'x_scaled' : 0.0 ,
991- 'y_scaled' : 1.0
992- }]
1005+ if elementwise :
1006+ expected_data = [{
1007+ 'x_scaled' : 0.5 ,
1008+ 'y_scaled' : 0.5
1009+ }, {
1010+ 'x_scaled' : - 1.0 ,
1011+ 'y_scaled' : - 1.0
1012+ }, {
1013+ 'x_scaled' : 1.0 ,
1014+ 'y_scaled' : 1.0
1015+ }, {
1016+ 'x_scaled' : - 0.5 ,
1017+ 'y_scaled' : - 0.5
1018+ }, {
1019+ 'x_scaled' : - 1.0 ,
1020+ 'y_scaled' : - 1.0
1021+ }, {
1022+ 'x_scaled' : 1.0 ,
1023+ 'y_scaled' : 1.0
1024+ }]
1025+ else :
1026+ expected_data = [{
1027+ 'x_scaled' : - 0.25 ,
1028+ 'y_scaled' : 0.75
1029+ }, {
1030+ 'x_scaled' : - 1.0 ,
1031+ 'y_scaled' : 0.0
1032+ }, {
1033+ 'x_scaled' : 0.0 ,
1034+ 'y_scaled' : 1.0
1035+ }, {
1036+ 'x_scaled' : - 0.75 ,
1037+ 'y_scaled' : 0.25
1038+ }, {
1039+ 'x_scaled' : - 1.0 ,
1040+ 'y_scaled' : 0.0
1041+ }, {
1042+ 'x_scaled' : 0.0 ,
1043+ 'y_scaled' : 1.0
1044+ }]
9931045 expected_metadata = tft .DatasetMetadata .from_feature_spec ({
9941046 'x_scaled' : tf .io .FixedLenFeature ([], tf .float32 ),
9951047 'y_scaled' : tf .io .FixedLenFeature ([], tf .float32 )
9961048 })
9971049 if key_vocabulary_filename :
998- per_key_vocab_contents = {key_vocabulary_filename :
999- [(b'a' , [- 1.0 , 9.0 ]), (b'b' , [2.0 , 2.0 ])]}
1050+ per_key_vocab_contents = {
1051+ key_vocabulary_filename : [(b'a' , [- 1.0 , 9.0 ]), (b'b' , [2.0 , 2.0 ])]
1052+ }
10001053 else :
10011054 per_key_vocab_contents = None
10021055 self .assertAnalyzeAndTransformResults (
1003- input_data , input_metadata , preprocessing_fn , expected_data ,
1056+ input_data ,
1057+ input_metadata ,
1058+ preprocessing_fn ,
1059+ expected_data ,
10041060 expected_metadata ,
10051061 expected_vocab_file_contents = per_key_vocab_contents )
10061062
0 commit comments