Skip to content

Commit 7875e9e

Browse files
zoyahavtfx-copybara
authored andcommitted
Enable store_frequency for compute_and_apply_vocabulary.
When file_format is 'text' then space characters in tokens are replaced in both the vocabulary and prior to lookup, as a workaround to allow TextFileInitializer to read the data properly. PiperOrigin-RevId: 523634481
1 parent 1e20b90 commit 7875e9e

File tree

5 files changed

+252
-133
lines changed

5 files changed

+252
-133
lines changed

RELEASE.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
* `DatasetKey.non_cacheable` added to allow for some datasets to not produce
1414
cache. This may be useful for gradual cache generation when operating on a
1515
large rolling range of datasets.
16+
* Vocabularies produced by `compute_and_apply_vocabulary` can now store
17+
frequencies. Controlled by the `store_frequency` parameter.
1618

1719
## Bug Fixes and Other Changes
1820

tensorflow_transform/beam/vocabulary_integration_test.py

Lines changed: 105 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -31,100 +31,86 @@
3131
dict(
3232
testcase_name='sparse',
3333
input_data=[
34-
{
35-
'val': ['hello'],
36-
'idx0': [0],
37-
'idx1': [0]
38-
},
39-
{
40-
'val': ['world'],
41-
'idx0': [1],
42-
'idx1': [1]
43-
},
44-
{
45-
'val': ['hello', 'goodbye'],
46-
'idx0': [0, 1],
47-
'idx1': [1, 2]
48-
},
34+
{'val': ['hello'], 'idx0': [0], 'idx1': [0]},
35+
{'val': ['world'], 'idx0': [1], 'idx1': [1]},
36+
{'val': ['hello', 'goodbye'], 'idx0': [0, 1], 'idx1': [1, 2]},
4937
{
5038
'val': ['hello', 'goodbye', ' '],
5139
'idx0': [0, 1, 1],
52-
'idx1': [0, 1, 2]
40+
'idx1': [0, 1, 2],
5341
},
5442
],
55-
input_metadata=tft.DatasetMetadata.from_feature_spec({
56-
'x': tf.io.SparseFeature(['idx0', 'idx1'], 'val', tf.string, [2, 3])
57-
}),
58-
expected_data=[{
59-
'index$sparse_indices_0': [0],
60-
'index$sparse_indices_1': [0],
61-
'index$sparse_values': [0],
62-
}, {
63-
'index$sparse_indices_0': [1],
64-
'index$sparse_indices_1': [1],
65-
'index$sparse_values': [2],
66-
}, {
67-
'index$sparse_indices_0': [0, 1],
68-
'index$sparse_indices_1': [1, 2],
69-
'index$sparse_values': [0, 1],
70-
}, {
71-
'index$sparse_indices_0': [0, 1, 1],
72-
'index$sparse_indices_1': [0, 1, 2],
73-
'index$sparse_values': [0, 1, 3],
74-
}],
75-
expected_vocab_file_contents={
76-
'my_vocab': [b'hello', b'goodbye', b'world', b' ']
77-
}),
78-
dict(
79-
testcase_name='ragged',
80-
input_data=[
43+
input_metadata=tft.DatasetMetadata.from_feature_spec(
44+
{
45+
'x': tf.io.SparseFeature(
46+
['idx0', 'idx1'], 'val', tf.string, [2, 3]
47+
)
48+
}
49+
),
50+
expected_data=[
8151
{
82-
'val': ['hello', ' '],
83-
'row_lengths': [1, 0, 1]
52+
'index$sparse_indices_0': [0],
53+
'index$sparse_indices_1': [0],
54+
'index$sparse_values': [0],
8455
},
8556
{
86-
'val': ['world'],
87-
'row_lengths': [0, 1]
57+
'index$sparse_indices_0': [1],
58+
'index$sparse_indices_1': [1],
59+
'index$sparse_values': [2],
8860
},
8961
{
90-
'val': ['hello', 'goodbye'],
91-
'row_lengths': [2, 0, 0]
62+
'index$sparse_indices_0': [0, 1],
63+
'index$sparse_indices_1': [1, 2],
64+
'index$sparse_values': [0, 1],
9265
},
9366
{
94-
'val': ['hello', 'goodbye', ' '],
95-
'row_lengths': [0, 2, 1]
67+
'index$sparse_indices_0': [0, 1, 1],
68+
'index$sparse_indices_1': [0, 1, 2],
69+
'index$sparse_values': [0, 1, 3],
9670
},
9771
],
98-
input_metadata=tft.DatasetMetadata.from_feature_spec({
99-
'x':
100-
tf.io.RaggedFeature(
72+
expected_vocab_contents={
73+
b'hello': 3,
74+
b'goodbye': 2,
75+
b'world': 1,
76+
b' ': 1,
77+
},
78+
),
79+
dict(
80+
testcase_name='ragged',
81+
input_data=[
82+
{'val': ['hello', ' '], 'row_lengths': [1, 0, 1]},
83+
{'val': ['world'], 'row_lengths': [0, 1]},
84+
{'val': ['hello', 'goodbye'], 'row_lengths': [2, 0, 0]},
85+
{'val': ['hello', 'goodbye', ' '], 'row_lengths': [0, 2, 1]},
86+
],
87+
input_metadata=tft.DatasetMetadata.from_feature_spec(
88+
{
89+
'x': tf.io.RaggedFeature(
10190
tf.string,
10291
value_key='val',
10392
partitions=[
10493
tf.io.RaggedFeature.RowLengths('row_lengths') # pytype: disable=attribute-error
105-
])
106-
}),
94+
],
95+
)
96+
}
97+
),
10798
expected_data=[
108-
{
109-
'index$ragged_values': [0, 2],
110-
'index$row_lengths_1': [1, 0, 1]
111-
},
112-
{
113-
'index$ragged_values': [3],
114-
'index$row_lengths_1': [0, 1]
115-
},
116-
{
117-
'index$ragged_values': [0, 1],
118-
'index$row_lengths_1': [2, 0, 0]
119-
},
99+
{'index$ragged_values': [0, 2], 'index$row_lengths_1': [1, 0, 1]},
100+
{'index$ragged_values': [3], 'index$row_lengths_1': [0, 1]},
101+
{'index$ragged_values': [0, 1], 'index$row_lengths_1': [2, 0, 0]},
120102
{
121103
'index$ragged_values': [0, 1, 2],
122-
'index$row_lengths_1': [0, 2, 1]
104+
'index$row_lengths_1': [0, 2, 1],
123105
},
124106
],
125-
expected_vocab_file_contents={
126-
'my_vocab': [b'hello', b'goodbye', b' ', b'world']
127-
}),
107+
expected_vocab_contents={
108+
b'hello': 3,
109+
b'goodbye': 2,
110+
b' ': 2,
111+
b'world': 1,
112+
},
113+
),
128114
]
129115

