Skip to content

Commit 004e3d0

Browse files
zoyahavtfx-copybara
authored andcommitted
Fixing DatasetKey.__new__ pytype, allowing object parameter type in case of a "flattened dataset key".
PiperOrigin-RevId: 506424291
1 parent f1ea609 commit 004e3d0

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

tensorflow_transform/beam/analyzer_cache.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,12 @@ class DatasetKey(tfx_namedtuple.namedtuple('DatasetKey', ['key'])):
5858
__slots__ = ()
5959
_FLATTENED_DATASET_KEY = object()
6060

61-
def __new__(cls, dataset_key):
61+
def __new__(cls, dataset_key: Union[str, object]) -> 'DatasetKey':
6262
if dataset_key is not DatasetKey._FLATTENED_DATASET_KEY:
63-
# TODO(b/267425539): Remove pytype directive.
64-
dataset_key = _make_valid_cache_component(dataset_key) # pytype: disable=wrong-arg-types # always-use-return-annotations
63+
if not isinstance(dataset_key, str):
64+
raise ValueError(
65+
f'User provided dataset_key must be a str. Got: {dataset_key}')
66+
dataset_key = _make_valid_cache_component(dataset_key)
6567
return super(DatasetKey, cls).__new__(cls, key=dataset_key)
6668

6769
def __str__(self):

tensorflow_transform/beam/analyzer_cache_test.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -268,16 +268,17 @@ def write_to_file(value):
268268
analyzer_cache.DatasetKey('a'):
269269
analyzer_cache.DatasetCache({'b': [bytes([17, 19, 27, 31])]}, None)
270270
}
271+
dataset_keys = list(test_cache_dict.keys())
271272

272273
class LocalSource(beam.PTransform):
273274

274275
def __init__(self, path):
275276
del path
276277

277278
def expand(self, pbegin):
278-
return pbegin | beam.Create([test_cache_dict['a'].cache_dict['b']])
279+
return pbegin | beam.Create(
280+
[test_cache_dict[k].cache_dict['b'] for k in dataset_keys])
279281

280-
dataset_keys = list(test_cache_dict.keys())
281282
cache_dir = self.get_temp_dir()
282283
with beam.Pipeline() as p:
283284
_ = test_cache_dict | analyzer_cache.WriteAnalysisCacheToFS(
@@ -289,9 +290,10 @@ def expand(self, pbegin):
289290
self.assertCountEqual(read_cache.keys(), ['a'])
290291
self.assertCountEqual(read_cache['a'].cache_dict.keys(), ['b'])
291292

292-
beam_test_util.assert_that(
293-
read_cache['a'].cache_dict['b'],
294-
beam_test_util.equal_to([test_cache_dict['a'].cache_dict['b']]))
293+
for key in dataset_keys:
294+
beam_test_util.assert_that(
295+
read_cache[key].cache_dict['b'],
296+
beam_test_util.equal_to([test_cache_dict[key].cache_dict['b']]))
295297

296298

297299
if __name__ == '__main__':

0 commit comments

Comments
 (0)