3131 dict (
3232 testcase_name = 'sparse' ,
3333 input_data = [
34- {
35- 'val' : ['hello' ],
36- 'idx0' : [0 ],
37- 'idx1' : [0 ]
38- },
39- {
40- 'val' : ['world' ],
41- 'idx0' : [1 ],
42- 'idx1' : [1 ]
43- },
44- {
45- 'val' : ['hello' , 'goodbye' ],
46- 'idx0' : [0 , 1 ],
47- 'idx1' : [1 , 2 ]
48- },
34+ {'val' : ['hello' ], 'idx0' : [0 ], 'idx1' : [0 ]},
35+ {'val' : ['world' ], 'idx0' : [1 ], 'idx1' : [1 ]},
36+ {'val' : ['hello' , 'goodbye' ], 'idx0' : [0 , 1 ], 'idx1' : [1 , 2 ]},
4937 {
5038 'val' : ['hello' , 'goodbye' , ' ' ],
5139 'idx0' : [0 , 1 , 1 ],
52- 'idx1' : [0 , 1 , 2 ]
40+ 'idx1' : [0 , 1 , 2 ],
5341 },
5442 ],
55- input_metadata = tft .DatasetMetadata .from_feature_spec ({
56- 'x' : tf .io .SparseFeature (['idx0' , 'idx1' ], 'val' , tf .string , [2 , 3 ])
57- }),
58- expected_data = [{
59- 'index$sparse_indices_0' : [0 ],
60- 'index$sparse_indices_1' : [0 ],
61- 'index$sparse_values' : [0 ],
62- }, {
63- 'index$sparse_indices_0' : [1 ],
64- 'index$sparse_indices_1' : [1 ],
65- 'index$sparse_values' : [2 ],
66- }, {
67- 'index$sparse_indices_0' : [0 , 1 ],
68- 'index$sparse_indices_1' : [1 , 2 ],
69- 'index$sparse_values' : [0 , 1 ],
70- }, {
71- 'index$sparse_indices_0' : [0 , 1 , 1 ],
72- 'index$sparse_indices_1' : [0 , 1 , 2 ],
73- 'index$sparse_values' : [0 , 1 , 3 ],
74- }],
75- expected_vocab_file_contents = {
76- 'my_vocab' : [b'hello' , b'goodbye' , b'world' , b' ' ]
77- }),
78- dict (
79- testcase_name = 'ragged' ,
80- input_data = [
43+ input_metadata = tft .DatasetMetadata .from_feature_spec (
44+ {
45+ 'x' : tf .io .SparseFeature (
46+ ['idx0' , 'idx1' ], 'val' , tf .string , [2 , 3 ]
47+ )
48+ }
49+ ),
50+ expected_data = [
8151 {
82- 'val' : ['hello' , ' ' ],
83- 'row_lengths' : [1 , 0 , 1 ]
52+ 'index$sparse_indices_0' : [0 ],
53+ 'index$sparse_indices_1' : [0 ],
54+ 'index$sparse_values' : [0 ],
8455 },
8556 {
86- 'val' : ['world' ],
87- 'row_lengths' : [0 , 1 ]
57+ 'index$sparse_indices_0' : [1 ],
58+ 'index$sparse_indices_1' : [1 ],
59+ 'index$sparse_values' : [2 ],
8860 },
8961 {
90- 'val' : ['hello' , 'goodbye' ],
91- 'row_lengths' : [2 , 0 , 0 ]
62+ 'index$sparse_indices_0' : [0 , 1 ],
63+ 'index$sparse_indices_1' : [1 , 2 ],
64+ 'index$sparse_values' : [0 , 1 ],
9265 },
9366 {
94- 'val' : ['hello' , 'goodbye' , ' ' ],
95- 'row_lengths' : [0 , 2 , 1 ]
67+ 'index$sparse_indices_0' : [0 , 1 , 1 ],
68+ 'index$sparse_indices_1' : [0 , 1 , 2 ],
69+ 'index$sparse_values' : [0 , 1 , 3 ],
9670 },
9771 ],
98- input_metadata = tft .DatasetMetadata .from_feature_spec ({
99- 'x' :
100- tf .io .RaggedFeature (
72+ expected_vocab_contents = {
73+ b'hello' : 3 ,
74+ b'goodbye' : 2 ,
75+ b'world' : 1 ,
76+ b' ' : 1 ,
77+ },
78+ ),
79+ dict (
80+ testcase_name = 'ragged' ,
81+ input_data = [
82+ {'val' : ['hello' , ' ' ], 'row_lengths' : [1 , 0 , 1 ]},
83+ {'val' : ['world' ], 'row_lengths' : [0 , 1 ]},
84+ {'val' : ['hello' , 'goodbye' ], 'row_lengths' : [2 , 0 , 0 ]},
85+ {'val' : ['hello' , 'goodbye' , ' ' ], 'row_lengths' : [0 , 2 , 1 ]},
86+ ],
87+ input_metadata = tft .DatasetMetadata .from_feature_spec (
88+ {
89+ 'x' : tf .io .RaggedFeature (
10190 tf .string ,
10291 value_key = 'val' ,
10392 partitions = [
10493 tf .io .RaggedFeature .RowLengths ('row_lengths' ) # pytype: disable=attribute-error
105- ])
106- }),
94+ ],
95+ )
96+ }
97+ ),
10798 expected_data = [
108- {
109- 'index$ragged_values' : [0 , 2 ],
110- 'index$row_lengths_1' : [1 , 0 , 1 ]
111- },
112- {
113- 'index$ragged_values' : [3 ],
114- 'index$row_lengths_1' : [0 , 1 ]
115- },
116- {
117- 'index$ragged_values' : [0 , 1 ],
118- 'index$row_lengths_1' : [2 , 0 , 0 ]
119- },
99+ {'index$ragged_values' : [0 , 2 ], 'index$row_lengths_1' : [1 , 0 , 1 ]},
100+ {'index$ragged_values' : [3 ], 'index$row_lengths_1' : [0 , 1 ]},
101+ {'index$ragged_values' : [0 , 1 ], 'index$row_lengths_1' : [2 , 0 , 0 ]},
120102 {
121103 'index$ragged_values' : [0 , 1 , 2 ],
122- 'index$row_lengths_1' : [0 , 2 , 1 ]
104+ 'index$row_lengths_1' : [0 , 2 , 1 ],
123105 },
124106 ],
125- expected_vocab_file_contents = {
126- 'my_vocab' : [b'hello' , b'goodbye' , b' ' , b'world' ]
127- }),
107+ expected_vocab_contents = {
108+ b'hello' : 3 ,
109+ b'goodbye' : 2 ,
110+ b' ' : 2 ,
111+ b'world' : 1 ,
112+ },
113+ ),
128114]
129115
130116
@@ -733,7 +719,11 @@ def preprocessing_fn(inputs):
733719 'my_approximate_vocab' : expected_vocab_file_contents
734720 })
735721
736- def testComputeAndApplyApproximateVocabulary (self ):
722+ @tft_unit .named_parameters ([
723+ dict (testcase_name = 'no_frequency' , store_frequency = False ),
724+ dict (testcase_name = 'with_frequency' , store_frequency = True ),
725+ ])
726+ def testComputeAndApplyApproximateVocabulary (self , store_frequency ):
737727 input_data = [{'x' : 'a' }] * 2 + [{'x' : 'b' }] * 3
738728 input_metadata = tft .DatasetMetadata .from_feature_spec (
739729 {'x' : tf .io .FixedLenFeature ([], tf .string )})
@@ -743,7 +733,9 @@ def preprocessing_fn(inputs):
743733 inputs ['x' ],
744734 top_k = 2 ,
745735 file_format = self ._VocabFormat (),
746- num_oov_buckets = 1 )
736+ store_frequency = store_frequency ,
737+ num_oov_buckets = 1 ,
738+ )
747739 return {'index' : index }
748740
749741 expected_data = [{'index' : 1 }] * 2 + [{'index' : 0 }] * 3 + [{'index' : 2 }]
@@ -1355,19 +1347,49 @@ def preprocessing_fn(inputs):
13551347 expected_metadata ,
13561348 expected_vocab_file_contents = expected_vocab_file_contents )
13571349
1358- @tft_unit .named_parameters (* _COMPOSITE_COMPUTE_AND_APPLY_VOCABULARY_TEST_CASES
1359- )
1360- def testCompositeComputeAndApplyVocabulary (self , input_data , input_metadata ,
1361- expected_data ,
1362- expected_vocab_file_contents ):
1363-
1350+ @tft_unit .named_parameters (
1351+ * tft_unit .cross_named_parameters (
1352+ _COMPOSITE_COMPUTE_AND_APPLY_VOCABULARY_TEST_CASES ,
1353+ [
1354+ dict (testcase_name = 'no_frequency' , store_frequency = False ),
1355+ dict (testcase_name = 'with_frequency' , store_frequency = True ),
1356+ ],
1357+ )
1358+ )
1359+ def testCompositeComputeAndApplyVocabulary (
1360+ self ,
1361+ input_data ,
1362+ input_metadata ,
1363+ expected_data ,
1364+ expected_vocab_contents ,
1365+ store_frequency ,
1366+ ):
13641367 def preprocessing_fn (inputs ):
13651368 index = tft .compute_and_apply_vocabulary (
13661369 inputs ['x' ],
13671370 file_format = self ._VocabFormat (),
1368- vocab_filename = 'my_vocab' )
1371+ store_frequency = store_frequency ,
1372+ vocab_filename = 'my_vocab' ,
1373+ )
13691374 return {'index' : index }
13701375
1376+ if store_frequency :
1377+ def format_pair (t : bytes , c : int ) -> str :
1378+ t = t .decode ('utf-8' )
1379+ if t != ' ' or self ._VocabFormat () != 'text' :
1380+ suffix = ' ' + t
1381+ else :
1382+ suffix = ' __SPACE__'
1383+ return f'{ c } { suffix } '
1384+ contents = [
1385+ format_pair (t , c ).encode ('utf-8' )
1386+ for t , c in expected_vocab_contents .items ()
1387+ ]
1388+ else :
1389+ contents = [t for t in expected_vocab_contents ]
1390+
1391+ expected_vocab_file_contents = {'my_vocab' : contents }
1392+
13711393 self .assertAnalyzeAndTransformResults (
13721394 input_data ,
13731395 input_metadata ,
@@ -1650,7 +1672,9 @@ def preprocessing_fn(inputs):
16501672 coverage_top_k = 1 ,
16511673 key_fn = key_fn ,
16521674 frequency_threshold = 4 ,
1653- file_format = self ._VocabFormat ())
1675+ store_frequency = True ,
1676+ file_format = self ._VocabFormat (),
1677+ )
16541678
16551679 # Return input unchanged, this preprocessing_fn is a no-op except for
16561680 # computing uniques.
0 commit comments