130116

@@ -733,7 +719,11 @@ def preprocessing_fn(inputs):
733719
'my_approximate_vocab': expected_vocab_file_contents
734720
})
735721

736-
def testComputeAndApplyApproximateVocabulary(self):
722+
@tft_unit.named_parameters([
723+
dict(testcase_name='no_frequency', store_frequency=False),
724+
dict(testcase_name='with_frequency', store_frequency=True),
725+
])
726+
def testComputeAndApplyApproximateVocabulary(self, store_frequency):
737727
input_data = [{'x': 'a'}] * 2 + [{'x': 'b'}] * 3
738728
input_metadata = tft.DatasetMetadata.from_feature_spec(
739729
{'x': tf.io.FixedLenFeature([], tf.string)})
@@ -743,7 +733,9 @@ def preprocessing_fn(inputs):
743733
inputs['x'],
744734
top_k=2,
745735
file_format=self._VocabFormat(),
746-
num_oov_buckets=1)
736+
store_frequency=store_frequency,
737+
num_oov_buckets=1,
738+
)
747739
return {'index': index}
748740

749741
expected_data = [{'index': 1}] * 2 + [{'index': 0}] * 3 + [{'index': 2}]
@@ -1355,19 +1347,49 @@ def preprocessing_fn(inputs):
13551347
expected_metadata,
13561348
expected_vocab_file_contents=expected_vocab_file_contents)
13571349

1358-
@tft_unit.named_parameters(*_COMPOSITE_COMPUTE_AND_APPLY_VOCABULARY_TEST_CASES
1359-
)
1360-
def testCompositeComputeAndApplyVocabulary(self, input_data, input_metadata,
1361-
expected_data,
1362-
expected_vocab_file_contents):
1363-
1350+
@tft_unit.named_parameters(
1351+
*tft_unit.cross_named_parameters(
1352+
_COMPOSITE_COMPUTE_AND_APPLY_VOCABULARY_TEST_CASES,
1353+
[
1354+
dict(testcase_name='no_frequency', store_frequency=False),
1355+
dict(testcase_name='with_frequency', store_frequency=True),
1356+
],
1357+
)
1358+
)
1359+
def testCompositeComputeAndApplyVocabulary(
1360+
self,
1361+
input_data,
1362+
input_metadata,
1363+
expected_data,
1364+
expected_vocab_contents,
1365+
store_frequency,
1366+
):
13641367
def preprocessing_fn(inputs):
13651368
index = tft.compute_and_apply_vocabulary(
13661369
inputs['x'],
13671370
file_format=self._VocabFormat(),
1368-
vocab_filename='my_vocab')
1371+
store_frequency=store_frequency,
1372+
vocab_filename='my_vocab',
1373+
)
13691374
return {'index': index}
13701375

