Skip to content

Commit 1e20b90

Browse files
zoyahavtfx-copybara
authored andcommitted
Accounting for the empty vocabulary dummy token in VocabularyCount. This fixes an issue where get_vocabulary_size_by_name returns 0 for an empty vocabulary, while the actual vocabulary size in this case is 1.
PiperOrigin-RevId: 523356276
1 parent 135d77d commit 1e20b90

File tree

3 files changed

+50
-1
lines changed

3 files changed

+50
-1
lines changed

RELEASE.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
* Depends on `numpy~=1.22.0`.
2020
* Depends on `tensorflow>=2.12.0,<2.13`.
2121
* Depends on `protobuf>=3.20.3,<5`.
22+
* Modifies `get_vocabulary_size_by_name` to return a minimum of 1.
2223

2324
## Breaking Changes
2425

tensorflow_transform/beam/analyzer_impls.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,12 +302,18 @@ class _VocabularyCountImpl(beam.PTransform):
302302
def __init__(self, operation, extra_args):
303303
super().__init__()
304304

305+
def _format_count(self, count):
306+
# Count should be at least one because empty vocabularies get populated with
307+
# a single dummy value when written.
308+
# TODO(b/62272023) remove this workaround if/when fixed on tensorflow.
309+
return np.int64(np.maximum(count, 1))
310+
305311
def expand(self, inputs):
306312
pcoll, = inputs
307313

308314
return (pcoll
309315
| 'TotalVocabSize' >> beam.combiners.Count.Globally()
310-
| 'ToInt64' >> beam.Map(np.int64))
316+
| 'FormatCount' >> beam.Map(self._format_count))
311317

312318

313319
@common.register_ptransform(analyzer_nodes.VocabularyMerge)

tensorflow_transform/beam/annotators_test.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,48 @@ def preprocessing_fn(inputs):
216216
force_tf_compat_v1=use_tf_compat_v1,
217217
)
218218

219+
@tft_unit.named_parameters(
220+
dict(
221+
testcase_name='sanity',
222+
values=['hello', 'world', 'world'],
223+
expected_size=2,
224+
),
225+
dict(
226+
testcase_name='single_token',
227+
values=['hello', 'hello', 'hello'],
228+
expected_size=1,
229+
),
230+
dict(
231+
testcase_name='empty',
232+
values=['', '', ''],
233+
expected_size=1,
234+
),
235+
)
236+
def test_get_vocabulary_size_by_name(self, values, expected_size):
237+
vocab_filename = 'vocab'
238+
239+
def preprocessing_fn(inputs):
240+
tft.vocabulary(inputs['s'], vocab_filename=vocab_filename)
241+
size = tf.zeros_like(
242+
inputs['s'], dtype=tf.int64
243+
) + tft.experimental.get_vocabulary_size_by_name(vocab_filename)
244+
return {'size': size}
245+
246+
input_data_dicts = [dict(s=v) for v in values]
247+
input_metadata = tft.DatasetMetadata.from_feature_spec({
248+
's': tf.io.FixedLenFeature([], tf.string),
249+
})
250+
expected_data = [{
251+
'size': expected_size,
252+
}] * len(values)
253+
self.assertAnalyzeAndTransformResults(
254+
input_data_dicts,
255+
input_metadata,
256+
preprocessing_fn,
257+
force_tf_compat_v1=False,
258+
expected_data=expected_data,
259+
)
260+
219261

220262
if __name__ == '__main__':
221263
tft_unit.main()

0 commit comments

Comments
 (0)