1376+
if store_frequency:
1377+
def format_pair(t: bytes, c: int) -> str:
1378+
t = t.decode('utf-8')
1379+
if t != ' ' or self._VocabFormat() != 'text':
1380+
suffix = ' ' + t
1381+
else:
1382+
suffix = ' __SPACE__'
1383+
return f'{c}{suffix}'
1384+
contents = [
1385+
format_pair(t, c).encode('utf-8')
1386+
for t, c in expected_vocab_contents.items()
1387+
]
1388+
else:
1389+
contents = [t for t in expected_vocab_contents]
1390+
1391+
expected_vocab_file_contents = {'my_vocab': contents}
1392+
13711393
self.assertAnalyzeAndTransformResults(
13721394
input_data,
13731395
input_metadata,
@@ -1650,7 +1672,9 @@ def preprocessing_fn(inputs):
16501672
coverage_top_k=1,
16511673
key_fn=key_fn,
16521674
frequency_threshold=4,
1653-
file_format=self._VocabFormat())
1675+
store_frequency=True,
1676+
file_format=self._VocabFormat(),
1677+
)
16541678

16551679
# Return input unchanged, this preprocessing_fn is a no-op except for
16561680
# computing uniques.

tensorflow_transform/experimental/mappers.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,10 @@ def compute_and_apply_approximate_vocabulary(
4545
num_oov_buckets: int = 0,
4646
vocab_filename: Optional[str] = None,
4747
weights: Optional[tf.Tensor] = None,
48-
file_format: common_types.VocabularyFileFormatType = analyzers
49-
.DEFAULT_VOCABULARY_FILE_FORMAT,
50-
name: Optional[str] = None) -> common_types.ConsistentTensorType:
48+
file_format: common_types.VocabularyFileFormatType = analyzers.DEFAULT_VOCABULARY_FILE_FORMAT,
49+
store_frequency: Optional[bool] = False,
50+
name: Optional[str] = None,
51+
) -> common_types.ConsistentTensorType:
5152
"""Generates an approximate vocabulary for `x` and maps it to an integer.
5253
5354
Args:
@@ -70,7 +71,12 @@ def compute_and_apply_approximate_vocabulary(
7071
same shape as x.
7172
file_format: (Optional) A str. The format of the resulting vocabulary file.
7273
Accepted formats are: 'tfrecord_gzip', 'text'. 'tfrecord_gzip' requires
73-
tensorflow>=2.4. The default value is 'text'.
74+
tensorflow>=2.4. The default value is 'text'.
75+
store_frequency: If True, frequency of the words is stored in the vocabulary
76+
file. In the case labels are provided, the mutual information is stored in
77+
the file instead. Each line in the file will be of the form 'frequency
78+
word'. NOTE: if True and text_format is 'text' then spaces will be
79+
replaced to avoid information loss.
7480
name: (Optional) A name for this operation.
7581
7682
Returns:
@@ -90,19 +96,27 @@ def compute_and_apply_approximate_vocabulary(
9096
"""
9197
with tf.compat.v1.name_scope(name,
9298
'compute_and_apply_approximate_vocabulary'):
99+
if store_frequency and file_format == 'text':
100+
x = tf_utils.maybe_format_vocabulary_input(x)
93101
deferred_vocab_and_filename = experimental_analyzers.approximate_vocabulary(
94102
x=x,
95103
top_k=top_k,
96104
vocab_filename=vocab_filename,
97105
weights=weights,
98106
file_format=file_format,
99-
name=name)
100-
return mappers.apply_vocabulary(
107+
store_frequency=store_frequency,
108+
name=name,
109+
)
110+
return mappers._apply_vocabulary_internal( # pylint: disable=protected-access
101111
x,
102112
deferred_vocab_and_filename,
103113
default_value,
104114
num_oov_buckets,
105-
file_format=file_format)
115+
lookup_fn=None,
116+
file_format=file_format,
117+
store_frequency=store_frequency,
118+
name=None,
119+
)
106120

107121

108122
@common.log_api_use(common.MAPPER_COLLECTION)

0 commit comments

Comments
 (0)